본문 바로가기
기타

pytorch summary() 기능

by 블로그별명 2023. 2. 21.

pytorch conv3d모델이다

class BaseModel(nn.Module):
    def __init__(self, num_classes=13):
        super(BaseModel, self).__init__() 
        # super().__init__()->python 3 에서만 작동
        # super(자기자신,self).__init__()->python 2,3 모두 작동
        # 성능차이는 없음
        self.feature_extract = nn.Sequential(
            nn.Conv3d(3, 8, (1, 3, 3)),
            nn.ReLU(),
            nn.BatchNorm3d(8),
            nn.MaxPool3d(2),
            nn.Conv3d(8, 32, (1, 2, 2)),
            nn.ReLU(),
            nn.BatchNorm3d(32),
            nn.MaxPool3d(2),
            nn.Conv3d(32, 64, (1, 2, 2)),
            nn.ReLU(),
            nn.BatchNorm3d(64),
            nn.MaxPool3d(2),
            nn.Conv3d(64, 128, (1, 2, 2)),
            nn.ReLU(),
            nn.BatchNorm3d(128),
            nn.MaxPool3d((3, 7, 7)),
        )
        self.classifier = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        batch_size = x.size(0)
        x = self.feature_extract(x)
        x = x.view(batch_size, -1)
        x = self.classifier(x)
        return x

아래 코드를 입력하면 summary된 모델을 볼수가 있다

from torchsummary import summary
model = BaseModel()
summary(model, (3, 50, 128, 128))

 

혹시나 에러가 뜬다면  gpu를 끄면 해결될수도있다(나는 그랬다)

레이어마다 달라지는 데이터의 shape을 정확하게 파악할수있어 흐름을 이해하거나 미쳐몰랐었던 내용도 알수있다

 

댓글