0% found this document useful (0 votes)
10 views

Aai 3

The document details the training of a generative adversarial network (GAN) to generate MNIST handwritten digits. It defines discriminator and generator models, loads or trains the models, and logs the loss at each epoch. It generates samples from the trained generator and plots the results.

Uploaded by

ahmed.412052.cs
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
10 views

Aai 3

The document details the training of a generative adversarial network (GAN) to generate MNIST handwritten digits. It defines discriminator and generator models, loads or trains the models, and logs the loss at each epoch. It generates samples from the trained generator and plots the results.

Uploaded by

ahmed.412052.cs
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 4

import torch

from torch import nn, optim

import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import os

torch.manual_seed(111)

<torch._C.Generator at 0x7cb5041e1590>

device = ''
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])

train_set = torchvision.datasets.MNIST(root='.',
train=True,
download=True,
transform=transform)

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=batch_size,
shuffle=True)

plt.figure(dpi=150)
real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):
ax = plt.subplot(4, 4, i+1)
plt.imshow(real_samples[i].reshape(28, 28), cmap='gray_r')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)

def forward(self, x):


x = x.view(x.size(0), 784)
output = self.model(x)
return output

discriminator = Discriminator().to(device=device)

class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh(),
)

def forward(self, x):


output = self.model(x)
output = output.view(x.size(0), 1, 28, 28)
return output

generator = Generator().to(device=device)

lr = 0.0001
num_epochs = 20
loss_function = nn.BCELoss()

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)


optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

latent_space_samples_plot = torch.randn((16, 100)).to(device=device)

# Load trained NN when it exists, or train a new NN


if os.path.isfile('discriminator.pt') and os.path.isfile('generator.pt'):
discriminator.load_state_dict(torch.load('./discriminator.pt'))
generator.load_state_dict(torch.load('./generator.pt'))
else:
for epoch in range(num_epochs):
for n, (real_samples, mnist_labels) in enumerate(train_loader):
# Data for training the discriminator
real_samples = real_samples.to(device=device)
real_samples_labels = torch.ones((batch_size, 1)).to(device=device)
latent_space_samples = torch.randn((batch_size, 100)).to(device=device)
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros(
(batch_size, 1)).to(device=device)
all_samples = torch.cat((real_samples, generated_samples))
all_samples_labels = torch.cat(
(real_samples_labels, generated_samples_labels))

# Training the discriminator


discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(
output_discriminator, all_samples_labels)
loss_discriminator.backward()
optimizer_discriminator.step()

# Data for training the generator


latent_space_samples = torch.randn((batch_size, 100)).to(device=device)

# Training the generator


generator.zero_grad()
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels)
loss_generator.backward()
optimizer_generator.step()

# Show loss
if n == batch_size - 1:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
print(f"Epoch: {epoch} Loss G.: {loss_generator}")

Epoch: 0 Loss D.: 0.5897145867347717


Epoch: 0 Loss G.: 0.45927101373672485
Epoch: 1 Loss D.: 0.02344338968396187
Epoch: 1 Loss G.: 6.751859664916992
Epoch: 2 Loss D.: 0.016641004011034966
Epoch: 2 Loss G.: 5.987716197967529
Epoch: 3 Loss D.: 0.2856064736843109
Epoch: 3 Loss G.: 5.64042329788208
Epoch: 4 Loss D.: 0.031034916639328003
Epoch: 4 Loss G.: 4.782405376434326
Epoch: 5 Loss D.: 0.0478241890668869
Epoch: 5 Loss G.: 4.23112678527832
Epoch: 6 Loss D.: 0.23478065431118011
Epoch: 6 Loss G.: 3.270514488220215
Epoch: 7 Loss D.: 0.4134957194328308
Epoch: 7 Loss G.: 2.6753756999969482
Epoch: 8 Loss D.: 0.2761218547821045
Epoch: 8 Loss G.: 2.560955762863159
Epoch: 9 Loss D.: 0.3327544927597046
Epoch: 9 Loss G.: 2.162978172302246
Epoch: 10 Loss D.: 0.36913758516311646
Epoch: 10 Loss G.: 2.086801052093506
Epoch: 11 Loss D.: 0.40733101963996887
Epoch: 11 Loss G.: 1.3113739490509033
Epoch: 12 Loss D.: 0.420818954706192
Epoch: 12 Loss G.: 1.4144489765167236
Epoch: 13 Loss D.: 0.4003124237060547
Epoch: 13 Loss G.: 1.4187294244766235
Epoch: 14 Loss D.: 0.3528069257736206
Epoch: 14 Loss G.: 1.2073755264282227
Epoch: 15 Loss D.: 0.45277005434036255
Epoch: 15 Loss G.: 1.5455167293548584
Epoch: 16 Loss D.: 0.5453391075134277
Epoch: 16 Loss G.: 1.3073005676269531
Epoch: 17 Loss D.: 0.49497324228286743
Epoch: 17 Loss G.: 1.3250724077224731
Epoch: 18 Loss D.: 0.5684107542037964
Epoch: 18 Loss G.: 1.2762668132781982
Epoch: 19 Loss D.: 0.5791763067245483
Epoch: 19 Loss G.: 1.1964823007583618

latent_space_samples = torch.randn(batch_size, 100).to(device=device)

generated_samples = generator(latent_space_samples)
generated_samples = generated_samples.cpu().detach()

plt.figure(dpi=150)
for i in range(16):
ax = plt.subplot(4, 4, i+1)
plt.imshow(generated_samples[i].reshape(28, 28), cmap='gray_r')
plt.xticks([])
plt.yticks([])
plt.tight_layout()

You might also like