A015-基于DCGAN模型实现彩色图像生成

A015-基于DCGAN模型实现彩色图像生成

导出时间:2025/12/16 16:50:31

1、题目:彩色图像生成

数据文件

CIFAR-10 是一个用于图像分类任务的广泛使用的数据集,包含 10 个不同类别的彩色图片。每个类别包含 6000 张图片,总共 60000 张图片,大小为 32x32 像素,分为 50000 张训练图片和 10000 张测试图片。这个数据集在机器学习和计算机视觉任务中非常常见。
CIFAR-10 的类别包括:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)。分析结果时,既可以给出总体性能,又可以按类型进行分析。
AJClbt2Nbojoa3xbyPGcwFR3nrf.png

性能指标

主观指标:生成图像质量主观评价,对比数据集中的真实图像。报告中可以给出随着训练迭代轮数增加,生成图像结果的变化情况。
客观指标:使用Inception Score(IS)和Frechet Inception Distance(FID)等评价指标,分析生成图像的质量。
# 首先安装torch-fidelity库
pip install torch-fidelity

import torch_fidelity
def fidelity_metric(genereated_images_path, real_images_path):
"""
使用fidelity package计算所有的生成相关的指标,输入生成图像路径和真实图像路径
isc: inception score
kid: kernel inception distance
fid: frechet inception distance
"""
  metrics_dict = torch_fidelity.calculate_metrics(
    input1=genereated_images_path,
    input2=real_images_path,
    cuda=True,
    isc=True,
    fid=True,
    kid=True,
    verbose=False
  )
  return metrics_dict
第四个题目的数据集可以通过以下代码获取:
from torchvision.datasets import CIFAR10
dataset = CIFAR10(root='./CIFARdata', download=True, transform=transforms.ToTensor())

2、项目说明

一、项目概述

本项目旨在利用深度学习中的生成对抗网络(Generative Adversarial Network, GAN)方法,针对 CIFAR-10 彩色图像数据集进行模型训练,实现对 10 类常见物体(如飞机、汽车、猫、狗等)的彩色图像生成。通过设计并实现 深度卷积生成对抗网络(DCGAN),使生成模型能够在低分辨率图像(32×32)上生成视觉上逼真的样本。
项目采用 PyTorch 框架 实现,训练完成后使用主观视觉评估和客观评价指标(Inception Score, FID, KID)对生成图像质量进行综合分析。

二、数据集说明

1. 数据集简介

  • 名称:CIFAR-10
  • 来源:Canadian Institute for Advanced Research (CIFAR)
  • 数据量:共 60,000 张彩色图片
    • 训练集:50,000 张
    • 测试集:10,000 张
  • 图片尺寸:32×32×3
  • 类别:10 类(每类 6,000 张)
    • airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck

2. 数据预处理

  • 图片统一缩放至 32×32;
  • 像素值归一化至 [-1, 1];
  • 使用 torchvision.transforms.Normalize([0.5]*3, [0.5]*3)
  • 训练时打乱顺序(shuffle=True)以增强随机性。

三、模型设计

1. 模型框架

采用 DCGAN(Deep Convolutional GAN) 架构,由两个主要模块组成:
  • 生成器(Generator, G) 输入为 128 维随机噪声向量(Z ~ N(0,1)),通过多层反卷积网络(ConvTranspose2d)逐步上采样生成 32×32 的彩色图像。 激活函数为 ReLU,输出层使用 Tanh,将像素值约束在 [-1, 1]。
  • 判别器(Discriminator, D) 输入为图像(真实或生成),通过多层卷积提取特征并输出单个标量,用于判定图片的“真实性”。 激活函数为 LeakyReLU(负斜率 0.2)。

2. 关键网络结构

生成器结构:
层级
类型
输出尺寸
激活函数
输入
z ∈ R^(128×1×1)
-
-
1
ConvTranspose2d(128, 256, 4, 1, 0)
4×4
ReLU
2
ConvTranspose2d(256, 128, 4, 2, 1)
8×8
ReLU
3
ConvTranspose2d(128, 64, 4, 2, 1)
16×16
ReLU
4
ConvTranspose2d(64, 3, 4, 2, 1)
32×32
Tanh
判别器结构:
层级
类型
输出尺寸
激活函数
输入
图像 (3×32×32)
-
-
1
Conv2d(3, 64, 4, 2, 1)
16×16
LeakyReLU(0.2)
2
Conv2d(64, 128, 4, 2, 1)
8×8
LeakyReLU(0.2)
3
Conv2d(128, 256, 4, 2, 1)
4×4
LeakyReLU(0.2)
4
Conv2d(256, 1, 4, 1, 0)
1×1
Sigmoid(隐式)

四、训练设置

参数
说明
优化器
Adam
β1=0.5, β2=0.999
学习率
2.00E-04
G 与 D 共享
批量大小
128

训练轮数
50

噪声维度
128

损失函数
BCEWithLogitsLoss
适合二分类输出
训练时,每轮都会保存生成样例图像,以便观察生成效果随迭代次数的提升情况。

五、实验结果与评估

1. 主观评估

通过对比生成图像随训练轮数的变化(保存于 runs_dcgan_cifar10/samples/ 目录),可以直观观察:
  • 早期生成图像模糊且无明显结构;
  • 中期开始出现类属形状(如动物、交通工具等);
  • 后期颜色趋于自然,轮廓清晰,模式多样。

2. 客观评估

使用 torch-fidelity 计算指标:
指标
含义
越高/越低
理想情况
Inception Score (IS)
多样性与清晰度
越高越好
≈ CIFAR-10 baseline
Frechet Inception Distance (FID)
与真实分布距离
越低越好
< 50
Kernel Inception Distance (KID)
稳健小样本指标
越低越好
趋近 0

六、结果分析与改进方向

模型表现

DCGAN 在 CIFAR-10 上能够生成具备一定可辨识度的彩色图片,尤其在动物类样本上效果较好。然而,由于图像分辨率较低(32×32),生成细节有限。

潜在改进

  • 使用条件 GAN (cGAN):引入类别条件,使生成器针对特定类别生成图像;
  • 采用 WGAN-GP:通过梯度惩罚改善训练稳定性;
  • 引入 EMA(Exponential Moving Average):平滑生成器参数,提升视觉质量;
  • 提升分辨率:通过多尺度架构(如 Progressive GAN 或 StyleGAN)生成更清晰图像。

3、项目演示


安装依赖
pip install torch torchvision
pip install torch-fidelity

训练并保存每轮样例图(主观指标可视化)
python train_dcgan_cifar10.py --epochs 50 --batch_size 128 --n_vis 8
可在 ./runs_dcgan_cifar10/samples/ 下看到 samples_epoch_XXX.png,直观看不同轮次生成质量的变化
训练后直接客观评估(IS / FID / KID)
python train_dcgan_cifar10.py --epochs 50 --eval_after_train --eval_gen_num 10000 --eval_batch 256 --real_split test
image.png
image.pngimage.png
模型训练过程的变化,如上图,
image.png