호이호우
Beatlefeed
호이호우
전체 방문자
오늘
어제
  • 분류 전체보기 (75)
    • 세상은 지금... (4)
    • 인공지능 (24)
    • 코딩배우기 (21)
      • HTML, CSS (7)
    • 심리학 (25)

블로그 메뉴

  • 홈
  • 태그
  • 미디어로그
  • 위치로그
  • 방명록

공지사항

인기 글

태그

  • Sketch2Pokemon
  • 권위주의적육아
  • Optimizer
  • Loss Function
  • ADHD
  • Momentum
  • keras
  • generator
  • CIFAR-10
  • mnist
  • cGAN
  • 심리치료
  • 파이썬
  • Encoder
  • LeakyReLU
  • discriminator
  • 행동심리
  • Python
  • Deep learning
  • BatchNormalization
  • U-Net
  • Diana Baumrind
  • U-Net Generator
  • 발달심리학
  • pix2pix
  • Decoder
  • tensorflow
  • 인공지능
  • Gan
  • DCGAN

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
호이호우

Beatlefeed

[Part3]Sketch2Pokemon-UNet Generator|Pix2Pix
인공지능

[Part3]Sketch2Pokemon-UNet Generator|Pix2Pix

2022. 3. 23. 20:49
반응형
  1. 데이터 준비하기
  2. Generator 구성하기
  3. Generator 재구성하기(현재 page)
  4. Discriminator 구성하기
  5. 학습 및 테스트하기

Generator 재구성하기

이전 스탭에서 Encoder와 Decoder를 연결시켜 Generator를 만들어 보았습니다. 하지만 앞서 설명드린 것처럼 Pix2Pix의 Generator 구조는 아래 그림처럼 두 가지를 제안하였는데, 아래 그림을 한번 살펴봅시다.

위 그림에서 각 구조 아래에 표시된 이미지는 해당 구조를 Generator로 사용했을 때의 결과입니다.

 

단순한 Encoder-Decoder 구조에 비해 Encoder와 Decoder 사이를 skip connection으로 연결한 U-Net 구조를 사용한 결과가 훨씬 더 실제 이미지에 가까운 품질을 보이는 것을 알 수 있군요.

 

이전 스탭에서 구현했던 Generator는 위 그림의 Encoder-decoder 구조로 Encoder에서 출력된 결과를 Decoder의 입력으로 연결했고, 이 외에 추가적으로 Encoder와 Decoder를 연결시키는 부분은 없었습니다. 더 좋은 결과를 기대하기 위해 이전에 구현했던 것들을 조금 수정하여 위 그림의 U-Net 구조를 만들어보는 시간을 가져봅시다.

 

(아래 단락부터는 두 가지 구조를 각각 "Encoder-Decoder Generator" 및 "U-Net Generator"라는 용어로 구분하여 사용)

 

먼저 Encoder 및 Decoder에 사용되는 기본적인 블록은 이전에 아래 코드와 같이 구현하였었다.

 

