본문 바로가기
AI Developer/Paper Review

[AI Paper] GAN(Generative Adversarial Nets) 리뷰

by 성 언 2023. 4. 7.

GAN(Generative Adversarial Nets) 논문 학습 후 정리한 포스팅 입니다.

이번 포스팅에서는 GAN(Generative Adversarial Nets)에 대해 리뷰합니다.

Generative Adversarial Nets

0. Abstract

적대적 프로세스를 통해 생성 모델을 예측하는 새로운 프레임워크를 제안합니다.

데이터 분포를 포착하는 생성 모델 G와 훈련 데이터에서 샘플이 나올 확률을 추정하는 판별 모델 D의 두 모델을 동시에 훈련합니다.

G의 훈련 절차는 D가 실수할 확률을 최대화하는 것입니다. 이 프레임워크는 미니 맥스 2인 게임에 해당합니다.

임의의 함수 G와 D의 공간에서 G는 훈련 데이터 분포를 복구하고 D는 모든 곳에서 0.5와 같은 고유한 솔루션이 존재합니다.

G와 D가 다음과 같이 정의된 경우 다층 퍼셉트론을 통해 전체 시스템을 역전파 방식으로 학습시킬 수 있습니다. 훈련이나 샘플 생성 과정에서 마르코프 체인이나 풀린 근사 추론 네트워크가 필요하지 않습니다.

실험을 통해 생성된 샘플의 정성적, 정량적 평가를 통해 프레임 워크의 잠재력을 입증합니다.

1. Introduction

지금까지 딥러닝에서 가장 눈에 띄는 성공을 거둔 모델은 일반적으로 고차원의 풍부한 감각 입력(rich sensory input)을 클래스 레이블에 맵핑하는 판별 모델이었습니다.

이러한 놀라운 성공은 주로 특히 잘 작동하는 그라디언트(그래디언트 vanshing x)를 가진 조각별 선형 단위(piecewise linear units)를 사용하는 역전파 및 드롭아웃 알고리즘에 기반했습니다.

  • 기존 연구 한계
    • 심층 생성 모델은 최대 가능성 추정(MLE) 및 관련 전략에서 발생하는 많은 다루기 어려운 확률 계산을 근사화하기 어렵고 생성 컨텐스트에서 조각 선형 단위(piecewise linear units)의 이점을 활용하기 어렵기 때문에 그 영향력이 적었습니다. 이러한 어려움을 피할 수 있는 새로운 생성 모델 추정 절차를 제안합니다.

제안된 적대적 네트워크(adversarial nets) 프레임워크에서 생성 모델은 샘플이 모델 분포에서 나온 것인지 데이터 분포에서 나온 것인지 판단하는 방법을 학습하는 판별 모델(discriminative model)이라는 적과 맞붙게 됩니다.

생성 모델(generative model)은 위조 화폐를 만들어 탐지되지 않고 사용하려는 위조범 팀에 비유할 수 있으며, 판별 모델은 위조 화폐를 탐지하려는 경찰과 유사합니다.

이 게임에서 경쟁을 통해 두 팀은 위조 화폐가 진품과 구별할 수 없을 때까지 개선하게 됩니다.

Generator의 목적: Fake data를 Real data와 최대한 유사하게 만들어 Discriminator가 구분할 수 없게 하는 것

Discriminator의 목적: Fake data와 Real data를 구별 할 수 있게 학습하는 것

 

이 프레임워크는 다양한 종류의 모델과 최적화 알고리즘을 위한 특정 훈련 알고리즘을 생성할 수 있습니다.

생성 모델은 다충 퍼셉트론을 통해 무작위 노이즈를 통과시켜 샘플을 생성하고 판별 모델도 다층 퍼셉트론입니다.

이러한 특수한 경우를 적대적 그물망(adversarial nets)이라고 합니다.

이 경우, 매우 성공적인 역전파 및 드롭아웃 알고리즘만 사용하여 두 모델을 모두 훈련하고 순방향 전파만을 사용하여 생성 모델에서 샘픔을 추출할 수 있습니다. 근사 추론이나 마르코프 체인이 필요하지 않습니다.

2. Related Work

  • RBMs: restricted Boltzmann machines, 잠재 변수를 가진 유향 그래프 모델에 대한 대안으로, 무향 그래프 모델
  • DBMs: deep Boltzmann machines, RBMs와 비슷함. 다양한 변형이 존재
  • MCMC: Markov chain Monte Carlo methods, 위 모델의 측정 방법
  • DBNs: Deep belief networks, 하나의 무향 레이어와 여러 유향 레이어의 hybrid 모델. 계삭적 문제가 있음
  • NCE: noise-contrasive estimation, log-likelihood를 근사하거나 경계값을 구하지 않는 방법
  • GSN: generative stochastic network, 확률분포를 명시적으로 정의하지 않고 분포 샘플을 생성하도록 학습시키는 방법을 사용
  • adversarial nets: 적대적 망은 생성 중 feedback loop를 필요로 하지 않아 sampling에서 Markov chain이 필요가 없다. 이는 backpropagation 성능 향상으로 이어진다.
  • auto-encoding varitional Bayes와 stochastic backpropagation은 생성 머신을 학습시키는 방법들 중 하나이다.

