Skip to content

Commit

Permalink
update to_gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
konas122 committed May 4, 2024
1 parent 929674b commit a8489bb
Show file tree
Hide file tree
Showing 11 changed files with 979 additions and 1 deletion.
3 changes: 3 additions & 0 deletions dazero/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

import dazero
import dazero.cuda


class Config:
Expand Down Expand Up @@ -173,6 +174,8 @@ def to_cpu(self):
self.data = dazero.cuda.as_numpy(self.data)

def to_gpu(self):
if dazero.cuda.gpu_enable == False:
return
if self.data is not None:
self.data = dazero.cuda.as_cupy(self.data)

Expand Down
3 changes: 3 additions & 0 deletions dazero/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def to_cpu(self):
self.gpu = False

def to_gpu(self):
if cuda.gpu_enable == False:
self.gpu = False
return
self.gpu = True


Expand Down
2 changes: 2 additions & 0 deletions dazero/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def to_cpu(self):
param.to_cpu()

def to_gpu(self):
if cuda.gpu_enable == False:
return
for param in self.params():
param.to_gpu()

Expand Down
592 changes: 592 additions & 0 deletions examples/Diffusion/diffusion.ipynb

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions examples/Diffusion/modules/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from tqdm import tqdm
import dazero
import dazero.functions as F
from PIL import Image
from modules.unet import UNet
from modules.sampler import DDPM
import numpy as xp
import matplotlib.pyplot as plt

class Diffusion:
'''
ノイズ予測・サンプラーを受け取って画像生成・学習ステップを定義する。
'''

def __init__(self, unet, sampler):
self.unet = unet
self.unet.to_gpu()
self.sampler = sampler

def generate(self, context, channels, height, width, cfg_scale = 1.0):
'''
画像生成を行うメソッド。
context: ラベルのxp配列
'''
batch_size = context.shape[0]
with dazero.test_mode():
with dazero.no_grad():
x = xp.random.randn(batch_size, channels, height, width) # 初期ノイズ x_0
for t in tqdm(reversed(range(1000)), total=1000): # t = 999, ..., 0
noise_pred = self.unet(x, xp.array([[t]]*batch_size).astype(xp.int32), context) # ノイズ予測
if cfg_scale != 1.0:
noise_pred_uncond = self.unet(x, xp.array([[t]]*batch_size).astype(xp.int32), context * 0) # ノイズ予測
noise_pred = noise_pred * cfg_scale + noise_pred_uncond * (1 - cfg_scale)
x = self.sampler.step(x, noise_pred, t) # x_{t+1} -> x_{t}

images = []
for image in x:
image = (xp.clip(image.data*127.5 + 127.5, 0, 255)).astype(xp.uint8) # 0~255に変換
image = xp.asnumpy(image)
image = image.transpose(1, 2, 0).squeeze()
image = Image.fromarray(image)
images.append(image)
return images


def generate_grid(self, num_images, channels, height, width, image_path, id2label, cfg_scale = 1.0):
'''
生成画像をラベルごとにグリッド状に並べて保存するメソッド。
'''
num_labels = self.unet.context_dim
fig, axes = plt.subplots(num_labels, num_images, figsize=(7, 14))
images = self.generate(xp.eye(num_labels).repeat(num_images, axis=0),channels, height, width, cfg_scale)
for i in range(num_labels):
for j in range(num_images):
axes[i, j].imshow(images[i*num_images+j], cmap='gray')
axes[i, j].axis('off')

axes[i, 0].text(-0.3, 0.5, f'{id2label[i]}', fontsize=12, verticalalignment='center', horizontalalignment='right', transform=axes[i, 0].transAxes)
fig.savefig(image_path)

def train_step(self, image, context):
'''
学習1ステップ分を実装、lossを返す。
'''

# 加えるノイズ
noise = xp.random.randn(*image.shape)

# ランダムな時刻を選択
t = xp.random.randint(0, 1000, size=(image.shape[0], 1)).astype(xp.int32)

# ノイズを加える
noisy_image = self.sampler.add_noise(image, noise, t)

# ノイズ予測
noise_pred = self.unet(noisy_image, t, context)

# ノイズ予測と実際のノイズのMSEを計算
loss = F.mean_squared_error(noise, noise_pred) / (image.shape[1]*image.shape[2]*image.shape[3])
return loss


if __name__ == "__main__":
unet = UNet()
ddpm = DDPM()
diffusion = Diffusion(unet, ddpm)
image = xp.random.randn(3, 1, 28, 28)
loss = diffusion.train_step(image, xp.array([0, 1, 2]))
# images = diffusion.generate(xp.array([0,1,2]),1,28,28)
46 changes: 46 additions & 0 deletions examples/Diffusion/modules/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as xp
from modules.utils import expand_2d


