3 mins
TP : Implémenter un GAN simple avec MNIST (PyTorch)
Objectifs du TP #
Comprendre l’architecture d’un GAN (Generative Adversarial Network).
Implémenter :
- un générateur capable de produire des images 28×28 ressemblant à MNIST,
- un discriminateur qui distingue les vraies images MNIST des fausses images générées.
- Entraîner le GAN et visualiser les images produites.
Un GAN se compose de deux réseaux :
- Générateur (G) : prend en entrée un vecteur de bruit et retourne une image générée.
- Discriminateur (D) : prend une image et retourne une probabilité “vrai/faux” qu’elle soit réelle ou fausse.
G essaie de tromper D, et D essaie de détecter les fausses images.
Voici un code de départ travaillant avec le dataset MNIST.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# ---------------------------------------------------------
# Hyperparamètres
# ---------------------------------------------------------
latent_dim = 100
batch_size = 128
lr = 0.0002
epochs = 30
img_size = 28*28
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------------------------------------
# MNIST
# ---------------------------------------------------------
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# ---------------------------------------------------------
# Modèle: Générateur
# ---------------------------------------------------------
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# todo
)
def forward(self, z):
img = self.net(z)
return img
# ---------------------------------------------------------
# Modèle: Discriminateur
# ---------------------------------------------------------
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
# todo
)
def forward(self, img):
return self.net(img)
# ---------------------------------------------------------
# Initialisation
# ---------------------------------------------------------
G = Generator().to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=lr)
opt_D = optim.Adam(D.parameters(), lr=lr)
# ---------------------------------------------------------
# Entraînement
# ---------------------------------------------------------
for epoch in range(epochs):
for imgs, _ in loader:
# Vraies images
real_imgs = imgs.view(imgs.size(0), -1).to(device)
real_labels = torch.ones(imgs.size(0), 1).to(device) # label à 1=Vrai
# Images générées
z = torch.randn(imgs.size(0), latent_dim).to(device)
fake_imgs = G(z)
fake_labels = torch.zeros(imgs.size(0), 1).to(device) # label à 0=Faux
# -----------------------------
# 1) Entraîner le discriminateur
# -----------------------------
D_real = D(real_imgs)
D_fake = D(fake_imgs.detach())
# TODO : loss de D
opt_D.zero_grad()
loss_D.backward()
opt_D.step()
# -----------------------------
# Entraîner le générateur
# -----------------------------
D_fake = D(fake_imgs)
# TODO : loss de G
opt_G.zero_grad()
loss_G.backward()
opt_G.step()
print(f"Epoch {epoch+1}/{epochs} | Loss_D = {loss_D.item():.4f} | Loss_G = {loss_G.item():.4f}")
# ---------------------------------------------------------
# Génération
# ---------------------------------------------------------
z = torch.randn(16, latent_dim).to(device)
samples = G(z).view(-1, 28, 28).cpu().detach()
fig, axes = plt.subplots(4, 4, figsize=(5, 5))
for i, ax in enumerate(axes.flatten()):
ax.imshow(samples[i], cmap="gray")
ax.axis("off")
plt.show()
A faire #
Complétez les TODO dans le code.
Des pistes pour aller plus loin
- Augmenter la taille du réseau (nombre de convolution),
- ajouter BatchNorm,
- tester DCGAN voir le tutrorial de pytorch
- entraîner sur FashionMNIST.