3. 제안 방법론

Main Idea: Adversarial nets

적대적 모델링 프레임워크는 모델이 모두 다층 퍼셉트론일 때 적용하기가 가장 간단합니다.

→ 어떻게 Adversarial 관계로 학습 할까? → 목적(loss) 함수 확인

G는 V를 최소화하고 D는 V를 최대화한다. (한방향으로 최소화하는 보통의 loss function과 다르다)

 

 

[G의 관점에서의 Loss function]

z는 Noise, pz(z)는 사전확률 (가우시안을 사용), G(z)는 z를 통해 만들어낸 Fake data

→ D(G(z))는 Fake data를 D가 어떻게 판별할까?

→ G의 관점에서 D(G(z))가 1로 가야한다.

→ Loss function의 관점에서도(최소화 해야함) log(1-D(G(z)))가 최소화되려면 D(G(z))가 1로 가야한다.

 

 

[D의 관점에서의 Loss function]

 

x는 실제 데이터 분포,

→ D의 관점에서 Real data를 1로 판별해야한다. (D(x)을 1로 판별)

→ D의 관점에서 Fake data는 0으로 판별해야한다. (D(G(z))를 0으로 판별)

—> 생성자와 판별자가 서로 다른 목적을 갖고 학습을 진행한다.

 

 

[학습]

초록색 선: Generator가 맵핑하는 이미지 분포 (생성한 데이터의 분포)

검정색 선: 실제 이미지의 분포 (실제 데이터의 분포)

파랑색 선: Discriminator의 분포 (판별자가 판단하는 값)

z: Noise (Uniform distribution), generator을 통과에 x와 유사한 데이터 생성

  • 처음 학습을 진행하는 (b)

Generator 고정하고 Discriminaotr을 학습

(Real data에 대해 안정적으로 1에 가까운 확률, Fake data에 대해 0에 가까운 확률을 반환)

  • Discriminator을 고정한 후 Generator 학습을 진행하는 (c)

z에서 x로의 맵핑이 실제 Real 이미지 분포에 가까운 분포를 형성

  • 계속해서 학습을 진행하는 (d)

z에서 x로의 맵핑이 Real 이미지 분포와 거의 동일한 분포를 형성

Discriminator은 0.5를 판별 확률로 반환

Generator, Discriminator 둘 다 학습이 잘 된 상황

 

 

  • Global Optimality (최적해를 갖는 과정)

D*g(x) : Optimal 한 최적의 Dg(x) → 어떠한 G가 들어와도 최적의 D의 값

기존의 목적함수를 변형

x를 샘플링해서 g(z)를 만드는 것과 Pg를 따르는 x를 만드는 것은 동일하다. (x에 관한 intergral로 변경)

D가 optimal하려면 V(G, D)가 최대가 되어야 한다.

P data (x) = a, P g (x) = b, D(x) = y 로 치환하면

a log y + b log (1-y)를 y로 미분한 값이 0이 되어야 한다.

즉, y = a / a+b 일때, 최대값을 갖는다.

  • C(G) = -log4

global minimum을 어떻게 찾을까?

D는 maximize된 값이므로 이전의 optimal 값을 대입할 수 있다.

log4 를 더하고 빼준다 (수학적 스킬)

log 4 = log 2 + log 2 이므로 V(G, D)에 대입가능

cf) KL(A||B): A와 B가 분포의 차이를 설명, KL divergence,

cf) JSD(p,q): KL divergence로 표현할 수 있다.

JSD의 값은 0이상이고, 분포가 동일할 때 0을 가진다. 즉, C(G)는 generator입장에서 minimize해야하는 값이므로 Pdata와 Pg과 동일할 때, 최솟값(-log4)을 갖는다.

 

 

[모델 관점에서의 GAN]

G는 Noise의 한 점을 샘플링하여 Muliti-Layer Perceptron에 넣고 Mapping을 하면 Fake 데이터를 생성한다.

D는 Fake = 0, Real = 1 을 출력하도록 학습한다.

즉 G는 들킬 가능성을 최소화, D는 속을 가능성을 최소화 한다. → Adversarial

 

 

4. 실험 및 결과

Dataset

a : MNIST, b : TFT(Toronto Face Database), c, d : CIFAR10

 

5. 코드

[import library]

import torch
import torch.nn as nn   # 생성자 판별자 아키텍처 정의

from torchvision import datasets  # Mnist 데이터 셋 활용의 경우
import torchvision.transforms as transforms # 변형(전처리)
from torchvision.utils import save_image # 이미지 출력

[Generator Code]

latent_dim = 100 # noise dim

