Home Generative Models - 2.1 Autoregressive Models FVSBN
Post
Cancel

Generative Models - 2.1 Autoregressive Models FVSBN

FVSBN Implementation

간단한 Autoregressive Model을 Pytorch Module을 사용하여 구현해 보기로 했다.

Mnist 데이터는 아래와 같이 labelimage 데이터 두개로 이루어져 있다.

sample_mnist

FVSBN의 이름에서 알 수 있듯이, 각 입력에 대해서 단일 parameter 만을 사용하는 구조이기 때문에 성능이 상당히 떨어질 것이라고 예상했다. 생성형 모델을 MNIST 데이터에 어떻게 활용할지 고민해보았을 때 아래 두 가지를 시도해 보면 좋을 것 같아 진행해보았다.

    1. Total Data (0 ~ 9) -> Random Number generating Model Training
    1. Individiual Label Data -> Individual Model Training

1. Random Number Generating Model Training

먼저 모든 데이터를 바탕으로 학습을 수행한 이후 랜덤으로 이미지를 생성하는 방식으로 학습을 진행해 보았다. 전체 파라미터의 개수는 $O(n^2)$에 Bound된다는 것을 확인했기 때문에 연산 효율성을 위해서 Parameter 집단 Wb를 각각 $784 \times 784$ 과 $1 \times 784$ 크기의 matrix로 설정했다.

FVSBN의 수식을 보면 i번재 Prediction을 진행하기 위해 필요한 argument의 개수는 i개가 되기 때문에 이에 맞추어 대각 행렬의 Upper paramenter는 전부 masking헤 아래의 값만 사용했다. (추가적으로 대각행렬의 값 또한 마스킹해주었다. 이는 Bias성분을 사용하기 때문이다.)

\[W = \begin{bmatrix} 0 & & & & 0 \\ \ell_{2,1} & 0 & & & \\ \ell_{3,1} & \ell_{3,2} & 0 & & \\ \vdots & \vdots & \ddots & \ddots & \\ \ell_{784,1} & \ell_{784,2} & \ldots & \ell_{784,783} & 0 \end{bmatrix}\]

꽤나 간단한 설정이기 때문에 multi-gpu를 사용한 gradient update는 따로 추가하지 않았다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from utils.data_preprocessing import *
import wandb
import os
import time

D = 28 * 28
EPOCHS = 10000
LR = 1e-3


def wandb_init():
    wandb.init(
        project="Generative Model",
        id="FVSBN_random",
        config={
            "learning_rate": LR,
            "epochs": EPOCHS,
        },
    )


class FVSBN(nn.Module):
    def __init__(self, D):
        super().__init__()
        self.W = nn.Parameter(torch.zeros(D, D))
        self.b = nn.Parameter(torch.zeros(1, D))
        tril = torch.tril(torch.ones(D, D), diagonal=-1)
        self.register_buffer("mask", tril)

    def forward(self, x):  # Batch Size x (B, 784)
        Wm = self.W * self.mask  # Weight Matrix (784, 784)
        logits = x @ Wm.t() + self.b  # (B, 784) * (784, 784) + (1, 784)
        loss = F.binary_cross_entropy_with_logits(
            logits, x, reduction="none"
        ).sum(dim=1).mean()
        return loss, logits

    @torch.no_grad()
    def sample(self, B=16, device=None):
        if device is None:
            device = self.b.device
        x = torch.zeros(B, D, device=device)
        Wm = self.W * self.mask
        for i in range(D):  # Autoregressive Decoding
            logits_i = (x @ Wm[:, i]) + self.b[0, i]  # (B,)
            prob_i = torch.sigmoid(logits_i)
            x[:, i] = torch.bernoulli(prob_i)
        return x


def train_one_epoch(model, loader, opt, device):
    model.train()
    running = 0.0
    # Training with whole data (Not using Label Data)
    for x, _ in loader:
        x = x.to(device, non_blocking=True)  # (B,784)
        opt.zero_grad(set_to_none=True)
        # Forward
        loss, _ = model(x)
        # Backward
        loss.backward()
        # Gradient Update
        opt.step()
        with torch.no_grad():
            model.W.mul_(model.mask)
        iteration_loss = loss.item() * x.size(0)
        wandb.log({"iter_loss": iteration_loss})
        running += iteration_loss
    epoch_loss = running / len(loader.dataset)
    wandb.log({"epoch_loss": epoch_loss})
    return epoch_loss


@torch.no_grad()
def evaluate_nll(model, loader, device):
    model.eval()
    total = 0.0
    for x, _ in loader:
        x = x.to(device, non_blocking=True)
        loss, _ = model(x)
        total += loss.item() * x.size(0)
    average_nll_loss = total / len(loader.dataset)
    wandb.log({"test_loss": average_nll_loss})
    return average_nll_loss