class DDPM:
def __init__(self, beta_start=1e-4, beta_end=0.02, T=1000):
'''
Denoise Diffusion Probabilistic Modelの実装
引数のデフォルトは論文通りの値
'''
self.beta_start = beta_start # beta_0
self.beta_end = beta_end # beta_T
self.T = T
self.beta = xp.linspace(beta_start, beta_end, T) # beta_0, ..., beta_T
self.sqrt_beta = xp.sqrt(self.beta)
self.alpha = 1 - self.beta # alpha_0, ..., alpha_T
self.alpha_bar = xp.cumprod(self.alpha) # Π_{i=0}^t alpha_i
self.sqrt_alpha_bar = xp.sqrt(self.alpha_bar)
self.beta_bar = 1 - self.alpha_bar
self.sqrt_beta_bar = xp.sqrt(self.beta_bar)
self.one_over_sqrt_alpha = 1 / xp.sqrt(self.alpha) # ddpm.stepで使う
self.beta_over_sqrt_beta_bar = self.beta / self.sqrt_beta_bar # ddpm.stepで使う

def add_noise(self, x, noise, t):
'''
時刻tに応じたノイズを加える
x_t = sqrt_alpha_bar_t * x_0 + sqrt_beta_bar_t * noise
'''
return expand_2d(self.sqrt_alpha_bar[t]) * x + expand_2d(self.sqrt_beta_bar[t]) * noise

def step(self, x, noise_pred, t):
'''
x_t -> x_{t-1}のサンプリング
x_{t-1} = 1/sqrt_alpha_t * (x_t - beta_t/sqrt_beta_bar_t * noise_pred) + sqrt_beta_t * noise
'''
noise = xp.random.randn(*x.shape)
prev_x = self.one_over_sqrt_alpha[t] * (x - self.beta_over_sqrt_beta_bar[t] * noise_pred) + self.sqrt_beta[t] * noise
return prev_x


if __name__ == "__main__":
ddpm = DDPM()
x = xp.random.randn(2, 3, 28, 28)
noise_pred = xp.random.randn(2, 3, 28, 28)
t = 999
ddpm.step(x, noise_pred, t)
69 changes: 69 additions & 0 deletions examples/Diffusion/modules/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import dazero
from dazero import DataLoader
from tqdm import tqdm
from dazero.transforms import Compose, ToFloat, Normalize
import numpy as np
import numpy as xp
import os


class Trainer:
def __init__(self, diffusion, batch_size, lr, ucg=0.1, output_dir="outputs", dataset="mnist"):
self.batch_size = batch_size
self.diffusion = diffusion
self.ucg = ucg
if dataset == "mnist":
self.train_set = dazero.datasets.MNIST(train=True, transform=Compose([ToFloat(), Normalize(127.5, 127.5)]),)
elif dataset == "cifar10":
self.train_set = dazero.datasets.CIFAR10(train=True, transform=Compose([ToFloat(), Normalize(127.5, 127.5)]),)
else:
raise ValueError(f"{dataset} is not supported.")

self.train_loader = DataLoader(self.train_set, batch_size)
self.train_loader.to_gpu()

self.optimizer = dazero.optimizers.Adam().setup(self.diffusion.unet)
self.optimizer.add_hook(dazero.optimizers.WeightDecay(lr))

self.output_dir = os.path.join(output_dir, "models")
self.image_dir = os.path.join(output_dir, "images")
self.log_dir = os.path.join(output_dir, "logs")
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.image_dir, exist_ok=True)
os.makedirs(self.log_dir, exist_ok=True)


