-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
190 lines (164 loc) · 8.1 KB
/
train.py
File metadata and controls
190 lines (164 loc) · 8.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import argparse
import os
from src import *
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
def main():
parser = argparse.ArgumentParser(description="Train the diffusion model and optionally run generation")
parser.add_argument("--save_dir_weights", type=str, help="Directory to save the weights from training")
parser.add_argument("--save_periods", type=int, help="Integer to specify how often the model saves weights. Ex: save every 5 epochs")
parser.add_argument("--inference_outs_dir", type=str, help="Directory to save the outputs from the generation")
parser.add_argument("--inference", type=str, help="Available options: ddpm, ddim")
parser.add_argument("--batch_size", type=int, default=32, help="Size of batch to train and learn from")
parser.add_argument("--epochs", type=int, default=2, help="Number of epochs to train on")
parser.add_argument("--lr", type=float, default=3e-3, help="Learning rate fro training")
parser.add_argument("--features", type=int, default=64, help="Size of the hidden layer from the U-net architecture")
parser.add_argument("--T", type=int, default=500, help="The size of steps to reverse within the diffusion process")
parser.add_argument("--beta_start", type=float, default=1e-3, help="Start of the beta scheduler")
parser.add_argument("--beta_end", type=float, default=0.02, help="End of the beta scheduler")
parser.add_argument("--dataset_name", type=str, default="sprites", help="Dataset name to train on")
args = parser.parse_args()
train_generate(args)
def train_generate(args):
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else 'cpu'
betas = (args.beta_end - args.beta_start) * torch.linspace(0, 1, args.T+1, device=device) + args.beta_start
alphas = 1 - betas
alphas_hat = torch.cumsum(alphas.log(), dim=0).exp() # numerical stability trick
alphas_hat[0] = 1
# DATASET
if args.dataset_name == 'sprites':
context_features = 5
image_size = (16, 16)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
dataset = SpritesDataset("data/sprites/sprites_1788_16x16.npy",
"data/sprites/sprite_labels_nc_1788_16x16.npy",
transform,
null_context=False)
elif args.dataset_name == 'cifar10':
context_features = 10
image_size = (32, 32)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
dataset = datasets.CIFAR10(
root="data/cifar10/",
train=True,
download=True,
transform=transform
)
dataset = CIFAR10OneHot(dataset, num_classes=10)
model = ContextUnet(in_channels=3, features=args.features, context_features=context_features, image_size=image_size).to(device)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) if device=='cuda' else DataLoader(dataset, args.batch_size, shuffle=True)
optim = torch.optim.Adam(model.parameters(), lr=args.lr)
model.train()
# TRAIN
def perturb_input(x, t, noise):
return alphas_hat.sqrt()[t, None, None, None] * x + (1 - alphas_hat[t, None, None, None]).sqrt() * noise
losses_save = []
for epoch in range(args.epochs):
optim.param_groups[0]['lr'] = args.lr*(1-epoch/args.epochs)
batch_losses = []
for x, c in dataloader:
optim.zero_grad()
x = x.to(device)
c = c.float().to(device)
context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9).to(device)
c = c * context_mask.unsqueeze(-1)
noise = torch.randn_like(x)
t = torch.randint(1, args.T+1, (x.shape[0],)).to(device)
x_pert = perturb_input(x, t, noise)
pred_noise = model(x_pert, t/args.T, c)
loss = F.mse_loss(pred_noise, noise)
batch_losses.append(loss.item())
loss.backward()
optim.step()
epoch_loss = float(np.mean(batch_losses))
print(f"epoch: {epoch} - loss: {epoch_loss:.4f}")
losses_save.append(epoch_loss)
# save model periodically
if args.save_dir_weights:
if epoch % args.save_periods == 0 or epoch == args.epochs - 1:
os.makedirs(args.save_dir_weights, exist_ok=True)
save_path = os.path.join(args.save_dir_weights, f"model_{epoch}.pth")
torch.save(model.state_dict(), save_path)
print(f"saved model at {save_path}")
torch.save(model.state_dict(), save_path)
print(f"saved model at {save_path}")
def denoise_add_noise(x, t, pred_noise, z=None):
if z is None:
z = torch.randn_like(x)
noise = betas.sqrt()[t] * z
mean = (x - pred_noise * ((1 - alphas[t]) / (1 - alphas_hat[t]).sqrt())) / alphas[t].sqrt()
return mean + noise
@torch.no_grad()
def sample_ddpm(n_sample, context ,save_rate=20):
samples = torch.randn(n_sample, 3, image_size[0], image_size[1]).to(device)
intermediate = []
for i in range(args.T, 0, -1):
print(f'sampling timestep {i:3d}', end='\r')
t = torch.tensor([i / args.T])[:, None, None, None].to(device)
z = torch.randn_like(samples) if i > 1 else 0
eps = model(samples, t, c=context)
samples = denoise_add_noise(samples, i, eps, z)
if i % save_rate ==0 or i==args.T or i<8:
intermediate.append(samples.detach().cpu().numpy())
return samples, np.stack(intermediate)
def denoise_ddim(x, t, t_prev, pred_noise):
ab = alphas_hat[t]
ab_prev = alphas_hat[t_prev]
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
dir_xt = (1 - ab_prev).sqrt() * pred_noise
return x0_pred + dir_xt
@torch.no_grad()
def sample_ddim(n_sample, context ,n=20):
samples = torch.randn(n_sample, 3, image_size[0], image_size[1]).to(device)
intermediate = []
step_size = args.T // n
for i in range(args.T, 0, -step_size):
print(f'sampling timestep {i:3d}', end='\r')
t = torch.tensor([i / args.T])[:, None, None, None].to(device)
eps = model(samples, t, c=context)
samples = denoise_ddim(samples, i, i-step_size, eps)
intermediate.append(samples.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return samples, intermediate
if args.inference_outs_dir:
model.load_state_dict(torch.load(f"{args.save_dir_weights}/model_{str(epoch)}.pth", map_location=device))
model.eval()
print("Loaded in Model")
# visualize samples
ctx = torch.tensor([
# hero, non-hero, food, spell, side-facing
[1,0,0,0,0],
[1,0,0,0,0],
[0,0,0,0,1],
[0,0,0,0,1],
[0,1,0,0,0],
[0,1,0,0,0],
[0,0,1,0,0],
[0,0,1,0,0],
]).float().to(device)
import time
if args.inference=="ddpm":
start = time.time()
samples, _ = sample_ddpm(ctx.shape[0], ctx)
speed = time.time()-start
print(f'DDPM: generated outputs. taken {speed} time')
torch.save(samples, f'{args.inference_outs_dir}ddpm_file.pt')
torch.save(torch.tensor(losses_save), f'{args.inference_outs_dir}ddpm_loss.pt')
if args.inference=="ddim":
start = time.time()
samples, _ = sample_ddim(ctx.shape[0], context=ctx, n=25)
speed = time.time()-start
print(f'DDIM: generated outputs. taken {speed} time')
torch.save(samples, f'{args.inference_outs_dir}ddim_file.pt')
torch.save(torch.tensor(losses_save), f'{args.inference_outs_dir}ddim_loss.pt')
if __name__=="__main__":
main()