class EncodeBlock(layers.Layer):
    def __init__(self, n_filters, use_bn=True):
        super(EncodeBlock, self).__init__()
        self.use_bn = use_bn       
        self.conv = layers.Conv2D(n_filters, 4, 2, "same", use_bias=False)
        self.batchnorm = layers.BatchNormalization()
        self.lrelu = layers.LeakyReLU(0.2)

    def call(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.batchnorm(x)
        return self.lrelu(x)


class DecodeBlock(layers.Layer):
    def __init__(self, f, dropout=True):
        super(DecodeBlock, self).__init__()
        self.dropout = dropout
        self.Transconv = layers.Conv2DTranspose(f, 4, 2, "same", use_bias=False)
        self.batchnorm = layers.BatchNormalization()
        self.relu = layers.ReLU()

    def call(self, x):
        x = self.Transconv(x)
        x = self.batchnorm(x)
        if self.dropout:
            x = layers.Dropout(.5)(x)
        return self.relu(x)

 

여기서 특별히 수정해야 할 부분은 없으니 그대로 가져다가 사용하도록 합시다.

U-Net Generator 정의

정의된 블록들을 이용해 한 번에 U-Net Generator를 정의해 보자. 아래 모델의 __init__() 메서드에서 Encoder 및 Decoder에서 사용할 모든 블록들을 정의해 놓고, call()에서 forward propagation을 진행하고, 이전 구현에는 없었던 skip connection이 call() 내부에서 어떻게 구현되었는지 잘 확인해 보시길 바랍니다.

 

class UNetGenerator(Model):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        encode_filters = [64,128,256,512,512,512,512,512]
        decode_filters = [512,512,512,512,256,128,64]

        self.encode_blocks = []
        for i, f in enumerate(encode_filters):
            if i == 0:
                self.encode_blocks.append(EncodeBlock(f, use_bn=False))
            else:
                self.encode_blocks.append(EncodeBlock(f))

        self.decode_blocks = []
        for i, f in enumerate(decode_filters):
            if i < 3:
                self.decode_blocks.append(DecodeBlock(f))
            else:
                self.decode_blocks.append(DecodeBlock(f, dropout=False))

        self.last_conv = layers.Conv2DTranspose(3, 4, 2, "same", use_bias=False)

    def call(self, x):
        features = []
        for block in self.encode_blocks:
            x = block(x)
            features.append(x)

        features = features[:-1]

        for block, feat in zip(self.decode_blocks, features[::-1]):
            x = block(x)
            x = layers.Concatenate()([x, feat])

        x = self.last_conv(x)
        return x

    def get_summary(self, input_shape=(256,256,3)):
        inputs = Input(input_shape)
        return Model(inputs, self.call(inputs)).summary()
  1. __init__() 에서 정의된 encode_blocks 및 decode_blocks 가 call() 내부에서 차례대로 사용되어 Encoder 및 Decoder 내부 연산을 수행한다.
  2. Encoder와 Decoder 사이의 skip connection을 위해 features라는 리스트를 만들고 Encoder 내에서 사용된 각 블록들의 출력을 차례대로 담는다.
  3. Encoder의 최종 출력이 Decoder의 입력으로 들어가면서 다시 한번 각각의 Decoder 블록들을 통과하는데, features 리스트에 있는 각각의 출력들이 Decoder 블록 연산 후 함께 연결되어 다음 블록의 입력으로 사용된다.
Question
위 코드의 call() 내에서 features = features [:-1]는 왜 필요할까?

Skip connection을 위해 만들어진 features 리스트에는 Encoder 내 각 블록의 출력이 들어있는데, Encoder의 마지막 출력(feature 리스트의 마지막 항목)은 Decoder로 직접 입력되므로 skip connection의 대상이 아니다.
Question
위 코드의 call() 내의 Decoder 연산 부분에서 features [::-1]는 왜 필요할까?

Skip connection은 Encoder 내 첫 번째 블록의 출력이 Decoder의 마지막 블록에 연결되고, Encoder 내 두 번째 블록의 출력이 Decoder의 뒤에서 2번째 블록에 연결되는.. 등 대칭을 이룬다.(맨 위 U-Net 구조 사진 참고). features에는 Encoder 블록들의 출력들이 순서대로 쌓여있고, 이를 Decoder에서 차례대로 사용하기 위해서 features의 역순으로 연결한다.
Question
아래와 같은 데이터 A, B가 있을 때,
- 데이터 A 크기 : (32, 128, 128, 200) #(batch, width, height, channel)
- 데이터 B 크기 : (32, 128, 128, 400) #(batch, width, height, channel)
여기서 사용되는 skip connection은 layers.Concatenate() 결과의 크기는 무엇일까?

(128, 128, 600) layers.Concatenate() 내에 별다른 설정이 없다면 가장 마지막 축(채널 축)을 기준으로 서로 연결된다.

U-Net Generator 내부 구조 확인

마지막으로 완성된 U-Net 구조 Generator 내부의 각 출력이 적절한지 아래 코드로 확인해봅시다.

 

UNetGenerator().get_summary()

 

Question
U-Net Generator의 파라미터가 늘어난 곳은 Encoder와 Decoder 중 어디일까? 또한, 두 종류의 Generator 구조에서 동일한 수의 convolution 레이어를 사용했는데, 구체적으로 어느 부분에서 파라미터가 늘어났나?

U-Net Generator의 Decoder 구조 내 파라미터가 많아졌습니다. 이 부분의 각 convolution 레이어에서 사용된 필터의 수는 두 종류의 Decoder에서 동일하지만, 그 크기가 다르다.

예를 들어, 이전 Decoder 블록의 출력의 크기가 (16, 16, 512)라면, Encoder-decoder Generator의 경우, Decoder의 다음 블록에서 계산할 convolution의 필터 크기는 4 * 4 * 512이다.

U-Net Generator의 경우, Encoder 내 블록 출력이 함께 연결되어 Decoder의 다음 블록에서 계산할 convolution의 필터 크기는 4 * 4 * (512 + 512)이다.

 

정리하면, U-Net Generator에서 사용한 skip-connection으로 인해 Decoder의 각 블록에서 입력받는 채널 수가 늘어났고, 이에 따라 블록 내 convolution 레이어에서 사용하는 필터 크기가 커지면서 학습해야 할 파라미터가 늘어난 것이다.

 

다음 투고에서는 Discriminator 구성하는 방법에 대해서 알아보는 시간을 가져봅시다. 😀

반응형
저작자표시 (새창열림)

'인공지능' 카테고리의 다른 글

[Part5]Sketch2Pokemon-학습 및 테스트하기|Pix2Pix  (2) 2022.03.23
[Part4]Sketch2Pokemon-Discriminator구성|Pix2Pix  (0) 2022.03.23
[Part2]Sketch2Pokemon-Generator 구성하기|Pix2Pix  (0) 2022.03.23
[Part1]Sketch2Pokemon-데이터 준비하기|Pix2Pix  (0) 2022.03.23
배치 정규화-신경망 훈련 속도 향상|Neural Network  (0) 2022.03.23
    '인공지능' 카테고리의 다른 글
    • [Part5]Sketch2Pokemon-학습 및 테스트하기|Pix2Pix
    • [Part4]Sketch2Pokemon-Discriminator구성|Pix2Pix
    • [Part2]Sketch2Pokemon-Generator 구성하기|Pix2Pix
    • [Part1]Sketch2Pokemon-데이터 준비하기|Pix2Pix
    호이호우
    호이호우
    나의 관심 콘텐츠를 즐겁게 볼 수 있는 Beatlefeed!

    티스토리툴바