def train(self, epochs, save_n_epochs=5, sample_cfg_scale=3.0, limited_steps=10000000):
progress_bar = tqdm(range(epochs*len(self.train_set)//self.batch_size), desc="Total Steps", leave=False)
loss_ema = None
loss_emas = []
for epoch in range(epochs):
steps = 0
for x, c in self.train_loader:

ucg_random = xp.random.uniform(0, 1, size=(x.shape[0], 1)).astype(xp.float32) > self.ucg
context = xp.eye(self.diffusion.unet.context_dim)[c] # one hot vector
context *= ucg_random

loss = self.diffusion.train_step(x, context)
self.diffusion.unet.cleargrads()
loss.backward()
self.optimizer.update()

if loss_ema is not None:
loss_ema = 0.9 * loss_ema + 0.1 * float(loss.data)
else:
loss_ema = float(loss.data)
loss_emas.append(loss_ema)

progress_bar.update(1)
progress_bar.set_postfix({"loss": loss_ema})
steps += 1
if steps > limited_steps: # test用
break

if ((epoch+1) % save_n_epochs) == 0:
self.diffusion.unet.save_weights(os.path.join(self.output_dir, f"model_{epoch:02}.npz"))
self.diffusion.unet.to_gpu() # セーブ時にcpuに移動してしまう仕様
np.save(os.path.join(self.log_dir, f"log_{epoch:02}.npy"), np.array(loss_emas))
self.diffusion.generate_grid(4, x.shape[1], x.shape[2], x.shape[3], os.path.join(self.image_dir, f"image_{epoch:02}.png"), self.train_set.labels(), cfg_scale=sample_cfg_scale)
150 changes: 150 additions & 0 deletions examples/Diffusion/modules/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from dazero.models import Model, Sequential
import dazero.functions as F
import dazero.layers as L
from dazero.core import Function

from utils import expand_2d
import numpy as xp


class Cat(Function):
'''
dazeroにはcatが定義されていないので、chatgptに作ってもらった。
'''
def __init__(self, axis=0):
self.axis = axis

def forward(self, *inputs):
z = xp.concatenate(inputs, axis=self.axis)
return z

def backward(self, gz):
inputs = self.inputs
gradients = []
start_idx = 0

for x in inputs:
end_idx = start_idx + x.shape[self.axis]

indices = [slice(None)] * gz.ndim
indices[self.axis] = slice(start_idx, end_idx)

gradients.append(gz[tuple(indices)])

start_idx = end_idx

return tuple(gradients)


def cat(inputs, axis=0):
return Cat(axis=axis)(*inputs)


class ConvBlock(Model):
'''
複数の畳み込み層+ばっちのーむ+ReLUによるブロック。
最後にアップサンプリングかダウンサンプリングを行うこともある(lastで指定)。
'''
def __init__(self, channels, num_layers, last=None):
'''
channels: 畳み込み層の出力チャンネル数
num_layers: 畳み込み層の数
last: None or "up" or "down"
'''
super().__init__()
convs = []
norms = []
for _ in range(num_layers):
convs.append(L.Conv2d(channels, kernel_size=3, pad=1, nobias=True))
norms.append(L.BatchNorm())

self.convs = Sequential(*convs)
self.norms = Sequential(*norms)

if last == "up":
self.last = L.Deconv2d(channels, kernel_size=4, stride=2, pad=1)
elif last == "down":
self.last = L.Conv2d(channels, kernel_size=3, stride=2, pad=1)
else:
self.last = None

def forward(self, x):
for conv, norm in zip(self.convs.layers, self.norms.layers):
x = F.relu(norm(conv(x)))

if self.last is not None:
x = self.last(x)
return x


class UNet(Model):
def __init__(self, out_channels=1, context_dim=10, hidden_channels=16, num_blocks=2, num_layers=3):
'''
out_channels: 出力画像のチャンネル数
context_dim: ラベルの数
hidden_channels: 中間のチャンネル数、ダウンサンプルごとに2倍になる。
num_blocks: ブロックの数。
num_layers: ブロックごとの畳み込み層の数。
'''
super().__init__()
self.context_dim = 10
self.conv_in = L.Conv2d(1, hidden_channels, kernel_size=3, pad=1)

# 時刻[0,1000]を全結合層に入力する。本当はsinとか使うやつにしたい。
time_embs = []
for i in range(num_blocks):
if i == 0:
time_embs.append(L.Linear(hidden_channels, hidden_channels))
else:
time_embs.append(L.Linear(hidden_channels, hidden_channels*(2**(i-1))))
self.time_embs = Sequential(*time_embs)

# one hot vectorのラベルを全結合層に入力する。
context_embs = []
for i in range(num_blocks):
if i == 0:
context_embs.append(L.Linear(hidden_channels, hidden_channels))
else:
context_embs.append(L.Linear(hidden_channels, hidden_channels*(2**(i-1))))
self.context_embs = Sequential(*context_embs)

self.down_blocks = Sequential(
*[ConvBlock(hidden_channels*(2**i), num_layers, "down") for i in range(num_blocks)]
)

self.mid_blocks = ConvBlock(hidden_channels*2**num_blocks, num_layers)

self.up_blocks = Sequential(
*[ConvBlock(hidden_channels*(2**(num_blocks-i)), num_layers, "up") for i in range(num_blocks)]
)

self.conv_out = L.Conv2d(out_channels, kernel_size=3, pad=1)

def forward(self, x, t, context):
t = t.astype(xp.float32) / 1000 # [0,1000] -> [0,1]
h = self.conv_in(x)
hs = [h] # skip connection
for down_block, time_emb, context_emb in zip(self.down_blocks.layers, self.time_embs.layers, self.context_embs.layers):
emb = time_emb(t) + context_emb(context) # 時刻埋め込み、ラベル埋め込み
emb = expand_2d(emb)
h = down_block(h + emb)
hs.append(h) # skip connection

h = self.mid_blocks(h)

for up_block in self.up_blocks.layers:
res = hs.pop()
h = up_block(cat((h, res), axis=1)) # skip connectionを結合

h = self.conv_out(h)
return h


if __name__ == "__main__":
x = xp.random.randn(1, 1, 28, 28).astype(xp.float32)
t = xp.random.randint(0, 1000, size=(1, 1)).astype(xp.int32)
c = xp.array([1])
model = UNet(1, 4, 2, 2)
model.to_gpu()
y = model(x, t, c)
print(y.shape)
Loading

0 comments on commit a8489bb

Please sign in to comment.