def log_time(func):
    def wrapper_func(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        return end_time - start_time, result

    return wrapper_func


@torch.no_grad()
def sample_and_save(
    model, n=16, save_dir="FVSBN_Random", grid_name="FVSBN_grid.png", save_it=0
):
    save_dir = os.path.join(save_dir, str(save_it))
    os.makedirs(save_dir, exist_ok=True)
    # Save Trained Model
    torch.save(
        model.state_dict(), os.path.join(save_dir, f"fvsbn_epoch_{save_it}.pth")
    )
    model.eval()
    device = next(model.parameters()).device

    start_time = time.time()
    x_samp = model.sample(B=n, device=device)  # (n,784)
    end_time = time.time()
    wandb.log({"Sample Gen Time(s)": end_time - start_time})
    imgs = x_samp.view(n, 28, 28).cpu().numpy()

    cols = int(n**0.5)
    rows = (n + cols - 1) // cols
    plt.figure(figsize=(cols * 2, rows * 2))
    for i in range(n):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(imgs[i], cmap="gray", vmin=0, vmax=1)
        plt.axis("off")
    plt.tight_layout()
    grid_path = os.path.join(save_dir, grid_name)
    plt.savefig(grid_path, dpi=150)
    plt.close()
    print(f"[Grid] Saved to {grid_path}")

    for i in range(n):
        img_path = os.path.join(save_dir, f"sample_{i:03d}.png")
        plt.imsave(img_path, imgs[i], cmap="gray", vmin=0, vmax=1)

    print(f"[Individual] Saved {n} samples to {save_dir}/")


def main():
    wandb_init()
    # Import Dataset
    train_dataset, test_dataset = load_data()
    test_loader = make_loader(test_dataset, shuffle=True)
    train_loader = make_loader(train_dataset, shuffle=True)
    print("Train/Test Data Loaded")
    random_model = FVSBN(D)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random_model = random_model.to(device)

    # Gradient Hook for Masking Robustness
    random_model.W.register_hook(lambda g: g * random_model.mask)
    # Optimizer Setting
    opt = torch.optim.Adam(random_model.parameters(), lr=LR)

    for ep in range(1, EPOCHS + 1):
        train_loss = train_one_epoch(random_model, train_loader, opt, device)
        test_loss = evaluate_nll(random_model, test_loader, device)
        print(f"[{ep:5d}] train NLL: {train_loss:.4f} | test NLL: {test_loss:.4f}")
        if ep % 100 == 0:
            sample_and_save(
                random_model, save_dir="FVSBN_Random_E2", n=16, save_it=ep
            )


if __name__ == "__main__":
    main()

위와 같은 설정으로 실험을 진행했을 때, 100 epoch마다 데이터를 출력하여 결과를 확인해보았다.

100 Epoch Sample Image

init_FVSBN

800 Epoch Sample Image

init_FVSBN 여러 실험 설정으로 학습과 Sampling을 진행해보았지만, 일정 수준 이상으로 더 좋은 학습은 불가능했고, 위와 같이 중심부에 점이 일부 보이긴 하나 어떤 문자인지 정확하게 확인할 수는 없는 수준이기 때문에 해당 실험은 800 epoch에서 중단했다.

2. Individual Number Generating Model Training

1번 주제의 실험에서 모든 Label에 대한 학습을 진행하고 Random Label에 대한 생성을 진행했을 때 성능이 굉장히 좋지 않은 점을 확인했다. 이에 문제의 난이도를 낮추어 개별 Label에 대한 이미지를 생성하는 것을 목적으로 하였다.

코드는 아래와 같다. 1번 실험에서 사용한 코드를 거의 동일하게 사용했으며, 아래의 data load 부분만 수정을 진행하면 된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def main():
    wandb_init()
    # Import Dataset
    train_dataset, test_dataset = load_data()
    train_targets = train_dataset.targets
    test_targets = test_dataset.targets

    train_idx = (train_targets == 5).nonzero(as_tuple=True)[0]
    train_subset = Subset(train_dataset, train_idx)
    test_idx = (test_targets == 5).nonzero(as_tuple=True)[0]
    test_subset = Subset(test_dataset, test_idx)

    train_loader = make_loader(train_subset, shuffle=True)
    test_loader = make_loader(test_subset, shuffle=True)

간단하게 5 Label에 대해서만 학습과 생성을 진행해보았다.

100 Epoch Sample Image

init_FVSBN

1000 Epoch Sample Image

init_FVSBN

2000 Epoch Sample Image

init_FVSBN

3000 Epoch Sample Image

init_FVSBN 데이터셋을 3000회 반복하여 학습을 진행해 보았지만, 위와 같이 5의 숫자 형상은 전혀 알아볼 수 없었다.

CNN 모듈과 달리 어떤 이미지에 대한 학습을 진행한 것이 아니라 이미지 구조에 대한 이해가 전혀 없을 것이라고 예상한다. Vision Transformer에서도 이미지를 픽셀 별로 잘라서 하나의 토큰처럼 사용했던 것으로 기억하는데, 위와 같이 한 픽셀 별로 진행하게 되면 공간적인 부분에 대한 이해가 전혀 없을 것으로 예상한다.

추가적으로 널리 사용되는 Autoregressive 모델 중 성능이 굉장히 좋다고 알려진 Transformer의 경우에는 이전 token들의 Key와 Value값을 큰 벡터로 표현하지만, 이 간단한 모델에서는 그저 스칼라 값 하나만 사용하기 때문에 제대로 된 생성을 진행할 수 없을 것이다.

This post is licensed under CC BY 4.0 by the author.