diff --git a/rfcs/Science/20250404_add_wgan_gp_for_paddlescience.md b/rfcs/Science/20250404_add_wgan_gp_for_paddlescience.md index e81972af3..c26998eb0 100644 --- a/rfcs/Science/20250404_add_wgan_gp_for_paddlescience.md +++ b/rfcs/Science/20250404_add_wgan_gp_for_paddlescience.md @@ -1,11 +1,11 @@ # 在 PaddleScience 中复现 WGAN-GP 模型 -| 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 | -| --- | --- | -| 提交作者 | robinbg | -| 提交时间 | 2025-04-04 | -| 版本号 | V1.0 | -| 依赖飞桨版本 | develop | +| 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 | +| --- |-------------------------------------------| +| 提交作者 | robinbg、XvLingWYY | +| 提交时间 | 2025-04-04 | +| 版本号 | V1.0 | +| 依赖飞桨版本 | develop | | 文件名 | 20250404_add_wgan_gp_for_paddlescience.md | # 一、概述 @@ -75,44 +75,106 @@ WGAN-GP 的核心在于其损失函数和梯度惩罚项的计算。以下是主 ### 1.1 损失函数 ```python -# 生成器损失 -def generator_loss(fake_output): - return -paddle.mean(fake_output) - -# 判别器损失(包含梯度惩罚) -def discriminator_loss(real_output, fake_output, gradient_penalty): - return paddle.mean(fake_output) - paddle.mean(real_output) + LAMBDA * gradient_penalty +# CIFAR10实验中生成器损失 +class Cifar10GenFuncs: + """ + Loss function for cifar10 generator + Args + discriminator_model: discriminator model + acgan_scale_g: scale of acgan loss for generator + + """ + + def __init__( + self, + discriminator_model, + acgan_scale_g=0.1, + ): + self.crossEntropyLoss = paddle.nn.CrossEntropyLoss() + self.acgan_scale_g = acgan_scale_g + self.discriminator_model = discriminator_model + + def loss(self, output_dict: Dict, *args): + fake_image = output_dict["fake_data"] + labels = output_dict["labels"] + outputs = self.discriminator_model({"data": fake_image, "labels": labels}) + disc_fake, disc_fake_acgan = outputs["disc_fake"], outputs["disc_acgan"] + gen_cost = -paddle.mean(disc_fake) + if disc_fake_acgan is not None: + gen_acgan_cost = self.crossEntropyLoss(disc_fake_acgan, labels) + gen_cost += self.acgan_scale_g * gen_acgan_cost + return {"loss_g": gen_cost} + +# CIFAR10实验中判别器损失 +class Cifar10DisFuncs: + """ + Loss function for cifar10 discriminator + Args + discriminator_model: discriminator model + acgan_scale: scale of acgan loss for discriminator + + """ + + def __init__(self, discriminator_model, acgan_scale): + self.crossEntropyLoss = paddle.nn.CrossEntropyLoss() + self.acgan_scale = acgan_scale + self.discriminator_model = discriminator_model + + def loss(self, output_dict: Dict, label_dict: Dict, *args): + fake_image = output_dict["fake_data"] + real_image = label_dict["real_data"] + labels = output_dict["labels"] + disc_fake = self.discriminator_model({"data": fake_image, "labels": labels})[ + "disc_fake" + ] + out = self.discriminator_model({"data": real_image, "labels": labels}) + disc_real, disc_real_acgan = out["disc_fake"], out["disc_acgan"] + gradient_penalty = self.compute_gradient_penalty(real_image, fake_image, labels) + disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real) + disc_wgan = disc_cost + gradient_penalty + if disc_real_acgan is not None: + disc_acgan_cost = self.crossEntropyLoss(disc_real_acgan, labels) + disc_acgan = disc_acgan_cost.sum() + disc_cost = disc_wgan + (self.acgan_scale * disc_acgan) + else: + disc_cost = disc_wgan + return {"loss_d": disc_cost} + + def compute_gradient_penalty(self, real_data, fake_data, labels): + differences = fake_data - real_data + alpha = paddle.rand([fake_data.shape[0], 1]) + interpolates = real_data + (alpha * differences) + gradients = paddle.grad( + outputs=self.discriminator_model({"data": interpolates, "labels": labels})[ + "disc_fake" + ], + inputs=interpolates, + create_graph=True, + retain_graph=False, + )[0] + slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1)) + gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2) + return gradient_penalty ``` ### 1.2 梯度惩罚计算 ```python -def gradient_penalty(discriminator, real_samples, fake_samples): - # 生成随机插值系数 - alpha = paddle.rand(shape=[real_samples.shape[0], 1, 1, 1]) - - # 创建真实样本和生成样本之间的插值 - interpolates = real_samples + alpha * (fake_samples - real_samples) - interpolates.stop_gradient = False - - # 计算判别器对插值样本的输出 - disc_interpolates = discriminator(interpolates) - - # 计算梯度 - gradients = paddle.grad( - outputs=disc_interpolates, - inputs=interpolates, - grad_outputs=paddle.ones_like(disc_interpolates), - create_graph=True, - retain_graph=True - )[0] - - # 计算梯度范数 - gradients_norm = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=[1, 2, 3])) - - # 计算梯度惩罚 - gradient_penalty = paddle.mean(paddle.square(gradients_norm - 1.0)) - - return gradient_penalty +# CIFAR-10 判别器中的梯度惩罚计算 +def compute_gradient_penalty(self, real_data, fake_data, labels): + differences = fake_data - real_data + alpha = paddle.rand([fake_data.shape[0], 1]) + interpolates = real_data + (alpha * differences) + gradients = paddle.grad( + outputs=self.discriminator_model({"data": interpolates, "labels": labels})[ + "disc_fake" + ], + inputs=interpolates, + create_graph=True, + retain_graph=False, + )[0] + slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1)) + gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2) + return gradient_penalty ``` ## 2. 网络架构 @@ -135,34 +197,12 @@ WGAN-GP 的训练流程与标准 GAN 有所不同,主要区别在于: ```python # 训练循环示例 -for iteration in range(ITERATIONS): - # 训练判别器 - for _ in range(CRITIC_ITERS): - real_data = next(data_iterator) - noise = paddle.randn([BATCH_SIZE, NOISE_DIM]) - - # 计算判别器损失 - fake_data = generator(noise) - real_output = discriminator(real_data) - fake_output = discriminator(fake_data) - gp = gradient_penalty(discriminator, real_data, fake_data) - d_loss = discriminator_loss(real_output, fake_output, gp) - - # 更新判别器参数 - d_optimizer.clear_grad() - d_loss.backward() - d_optimizer.step() - - # 训练生成器 - noise = paddle.randn([BATCH_SIZE, NOISE_DIM]) - fake_data = generator(noise) - fake_output = discriminator(fake_data) - g_loss = generator_loss(fake_output) - - # 更新生成器参数 - g_optimizer.clear_grad() - g_loss.backward() - g_optimizer.step() + for i in range(cfg.TRAIN.epochs): + logger.message(f"\nEpoch: {i + 1}\n") + optimizer_discriminator.clear_grad() + solver_discriminator.train() + optimizer_generator.clear_grad() + solver_generator.train() ``` ## 4. 评估指标 @@ -171,11 +211,8 @@ for iteration in range(ITERATIONS): ### 4.1 Inception Score (IS) 用于评估生成图像的质量和多样性。 -### 4.2 Fréchet Inception Distance (FID) -测量生成图像分布与真实图像分布之间的距离。 - -### 4.3 生成样本可视化 -定期保存生成的样本,用于直观评估模型性能。 +### 4.2 生成样本可视化 +保存生成的样本,用于直观评估模型性能。 ## 5. 与 PaddleScience 集成 我们将设计一个模块化的实现,便于与 PaddleScience 集成: @@ -184,22 +221,18 @@ for iteration in range(ITERATIONS): ``` PaddleScience/ └── examples/ - └── wgan_gp/ - ├── __init__.py - ├── utils/ - │ ├── __init__.py - │ ├── losses.py # 损失函数 - │ ├── metrics.py # 评估指标 - │ └── visualization.py # 可视化工具 - ├── models/ - │ ├── __init__.py - │ ├── base_gan.py # GAN 基类 - │ ├── wgan.py # WGAN 实现 - │ └── wgan_gp.py # WGAN-GP 实现 - └── cases/ - ├── wgan_gp_toy.py # 玩具数据集示例 - ├── wgan_gp_mnist.py # MNIST 示例 - └── wgan_gp_cifar.py # CIFAR-10 示例 + └── wgangp/ + ├── conf + │ ├── wgangp_cifar10.yaml # CIFAR-10 配置文件 + │ ├── wgangp_mnist.yaml # MNIST 配置文件 + │ └── wgangp_toy.yaml # 玩具数据集配置文件 + ├── function.py # 损失函数、评估指标、可视化工具 + ├── wgangp_cifr10.py # CIFAR-10 示例 + ├── wgangp_cifar10_model.py # CIFAR-10实验模型 + ├── wgangp_mnist.py # MNIST 示例 + ├── wgangp_mnist_model.py # MNIST实验模型 + └── wgangp_toy.py # 玩具数据集示例 + └── wgangp_toy_model.py # 玩具数据集实验模型 ``` ### 5.2 接口设计 @@ -207,27 +240,123 @@ PaddleScience/ ```python # 示例用法 -from models.wgan_gp import WGAN_GP - -# 创建模型 -model = WGAN_GP( - generator=generator_network, - discriminator=discriminator_network, - lambda_gp=10.0, - critic_iters=5 -) - -# 训练模型 -model.train( - train_data=dataset, - batch_size=64, - iterations=100000, - g_learning_rate=1e-4, - d_learning_rate=1e-4 -) - -# 生成样本 -samples = model.generate(num_samples=100) +import os +import paddle +from functions import Cifar10DisFuncs +from functions import Cifar10GenFuncs +from functions import load_cifar10 +from omegaconf import DictConfig +from wgangp_cifar10_model import WganGpCifar10Discriminator +from wgangp_cifar10_model import WganGpCifar10Generator + +def train(cfg: DictConfig): + # set model + generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"]) + discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"]) + if cfg.TRAIN.pretrained_dis_model_path and os.path.exists( + cfg.TRAIN.pretrained_dis_model_path + ): + discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path)) + + # set Loss + generator_funcs = Cifar10GenFuncs( + **cfg["LOSS"]["gen"], discriminator_model=discriminator_model + ) + discriminator_funcs = Cifar10DisFuncs( + **cfg["LOSS"]["dis"], discriminator_model=discriminator_model + ) + + # set dataloader + inputs, labels = load_cifar10(**cfg["DATA"]) + dataloader_cfg = { + "dataset": { + "name": cfg["EVAL"]["dataset"]["name"], + "input": inputs, + "label": labels, + }, + "sampler": { + **cfg["TRAIN"]["sampler"], + }, + "batch_size": cfg["TRAIN"]["batch_size"], + "use_shared_memory": cfg["TRAIN"]["use_shared_memory"], + "num_workers": cfg["TRAIN"]["num_workers"], + "drop_last": cfg["TRAIN"]["drop_last"], + } + + # set constraint + constraint_generator = ppsci.constraint.SupervisedConstraint( + dataloader_cfg=dataloader_cfg, + loss=ppsci.loss.FunctionalLoss(generator_funcs.loss), + output_expr={"labels": lambda out: out["labels"]}, + name="constraint_generator", + ) + constraint_generator_dict = {constraint_generator.name: constraint_generator} + + constraint_discriminator = ppsci.constraint.SupervisedConstraint( + dataloader_cfg=dataloader_cfg, + loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss), + output_expr={"labels": lambda out: out["labels"]}, + name="constraint_discriminator", + ) + constraint_discriminator_dict = { + constraint_discriminator.name: constraint_discriminator + } + + # set optimizer + lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])() + lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])() + + optimizer_generator = ppsci.optimizer.Adam( + learning_rate=lr_scheduler_generator, + beta1=cfg["TRAIN"]["optimizer"]["beta1"], + beta2=cfg["TRAIN"]["optimizer"]["beta2"], + ) + optimizer_discriminator = ppsci.optimizer.Adam( + learning_rate=lr_scheduler_discriminator, + beta1=cfg["TRAIN"]["optimizer"]["beta1"], + beta2=cfg["TRAIN"]["optimizer"]["beta2"], + ) + optimizer_generator = optimizer_generator(generator_model) + optimizer_discriminator = optimizer_discriminator(discriminator_model) + + # initialize solver + solver_generator = ppsci.solver.Solver( + model=generator_model, + output_dir=os.path.join(cfg.output_dir, "generator"), + constraint=constraint_generator_dict, + optimizer=optimizer_generator, + epochs=cfg.TRAIN.epochs_gen, + iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen, + pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path, + ) + solver_discriminator = ppsci.solver.Solver( + model=generator_model, + output_dir=os.path.join(cfg.output_dir, "discriminator"), + constraint=constraint_discriminator_dict, + optimizer=optimizer_discriminator, + epochs=cfg.TRAIN.epochs_dis, + iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis, + pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path, + ) + + # train + for i in range(cfg.TRAIN.epochs): + logger.message(f"\nEpoch: {i + 1}\n") + optimizer_discriminator.clear_grad() + solver_discriminator.train() + optimizer_generator.clear_grad() + solver_generator.train() + + # save model weight + paddle.save( + generator_model.state_dict(), + os.path.join(cfg.output_dir, "model_generator.pdparams"), + ) + paddle.save( + discriminator_model.state_dict(), + os.path.join(cfg.output_dir, "model_discriminator.pdparams"), + ) + ``` # 六、测试验收的考量