pytorch conv3d모델이다
class BaseModel(nn.Module):
def __init__(self, num_classes=13):
super(BaseModel, self).__init__()
# super().__init__()->python 3 에서만 작동
# super(자기자신,self).__init__()->python 2,3 모두 작동
# 성능차이는 없음
self.feature_extract = nn.Sequential(
nn.Conv3d(3, 8, (1, 3, 3)),
nn.ReLU(),
nn.BatchNorm3d(8),
nn.MaxPool3d(2),
nn.Conv3d(8, 32, (1, 2, 2)),
nn.ReLU(),
nn.BatchNorm3d(32),
nn.MaxPool3d(2),
nn.Conv3d(32, 64, (1, 2, 2)),
nn.ReLU(),
nn.BatchNorm3d(64),
nn.MaxPool3d(2),
nn.Conv3d(64, 128, (1, 2, 2)),
nn.ReLU(),
nn.BatchNorm3d(128),
nn.MaxPool3d((3, 7, 7)),
)
self.classifier = nn.Linear(1024, num_classes)
def forward(self, x):
batch_size = x.size(0)
x = self.feature_extract(x)
x = x.view(batch_size, -1)
x = self.classifier(x)
return x
아래 코드를 입력하면 summary된 모델을 볼수가 있다
from torchsummary import summary
model = BaseModel()
summary(model, (3, 50, 128, 128))
혹시나 에러가 뜬다면 gpu를 끄면 해결될수도있다(나는 그랬다)
레이어마다 달라지는 데이터의 shape을 정확하게 파악할수있어 흐름을 이해하거나 미쳐몰랐었던 내용도 알수있다
'기타' 카테고리의 다른 글
facenet 임베딩 모델 구조의 구현 (0) | 2023.04.28 |
---|---|
FL STUDIO 첫번째 곡(기초조작법 익히기) (0) | 2023.04.16 |
넘파이에서 리스트처럼 곱하기 적용하고싶을때 (0) | 2023.02.25 |
re.search 위치 여러개 찾기 (0) | 2023.02.23 |
numpy where 기능소개 (0) | 2023.02.21 |
댓글