- 데이터 준비하기
- Generator 구성하기
- Generator 재구성하기
- Discriminator 구성하기
- 학습 및 테스트하기(현재 page)
Generator와 Discriminator 학습 및 테스트
이번 스탭에서는 구현된 Generator와 Discriminator를 학습시켜보고, 스케치를 입력으로 채색된 이미지를 생성해봅시다.
먼저 학습에 필요한 손실 함수부터 정의하도록 하자. 논문의 여러 실험 결과 중 손실 함수 선택에 따른 결과의 차이는 아래와 같습니다.
레이블 정보만 있는 입력에 대해 여러 손실 함수를 사용해 실제 이미지를 만들어 낸 결과는, 일반적인 GAN의 손실 함수에 L1을 추가로 이용했을 때 가장 실제에 가까운 이미지를 생성해 내었습니다. 이번 실험에서도 두 가지 손실 함수를 모두 사용해 보도록 합시다.
from tensorflow.keras import losses
bce = losses.BinaryCrossentropy(from_logits=False)
mae = losses.MeanAbsoluteError()
def get_gene_loss(fake_output, real_output, fake_disc):
l1_loss = mae(real_output, fake_output)
gene_loss = bce(tf.ones_like(fake_disc), fake_disc)
return gene_loss, l1_loss
def get_disc_loss(fake_disc, real_disc):
return bce(tf.zeros_like(fake_disc), fake_disc) + bce(tf.ones_like(real_disc), real_disc)
Generator 및 Discriminator의 손실 계산
- Generator의 손실 함수 (위 코드의
get_gene_loss
)는 총 3개의 입력이 있습니다. 이 중fake_disc는
Generator가 생성한 가짜 이미지를 Discriminator에 입력하여 얻어진 값이며, 실제 이미지를 뜻하는 "1"과 비교하기 위해tf.ones_like()
를 사용한다. 또한 L1 손실을 계산하기 위해 생성한 가짜 이미지(fake_output
)와 실제 이미지(real_output
) 사이의 MAE(Mean Absolute Error)를 계산한다. - Discriminator의 손실 함수 (위 코드의
get_disc_loss
)는 2개의 입력이 있으며, 이들은 가짜 및 진짜 이미지가 Discriminator에 각각 입력되어 얻어진 값이다. Discriminator는 실제 이미지를 잘 구분해 내야 하므로real_disc는
"1"로 채워진 벡터와 비교하고,fake_disc는
"0"으로 채워진 벡터와 비교한다.
사용할 optimizer
논문과 동일하게 설정하였다.
from tensorflow.keras import optimizers
gene_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)
disc_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)
하나의 배치 크기만큼 데이터를 입력했을 때 가중치를 1회 업데이트하는 과정은 아래와 같이 구현하였다.
@tf.function
def train_step(sketch, real_colored):
with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape:
# Generator 예측
fake_colored = generator(sketch, training=True)
# Discriminator 예측
fake_disc = discriminator(sketch, fake_colored, training=True)
real_disc = discriminator(sketch, real_colored, training=True)
# Generator 손실 계산
gene_loss, l1_loss = get_gene_loss(fake_colored, real_colored, fake_disc)
gene_total_loss = gene_loss + (100 * l1_loss) ## <===== L1 손실 반영 λ=100
# Discrminator 손실 계산
disc_loss = get_disc_loss(fake_disc, real_disc)
gene_gradient = gene_tape.gradient(gene_total_loss, generator.trainable_variables)
disc_gradient = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables))
disc_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))
return gene_loss, l1_loss, disc_loss
전반적인 학습 과정은 앞서 진행했었던 cGAN 학습과 크게 다르지 않습니다.
다만 위 코드의 gene_total_loss
계산 라인에서 최종 Generator 손실을 계산할 때, L1 손실에 100을 곱을 해주자~
(## 이 표시가 있는 부분을 확인)
논문에서는 Generator의 손실을 아래와 같이 정의하였는데요.
위 식에서 λ는 학습 과정에서 L1 손실을 얼마나 반영할 것인지를 나타내며 논문에서는 λ=100을 사용하였다.
앞서 정의한 함수를 이용해서 학습을 진행합니다. 우선 10 epoch 학습을 진행해봅시다.
학습진행
EPOCHS = 10
generator = UNetGenerator()
discriminator = Discriminator()
for epoch in range(1, EPOCHS+1):
for i, (sketch, colored) in enumerate(train_images):
g_loss, l1_loss, d_loss = train_step(sketch, colored)
# 10회 반복마다 손실을 출력합니다.
if (i+1) % 10 == 0:
print(f"EPOCH[{epoch}] - STEP[{i+1}] \
\nGenerator_loss:{g_loss.numpy():.4f} \
\nL1_loss:{l1_loss.numpy():.4f} \
\nDiscriminator_loss:{d_loss.numpy():.4f}", end="\n\n")
테스트
epoch 10의 학습으로 짧게 진행하여 채색이 이쁘겐 나오지는 않았다. epoch 수를 늘리면 더 좋은 결과가 나오지 않을까 싶군요.
test_ind = 1
f = data_path + os.listdir(data_path)[test_ind]
sketch, colored = load_img(f)
pred = generator(tf.expand_dims(sketch, 0))
pred = denormalize(pred)
plt.figure(figsize=(20,10))
plt.subplot(1,3,1); plt.imshow(denormalize(sketch))
plt.subplot(1,3,2); plt.imshow(pred[0])
plt.subplot(1,3,3); plt.imshow(denormalize(colored))
데이터셋을 제공한 출처를 보면, Pix2Pix로 128 epoch 학습 후 테스트 결과가 아래와 같다고 합니다.
이전 10 epoch 학습의 결과보다는 훨씬 낫지만, 조금 오래 학습했어도 채색해야 할 전체적인 색감 정도만 학습되며 아직까지 세부적으로는 제대로 채색되지 않다는 것을 확인할 수 있습니다.
Encoder-Decoder Generator, U-Net Generator, Discriminator의 구현까지 많은 것을 배워보았습니다. 앞서 다뤘던 코드의 이해에 큰 어려움이 없었다면, 연습을 통해 이와 비슷한 구조의 모델들 또한 쉽게 구현할 수 있을 것입니다.
U-Net Generator를 구현했으니 원래 segmentation에 사용되었던 U-Net도 쉽게 구현할 수 있지 않을까 싶네요.
우리의 머리도 몇 시간의 학습을 통해 빨리빨리 기억하면 좋을 텐데 말이죠.. 또르르.... 😥
'인공지능' 카테고리의 다른 글
배치 정규화-속도 향상 미세조정|Neural Network (0) | 2022.03.29 |
---|---|
Segmentation map-도로 이미지 만들기|Pix2Pix (0) | 2022.03.24 |
[Part4]Sketch2Pokemon-Discriminator구성|Pix2Pix (0) | 2022.03.23 |
[Part3]Sketch2Pokemon-UNet Generator|Pix2Pix (0) | 2022.03.23 |
[Part2]Sketch2Pokemon-Generator 구성하기|Pix2Pix (0) | 2022.03.23 |