기타
pytorch summary() 기능
블로그별명
2023. 2. 21. 15:27
pytorch conv3d모델이다
class BaseModel(nn.Module):
def __init__(self, num_classes=13):
super(BaseModel, self).__init__()
# super().__init__()->python 3 에서만 작동
# super(자기자신,self).__init__()->python 2,3 모두 작동
# 성능차이는 없음
self.feature_extract = nn.Sequential(
nn.Conv3d(3, 8, (1, 3, 3)),
nn.ReLU(),
nn.BatchNorm3d(8),
nn.MaxPool3d(2),
nn.Conv3d(8, 32, (1, 2, 2)),
nn.ReLU(),
nn.BatchNorm3d(32),
nn.MaxPool3d(2),
nn.Conv3d(32, 64, (1, 2, 2)),
nn.ReLU(),
nn.BatchNorm3d(64),
nn.MaxPool3d(2),
nn.Conv3d(64, 128, (1, 2, 2)),
nn.ReLU(),
nn.BatchNorm3d(128),
nn.MaxPool3d((3, 7, 7)),
)
self.classifier = nn.Linear(1024, num_classes)
def forward(self, x):
batch_size = x.size(0)
x = self.feature_extract(x)
x = x.view(batch_size, -1)
x = self.classifier(x)
return x
아래 코드를 입력하면 summary된 모델을 볼수가 있다
from torchsummary import summary
model = BaseModel()
summary(model, (3, 50, 128, 128))
혹시나 에러가 뜬다면 gpu를 끄면 해결될수도있다(나는 그랬다)
레이어마다 달라지는 데이터의 shape을 정확하게 파악할수있어 흐름을 이해하거나 미쳐몰랐었던 내용도 알수있다