-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
979 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from tqdm import tqdm | ||
import dazero | ||
import dazero.functions as F | ||
from PIL import Image | ||
from modules.unet import UNet | ||
from modules.sampler import DDPM | ||
import numpy as xp | ||
import matplotlib.pyplot as plt | ||
|
||
class Diffusion: | ||
''' | ||
ノイズ予測・サンプラーを受け取って画像生成・学習ステップを定義する。 | ||
''' | ||
|
||
def __init__(self, unet, sampler): | ||
self.unet = unet | ||
self.unet.to_gpu() | ||
self.sampler = sampler | ||
|
||
def generate(self, context, channels, height, width, cfg_scale = 1.0): | ||
''' | ||
画像生成を行うメソッド。 | ||
context: ラベルのxp配列 | ||
''' | ||
batch_size = context.shape[0] | ||
with dazero.test_mode(): | ||
with dazero.no_grad(): | ||
x = xp.random.randn(batch_size, channels, height, width) # 初期ノイズ x_0 | ||
for t in tqdm(reversed(range(1000)), total=1000): # t = 999, ..., 0 | ||
noise_pred = self.unet(x, xp.array([[t]]*batch_size).astype(xp.int32), context) # ノイズ予測 | ||
if cfg_scale != 1.0: | ||
noise_pred_uncond = self.unet(x, xp.array([[t]]*batch_size).astype(xp.int32), context * 0) # ノイズ予測 | ||
noise_pred = noise_pred * cfg_scale + noise_pred_uncond * (1 - cfg_scale) | ||
x = self.sampler.step(x, noise_pred, t) # x_{t+1} -> x_{t} | ||
|
||
images = [] | ||
for image in x: | ||
image = (xp.clip(image.data*127.5 + 127.5, 0, 255)).astype(xp.uint8) # 0~255に変換 | ||
image = xp.asnumpy(image) | ||
image = image.transpose(1, 2, 0).squeeze() | ||
image = Image.fromarray(image) | ||
images.append(image) | ||
return images | ||
|
||
|
||
def generate_grid(self, num_images, channels, height, width, image_path, id2label, cfg_scale = 1.0): | ||
''' | ||
生成画像をラベルごとにグリッド状に並べて保存するメソッド。 | ||
''' | ||
num_labels = self.unet.context_dim | ||
fig, axes = plt.subplots(num_labels, num_images, figsize=(7, 14)) | ||
images = self.generate(xp.eye(num_labels).repeat(num_images, axis=0),channels, height, width, cfg_scale) | ||
for i in range(num_labels): | ||
for j in range(num_images): | ||
axes[i, j].imshow(images[i*num_images+j], cmap='gray') | ||
axes[i, j].axis('off') | ||
|
||
axes[i, 0].text(-0.3, 0.5, f'{id2label[i]}', fontsize=12, verticalalignment='center', horizontalalignment='right', transform=axes[i, 0].transAxes) | ||
fig.savefig(image_path) | ||
|
||
def train_step(self, image, context): | ||
''' | ||
学習1ステップ分を実装、lossを返す。 | ||
''' | ||
|
||
# 加えるノイズ | ||
noise = xp.random.randn(*image.shape) | ||
|
||
# ランダムな時刻を選択 | ||
t = xp.random.randint(0, 1000, size=(image.shape[0], 1)).astype(xp.int32) | ||
|
||
# ノイズを加える | ||
noisy_image = self.sampler.add_noise(image, noise, t) | ||
|
||
# ノイズ予測 | ||
noise_pred = self.unet(noisy_image, t, context) | ||
|
||
# ノイズ予測と実際のノイズのMSEを計算 | ||
loss = F.mean_squared_error(noise, noise_pred) / (image.shape[1]*image.shape[2]*image.shape[3]) | ||
return loss | ||
|
||
|
||
if __name__ == "__main__": | ||
unet = UNet() | ||
ddpm = DDPM() | ||
diffusion = Diffusion(unet, ddpm) | ||
image = xp.random.randn(3, 1, 28, 28) | ||
loss = diffusion.train_step(image, xp.array([0, 1, 2])) | ||
# images = diffusion.generate(xp.array([0,1,2]),1,28,28) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import numpy as xp | ||
from modules.utils import expand_2d | ||
|
||
|
||
class DDPM: | ||
def __init__(self, beta_start=1e-4, beta_end=0.02, T=1000): | ||
''' | ||
Denoise Diffusion Probabilistic Modelの実装 | ||
引数のデフォルトは論文通りの値 | ||
''' | ||
self.beta_start = beta_start # beta_0 | ||
self.beta_end = beta_end # beta_T | ||
self.T = T | ||
self.beta = xp.linspace(beta_start, beta_end, T) # beta_0, ..., beta_T | ||
self.sqrt_beta = xp.sqrt(self.beta) | ||
self.alpha = 1 - self.beta # alpha_0, ..., alpha_T | ||
self.alpha_bar = xp.cumprod(self.alpha) # Π_{i=0}^t alpha_i | ||
self.sqrt_alpha_bar = xp.sqrt(self.alpha_bar) | ||
self.beta_bar = 1 - self.alpha_bar | ||
self.sqrt_beta_bar = xp.sqrt(self.beta_bar) | ||
self.one_over_sqrt_alpha = 1 / xp.sqrt(self.alpha) # ddpm.stepで使う | ||
self.beta_over_sqrt_beta_bar = self.beta / self.sqrt_beta_bar # ddpm.stepで使う | ||
|
||
def add_noise(self, x, noise, t): | ||
''' | ||
時刻tに応じたノイズを加える | ||
x_t = sqrt_alpha_bar_t * x_0 + sqrt_beta_bar_t * noise | ||
''' | ||
return expand_2d(self.sqrt_alpha_bar[t]) * x + expand_2d(self.sqrt_beta_bar[t]) * noise | ||
|
||
def step(self, x, noise_pred, t): | ||
''' | ||
x_t -> x_{t-1}のサンプリング | ||
x_{t-1} = 1/sqrt_alpha_t * (x_t - beta_t/sqrt_beta_bar_t * noise_pred) + sqrt_beta_t * noise | ||
''' | ||
noise = xp.random.randn(*x.shape) | ||
prev_x = self.one_over_sqrt_alpha[t] * (x - self.beta_over_sqrt_beta_bar[t] * noise_pred) + self.sqrt_beta[t] * noise | ||
return prev_x | ||
|
||
|
||
if __name__ == "__main__": | ||
ddpm = DDPM() | ||
x = xp.random.randn(2, 3, 28, 28) | ||
noise_pred = xp.random.randn(2, 3, 28, 28) | ||
t = 999 | ||
ddpm.step(x, noise_pred, t) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import dazero | ||
from dazero import DataLoader | ||
from tqdm import tqdm | ||
from dazero.transforms import Compose, ToFloat, Normalize | ||
import numpy as np | ||
import numpy as xp | ||
import os | ||
|
||
|
||
class Trainer: | ||
def __init__(self, diffusion, batch_size, lr, ucg=0.1, output_dir="outputs", dataset="mnist"): | ||
self.batch_size = batch_size | ||
self.diffusion = diffusion | ||
self.ucg = ucg | ||
if dataset == "mnist": | ||
self.train_set = dazero.datasets.MNIST(train=True, transform=Compose([ToFloat(), Normalize(127.5, 127.5)]),) | ||
elif dataset == "cifar10": | ||
self.train_set = dazero.datasets.CIFAR10(train=True, transform=Compose([ToFloat(), Normalize(127.5, 127.5)]),) | ||
else: | ||
raise ValueError(f"{dataset} is not supported.") | ||
|
||
self.train_loader = DataLoader(self.train_set, batch_size) | ||
self.train_loader.to_gpu() | ||
|
||
self.optimizer = dazero.optimizers.Adam().setup(self.diffusion.unet) | ||
self.optimizer.add_hook(dazero.optimizers.WeightDecay(lr)) | ||
|
||
self.output_dir = os.path.join(output_dir, "models") | ||
self.image_dir = os.path.join(output_dir, "images") | ||
self.log_dir = os.path.join(output_dir, "logs") | ||
os.makedirs(self.output_dir, exist_ok=True) | ||
os.makedirs(self.image_dir, exist_ok=True) | ||
os.makedirs(self.log_dir, exist_ok=True) | ||
|
||
|
||
def train(self, epochs, save_n_epochs=5, sample_cfg_scale=3.0, limited_steps=10000000): | ||
progress_bar = tqdm(range(epochs*len(self.train_set)//self.batch_size), desc="Total Steps", leave=False) | ||
loss_ema = None | ||
loss_emas = [] | ||
for epoch in range(epochs): | ||
steps = 0 | ||
for x, c in self.train_loader: | ||
|
||
ucg_random = xp.random.uniform(0, 1, size=(x.shape[0], 1)).astype(xp.float32) > self.ucg | ||
context = xp.eye(self.diffusion.unet.context_dim)[c] # one hot vector | ||
context *= ucg_random | ||
|
||
loss = self.diffusion.train_step(x, context) | ||
self.diffusion.unet.cleargrads() | ||
loss.backward() | ||
self.optimizer.update() | ||
|
||
if loss_ema is not None: | ||
loss_ema = 0.9 * loss_ema + 0.1 * float(loss.data) | ||
else: | ||
loss_ema = float(loss.data) | ||
loss_emas.append(loss_ema) | ||
|
||
progress_bar.update(1) | ||
progress_bar.set_postfix({"loss": loss_ema}) | ||
steps += 1 | ||
if steps > limited_steps: # test用 | ||
break | ||
|
||
if ((epoch+1) % save_n_epochs) == 0: | ||
self.diffusion.unet.save_weights(os.path.join(self.output_dir, f"model_{epoch:02}.npz")) | ||
self.diffusion.unet.to_gpu() # セーブ時にcpuに移動してしまう仕様 | ||
np.save(os.path.join(self.log_dir, f"log_{epoch:02}.npy"), np.array(loss_emas)) | ||
self.diffusion.generate_grid(4, x.shape[1], x.shape[2], x.shape[3], os.path.join(self.image_dir, f"image_{epoch:02}.png"), self.train_set.labels(), cfg_scale=sample_cfg_scale) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from dazero.models import Model, Sequential | ||
import dazero.functions as F | ||
import dazero.layers as L | ||
from dazero.core import Function | ||
|
||
from utils import expand_2d | ||
import numpy as xp | ||
|
||
|
||
class Cat(Function): | ||
''' | ||
dazeroにはcatが定義されていないので、chatgptに作ってもらった。 | ||
''' | ||
def __init__(self, axis=0): | ||
self.axis = axis | ||
|
||
def forward(self, *inputs): | ||
z = xp.concatenate(inputs, axis=self.axis) | ||
return z | ||
|
||
def backward(self, gz): | ||
inputs = self.inputs | ||
gradients = [] | ||
start_idx = 0 | ||
|
||
for x in inputs: | ||
end_idx = start_idx + x.shape[self.axis] | ||
|
||
indices = [slice(None)] * gz.ndim | ||
indices[self.axis] = slice(start_idx, end_idx) | ||
|
||
gradients.append(gz[tuple(indices)]) | ||
|
||
start_idx = end_idx | ||
|
||
return tuple(gradients) | ||
|
||
|
||
def cat(inputs, axis=0): | ||
return Cat(axis=axis)(*inputs) | ||
|
||
|
||
class ConvBlock(Model): | ||
''' | ||
複数の畳み込み層+ばっちのーむ+ReLUによるブロック。 | ||
最後にアップサンプリングかダウンサンプリングを行うこともある(lastで指定)。 | ||
''' | ||
def __init__(self, channels, num_layers, last=None): | ||
''' | ||
channels: 畳み込み層の出力チャンネル数 | ||
num_layers: 畳み込み層の数 | ||
last: None or "up" or "down" | ||
''' | ||
super().__init__() | ||
convs = [] | ||
norms = [] | ||
for _ in range(num_layers): | ||
convs.append(L.Conv2d(channels, kernel_size=3, pad=1, nobias=True)) | ||
norms.append(L.BatchNorm()) | ||
|
||
self.convs = Sequential(*convs) | ||
self.norms = Sequential(*norms) | ||
|
||
if last == "up": | ||
self.last = L.Deconv2d(channels, kernel_size=4, stride=2, pad=1) | ||
elif last == "down": | ||
self.last = L.Conv2d(channels, kernel_size=3, stride=2, pad=1) | ||
else: | ||
self.last = None | ||
|
||
def forward(self, x): | ||
for conv, norm in zip(self.convs.layers, self.norms.layers): | ||
x = F.relu(norm(conv(x))) | ||
|
||
if self.last is not None: | ||
x = self.last(x) | ||
return x | ||
|
||
|
||
class UNet(Model): | ||
def __init__(self, out_channels=1, context_dim=10, hidden_channels=16, num_blocks=2, num_layers=3): | ||
''' | ||
out_channels: 出力画像のチャンネル数 | ||
context_dim: ラベルの数 | ||
hidden_channels: 中間のチャンネル数、ダウンサンプルごとに2倍になる。 | ||
num_blocks: ブロックの数。 | ||
num_layers: ブロックごとの畳み込み層の数。 | ||
''' | ||
super().__init__() | ||
self.context_dim = 10 | ||
self.conv_in = L.Conv2d(1, hidden_channels, kernel_size=3, pad=1) | ||
|
||
# 時刻[0,1000]を全結合層に入力する。本当はsinとか使うやつにしたい。 | ||
time_embs = [] | ||
for i in range(num_blocks): | ||
if i == 0: | ||
time_embs.append(L.Linear(hidden_channels, hidden_channels)) | ||
else: | ||
time_embs.append(L.Linear(hidden_channels, hidden_channels*(2**(i-1)))) | ||
self.time_embs = Sequential(*time_embs) | ||
|
||
# one hot vectorのラベルを全結合層に入力する。 | ||
context_embs = [] | ||
for i in range(num_blocks): | ||
if i == 0: | ||
context_embs.append(L.Linear(hidden_channels, hidden_channels)) | ||
else: | ||
context_embs.append(L.Linear(hidden_channels, hidden_channels*(2**(i-1)))) | ||
self.context_embs = Sequential(*context_embs) | ||
|
||
self.down_blocks = Sequential( | ||
*[ConvBlock(hidden_channels*(2**i), num_layers, "down") for i in range(num_blocks)] | ||
) | ||
|
||
self.mid_blocks = ConvBlock(hidden_channels*2**num_blocks, num_layers) | ||
|
||
self.up_blocks = Sequential( | ||
*[ConvBlock(hidden_channels*(2**(num_blocks-i)), num_layers, "up") for i in range(num_blocks)] | ||
) | ||
|
||
self.conv_out = L.Conv2d(out_channels, kernel_size=3, pad=1) | ||
|
||
def forward(self, x, t, context): | ||
t = t.astype(xp.float32) / 1000 # [0,1000] -> [0,1] | ||
h = self.conv_in(x) | ||
hs = [h] # skip connection | ||
for down_block, time_emb, context_emb in zip(self.down_blocks.layers, self.time_embs.layers, self.context_embs.layers): | ||
emb = time_emb(t) + context_emb(context) # 時刻埋め込み、ラベル埋め込み | ||
emb = expand_2d(emb) | ||
h = down_block(h + emb) | ||
hs.append(h) # skip connection | ||
|
||
h = self.mid_blocks(h) | ||
|
||
for up_block in self.up_blocks.layers: | ||
res = hs.pop() | ||
h = up_block(cat((h, res), axis=1)) # skip connectionを結合 | ||
|
||
h = self.conv_out(h) | ||
return h | ||
|
||
|
||
if __name__ == "__main__": | ||
x = xp.random.randn(1, 1, 28, 28).astype(xp.float32) | ||
t = xp.random.randint(0, 1000, size=(1, 1)).astype(xp.int32) | ||
c = xp.array([1]) | ||
model = UNet(1, 4, 2, 2) | ||
model.to_gpu() | ||
y = model(x, t, c) | ||
print(y.shape) |
Oops, something went wrong.