import torch
from torch import nn
import torchvision
from torch.nn import Conv2d, MaxPool2d, ReLU
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
input = torch.tensor([[1, -0.5],
[-1, 3]])
input = torch.reshape(input, (-1, 1, 2, 2))
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.relu1 = ReLU()
def forward(self, input):
output = self.relu1(input)
return output
module = MyModule()
output = module(input)
print(output)
参考地址:https://www.bilibili.com/video/BV1hE411t7RN?p=20