A015-基于DCGAN模型实现彩色图像生成
导出时间:2025/12/16 16:50:31
1、题目:彩色图像生成
数据文件
CIFAR-10 是一个用于图像分类任务的广泛使用的数据集,包含 10 个不同类别的彩色图片。每个类别包含 6000 张图片,总共 60000 张图片,大小为 32x32 像素,分为 50000 张训练图片和 10000 张测试图片。这个数据集在机器学习和计算机视觉任务中非常常见。
CIFAR-10 的类别包括:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)。分析结果时,既可以给出总体性能,又可以按类型进行分析。
性能指标
主观指标:生成图像质量主观评价,对比数据集中的真实图像。报告中可以给出随着训练迭代轮数增加,生成图像结果的变化情况。
客观指标:使用Inception Score(IS)和Frechet Inception Distance(FID)等评价指标,分析生成图像的质量。
# 首先安装torch-fidelity库
pip install torch-fidelity
import torch_fidelity
def fidelity_metric(genereated_images_path, real_images_path):
"""
使用fidelity package计算所有的生成相关的指标,输入生成图像路径和真实图像路径
isc: inception score
kid: kernel inception distance
fid: frechet inception distance
"""
metrics_dict = torch_fidelity.calculate_metrics(
input1=genereated_images_path,
input2=real_images_path,
cuda=True,
isc=True,
fid=True,
kid=True,
verbose=False
)
return metrics_dict
第四个题目的数据集可以通过以下代码获取:
from torchvision.datasets import CIFAR10
dataset = CIFAR10(root='./CIFARdata', download=True, transform=transforms.ToTensor())
2、项目说明
一、项目概述
本项目旨在利用深度学习中的生成对抗网络(Generative Adversarial Network, GAN)方法,针对 CIFAR-10 彩色图像数据集进行模型训练,实现对 10 类常见物体(如飞机、汽车、猫、狗等)的彩色图像生成。通过设计并实现 深度卷积生成对抗网络(DCGAN),使生成模型能够在低分辨率图像(32×32)上生成视觉上逼真的样本。
项目采用 PyTorch 框架 实现,训练完成后使用主观视觉评估和客观评价指标(Inception Score, FID, KID)对生成图像质量进行综合分析。
二、数据集说明
1. 数据集简介
- 名称:CIFAR-10
- 来源:Canadian Institute for Advanced Research (CIFAR)
- 数据量:共 60,000 张彩色图片
- 训练集:50,000 张
- 测试集:10,000 张
- 图片尺寸:32×32×3
- 类别:10 类(每类 6,000 张)
- airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
2. 数据预处理
- 图片统一缩放至 32×32;
- 像素值归一化至 [-1, 1];
- 使用
torchvision.transforms.Normalize([0.5]*3, [0.5]*3); - 训练时打乱顺序(shuffle=True)以增强随机性。
三、模型设计
1. 模型框架
采用 DCGAN(Deep Convolutional GAN) 架构,由两个主要模块组成:
- 生成器(Generator, G) 输入为 128 维随机噪声向量(Z ~ N(0,1)),通过多层反卷积网络(ConvTranspose2d)逐步上采样生成 32×32 的彩色图像。 激活函数为 ReLU,输出层使用 Tanh,将像素值约束在 [-1, 1]。
- 判别器(Discriminator, D) 输入为图像(真实或生成),通过多层卷积提取特征并输出单个标量,用于判定图片的“真实性”。 激活函数为 LeakyReLU(负斜率 0.2)。
2. 关键网络结构
生成器结构:
层级
| 类型
| 输出尺寸
| 激活函数
|
输入
| z ∈ R^(128×1×1)
| -
| -
|
1
| ConvTranspose2d(128, 256, 4, 1, 0)
| 4×4
| ReLU
|
2
| ConvTranspose2d(256, 128, 4, 2, 1)
| 8×8
| ReLU
|
3
| ConvTranspose2d(128, 64, 4, 2, 1)
| 16×16
| ReLU
|
4
| ConvTranspose2d(64, 3, 4, 2, 1)
| 32×32
| Tanh
|
判别器结构:
层级
| 类型
| 输出尺寸
| 激活函数
|
输入
| 图像 (3×32×32)
| -
| -
|
1
| Conv2d(3, 64, 4, 2, 1)
| 16×16
| LeakyReLU(0.2)
|
2
| Conv2d(64, 128, 4, 2, 1)
| 8×8
| LeakyReLU(0.2)
|
3
| Conv2d(128, 256, 4, 2, 1)
| 4×4
| LeakyReLU(0.2)
|
4
| Conv2d(256, 1, 4, 1, 0)
| 1×1
| Sigmoid(隐式)
|
四、训练设置
参数
| 值
| 说明
|
优化器
| Adam
| β1=0.5, β2=0.999
|
学习率
| 2.00E-04
| G 与 D 共享
|
批量大小
| 128
| |
训练轮数
| 50
| |
噪声维度
| 128
| |
损失函数
| BCEWithLogitsLoss
| 适合二分类输出
|
训练时,每轮都会保存生成样例图像,以便观察生成效果随迭代次数的提升情况。
五、实验结果与评估
1. 主观评估
通过对比生成图像随训练轮数的变化(保存于
runs_dcgan_cifar10/samples/ 目录),可以直观观察:
- 早期生成图像模糊且无明显结构;
- 中期开始出现类属形状(如动物、交通工具等);
- 后期颜色趋于自然,轮廓清晰,模式多样。
2. 客观评估
使用
torch-fidelity 计算指标:
指标
| 含义
| 越高/越低
| 理想情况
|
Inception Score (IS)
| 多样性与清晰度
| 越高越好
| ≈ CIFAR-10 baseline
|
Frechet Inception Distance (FID)
| 与真实分布距离
| 越低越好
| < 50
|
Kernel Inception Distance (KID)
| 稳健小样本指标
| 越低越好
| 趋近 0
|
六、结果分析与改进方向
模型表现
DCGAN 在 CIFAR-10 上能够生成具备一定可辨识度的彩色图片,尤其在动物类样本上效果较好。然而,由于图像分辨率较低(32×32),生成细节有限。
潜在改进
- 使用条件 GAN (cGAN):引入类别条件,使生成器针对特定类别生成图像;
- 采用 WGAN-GP:通过梯度惩罚改善训练稳定性;
- 引入 EMA(Exponential Moving Average):平滑生成器参数,提升视觉质量;
- 提升分辨率:通过多尺度架构(如 Progressive GAN 或 StyleGAN)生成更清晰图像。
3、项目演示
安装依赖
pip install torch torchvision
pip install torch-fidelity
训练并保存每轮样例图(主观指标可视化)
python train_dcgan_cifar10.py --epochs 50 --batch_size 128 --n_vis 8
可在 ./runs_dcgan_cifar10/samples/ 下看到 samples_epoch_XXX.png,直观看不同轮次生成质量的变化
训练后直接客观评估(IS / FID / KID)
python train_dcgan_cifar10.py --epochs 50 --eval_after_train --eval_gen_num 10000 --eval_batch 256 --real_split test
模型训练过程的变化,如上图,