# 생성자 클래스 정의
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
	
		# 하나의 블록 정의
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
				# batch normalization, 차원 동일
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
				
		# 여러개의 블록을 가짐 -> 1 * 28 * 28 
        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape) # 배치 사이즈를 이미지 형태를 갖게끔
        return img

논문에선 100을 차원으로 맵핑 (latent_dim = 100)

이미지 크기와 같은 1024가 출력 (2^10, image를 Flatten 했을 때 pixel)

Q. tanh 쓰는 이유 (-1~1)

A. 보통 픽셀값을 -1~1까지 normalize 하는데 그것을 맞춰주기 위해 사용한다.

 

[Discriminator Code]

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

		# 이미지에 대한 판별 결과를 반환
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

img_shape인 1024를 입력으로 하여 마지막엔 0~1 사이의 값을 출력하는 Sigmoid (Real 이미지일 확률)

 

[Loss function, Optimizers]

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator (학습을 위해 초기화)
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

loss를 학습하기 위해 Binary Cross entropy loss를 사용한다. (수식 참고)

optimizer은 Generator과 Discriminator에 각각에 맞게 설정해준다. (model이 2개있는 것과 마찬가지이므로)

 

[real, fake data label 생성]

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

1024의 이미지 사이즈 만큼을 valid에선 1로, fake는 0으로 채운다 → 정답지

real_imgs: batch 에서 들어오는 실제 이미지 데이터

 

[Generator 학습]

optimizer_G.zero_grad()

# Sample noise as generator input, 랜덤 노이즈를 뽑아 샘플링한다.
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

# Generate a batch of images, 이미지 생성
gen_imgs = generator(z)

# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)

# 학습 진행
g_loss.backward()
optimizer_G.step()

Discriminator 고정 시키고, Generator만 학습을 진행

z: 가우시안 분포에서 한 값을 뽑는다 (사전확률 분포를 가우시안 분포로 둔다.)

generator에 z를 대입 → 1024의 이미지 나옴

valid와 get_imgs를 비교한다. → 1에 가깝게 학습을 해야하는 Generator

 

 

[Discriminator 학습]

optimizer_D.zero_grad()

# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2

d_loss.backward()
optimizer_D.step()

Generator는 고정 real_imgs는 1에 가깝게, gen_imgs는 0과 가깝게 학습을 진행

 

 

6. 한계점

  • model collapse

분포의 형태를 전반적으로 맵핑하기보다, 단순히 오류를 최소화하기 위해서 최빈값(model)에만 집중하여 학습함에 따라 실제 값(이미지 등) 중 특정한 형태만 생성한다.

→ Mini-batch discrimination, feature matching으로 해결 가능.

  • 학습시키는 것이 어렵다
    • 생성자와 판별자의 실육이 비슷한 경우에 두 모델이 균형있게 학습이 진행되는데, 그렇지 않은 경우 학습이 잘 진행되지 않음
    • 서로 이기려고 하는 minimax 게임을 통해 학습하므로, G와 D 간의 힘의 균형이 깨지기 쉬움
    • → DCGAN등 학습을 안정적으로 바꾸고자 구조를 새로 제안하는 모델 등장
  • 사용된 생성자의 결과물 형태가 어떠한 과정을 통해 나왔는지 알 수 없다
  • 새롭게 만들어진 데이터가 얼마나 정확한지 객관적으로 판단하기 어렵다
    • 주관적인 판단 필요
    • 생성된 결과가 잘 생성된 값인지 판단할 수 있는 지표가 없고, 이에 따라 학습을 얼마나 진행해야 하는지 명확한 기준이 부족
    → Inception Score (생성된 이미지의 다양성을 측정하는 지표) 사용

7. 추가 가이드라인

Before reading

  1. 논문의 main figure를 보고 전체 흐름을 유추해봅시다
    1. 이해되지 않는 파트가 있나요? 있다면 미리 표시해두고 집중적으로 읽어봅시다
    2. Global Optimality 부분 (수식이 많아서 겁먹었다)
  2. 해당 모델을 구현한 코드가 있는지 체크해봅시다
  3. PyTorch-GAN/gan.py at master · eriklindernoren/PyTorch-GAN

After reading

  1. 논문에서 풀고자 했던 문제는 무엇인가요 ? (task 정의해보기)
  2. 기존 논문들은 이 TAsk를 어떻게 풀어왔나요?
  3. 기존 논문 대비 이 논문의 강점은 무엇인가요?
  4. 내가 가설을 세운 모델의 특징과 실제 모델의 주장이 어느정도 일치하나요?
  5. 논문의 내용과 내 생각을 짧게 정리해봅시다

 

<Summary>

GAN(Generative Adversarial Nets) 리뷰

 

 

*유의사항

- AI 논문 공부 중인 인공지능공학과 학부생이 공부하여 남긴 정리입니다.

- 정확하지 않거나, 틀린 점이 있다면 댓글로 알려주시면 감사하겠습니다.

 

 

 

 

댓글