# create a wandb run
run = wandb.init(project="dlai_sprite_diffusion",
job_type="train",
config=config)
# we pass the config back from W&B
config = wandb.config
for ep in tqdm(range(config.n_epoch), leave=True, total=config.n_epoch):
# set into train mode
nn_model.train()
optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)
pbar = tqdm(dataloader, leave=False)
for x, c in pbar: # x: images c: context
optim.zero_grad()
x = x.to(DEVICE)
c = c.to(DEVICE)
context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.8).to(DEVICE)
c = c * context_mask.unsqueeze(-1)
noise = torch.randn_like(x)
t = torch.randint(1, config.timesteps + 1, (x.shape[0],)).to(DEVICE)
x_pert = perturb_input(x, t, noise)
pred_noise = nn_model(x_pert, t / config.timesteps, c=c)
loss = F.mse_loss(pred_noise, noise)
loss.backward()
optim.step()
wandb.log({"loss": loss.item(),
"lr": optim.param_groups[0]['lr'],
"epoch": ep})
# save model periodically
if ep%4==0 or ep == int(config.n_epoch-1):
nn_model.eval()
ckpt_file = SAVE_DIR/f"context_model.pth"
torch.save(nn_model.state_dict(), ckpt_file)
artifact_name = f"{wandb.run.id}_context_model"
at = wandb.Artifact(artifact_name, type="model")
at.add_file(ckpt_file)
wandb.log_artifact(at, aliases=[f"epoch_{ep}"])
samples, _ = sample_ddpm_context(nn_model,
noises,
ctx_vector[:config.num_samples])
wandb.log({
"train_samples": [
wandb.Image(img) for img in samples.split(1)
]})
# finish W&B run
wandb.finish()