Je construis mon propre modèle et avant de le connecter à la couche entièrement connectée, je demande toujours: "Quelles sont les caractéristiques de l'entrée?" Vous pouvez comprendre la structure du modèle en tapant souvent print (model)
, mais vous ne pouvez pas vérifier la taille de la carte des caractéristiques. C'est là que le «résumé de la torche» est utile.
En termes simples, c'est un "vous pouvez voir la taille de la carte des caractéristiques".
Cette fois, j'ai fait le modèle simple suivant. Je n'ai pas écrit avant de le classer.
Pliage ➡︎BN➡︎ReLU➡︎pool Pliage ➡︎BN➡︎ReLU➡︎pool Pliage ➡︎ Global Average Pooling
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN,self).__init__()
self.conv1 = nn.Conv2d(3,16,kernel_size=3,stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d((2,2))
self.conv2 = nn.Conv2d(16,32,kernel_size=3,stride=1)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32,64,kernel_size=3,stride=1)
self.gap = nn.AdaptiveMaxPool2d(1)
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv3(x)
x = self.gap(x)
return x
pip install torchsummary
from torchsummary import summary
model = SimpleCNN()
summary(model,(3,224,224)) # summary(model,(channels,H,W))
Cette fois, j'essaie de supposer une taille d'entrée d'image de 224x224. Si vous voulez essayer d'autres résolutions, modifiez les valeurs de «H» et «W».
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 222, 222] 448
BatchNorm2d-2 [-1, 16, 222, 222] 32
ReLU-3 [-1, 16, 222, 222] 0
MaxPool2d-4 [-1, 16, 111, 111] 0
Conv2d-5 [-1, 32, 109, 109] 4,640
BatchNorm2d-6 [-1, 32, 109, 109] 64
ReLU-7 [-1, 32, 109, 109] 0
MaxPool2d-8 [-1, 32, 54, 54] 0
Conv2d-9 [-1, 64, 52, 52] 18,496
AdaptiveMaxPool2d-10 [-1, 64, 1, 1] 0
================================================================
Total params: 23,680
Trainable params: 23,680
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 30.29
Params size (MB): 0.09
Estimated Total Size (MB): 30.95
----------------------------------------------------------------
Il est assez pratique de pouvoir vérifier la forme de sortie. Je suis reconnaissant qu'il compte également le nombre de paramètres.
torchsummary est pratique, veuillez donc l'utiliser.