Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 239 additions & 110 deletions rfcs/Science/20250404_add_wgan_gp_for_paddlescience.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# 在 PaddleScience 中复现 WGAN-GP 模型

| 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
| --- | --- |
| 提交作者 | robinbg |
| 提交时间 | 2025-04-04 |
| 版本号 | V1.0 |
| 依赖飞桨版本 | develop |
| 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
| --- |-------------------------------------------|
| 提交作者 | robinbg、XvLingWYY |
| 提交时间 | 2025-04-04 |
| 版本号 | V1.0 |
| 依赖飞桨版本 | develop |
| 文件名 | 20250404_add_wgan_gp_for_paddlescience.md |

# 一、概述
Expand Down Expand Up @@ -75,44 +75,106 @@ WGAN-GP 的核心在于其损失函数和梯度惩罚项的计算。以下是主

### 1.1 损失函数
```python
# 生成器损失
def generator_loss(fake_output):
return -paddle.mean(fake_output)

# 判别器损失(包含梯度惩罚)
def discriminator_loss(real_output, fake_output, gradient_penalty):
return paddle.mean(fake_output) - paddle.mean(real_output) + LAMBDA * gradient_penalty
# CIFAR10实验中生成器损失
class Cifar10GenFuncs:
"""
Loss function for cifar10 generator
Args
discriminator_model: discriminator model
acgan_scale_g: scale of acgan loss for generator

"""

def __init__(
self,
discriminator_model,
acgan_scale_g=0.1,
):
self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
self.acgan_scale_g = acgan_scale_g
self.discriminator_model = discriminator_model

def loss(self, output_dict: Dict, *args):
fake_image = output_dict["fake_data"]
labels = output_dict["labels"]
outputs = self.discriminator_model({"data": fake_image, "labels": labels})
disc_fake, disc_fake_acgan = outputs["disc_fake"], outputs["disc_acgan"]
gen_cost = -paddle.mean(disc_fake)
if disc_fake_acgan is not None:
gen_acgan_cost = self.crossEntropyLoss(disc_fake_acgan, labels)
gen_cost += self.acgan_scale_g * gen_acgan_cost
return {"loss_g": gen_cost}

# CIFAR10实验中判别器损失
class Cifar10DisFuncs:
"""
Loss function for cifar10 discriminator
Args
discriminator_model: discriminator model
acgan_scale: scale of acgan loss for discriminator

"""

def __init__(self, discriminator_model, acgan_scale):
self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
self.acgan_scale = acgan_scale
self.discriminator_model = discriminator_model

def loss(self, output_dict: Dict, label_dict: Dict, *args):
fake_image = output_dict["fake_data"]
real_image = label_dict["real_data"]
labels = output_dict["labels"]
disc_fake = self.discriminator_model({"data": fake_image, "labels": labels})[
"disc_fake"
]
out = self.discriminator_model({"data": real_image, "labels": labels})
disc_real, disc_real_acgan = out["disc_fake"], out["disc_acgan"]
gradient_penalty = self.compute_gradient_penalty(real_image, fake_image, labels)
disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
disc_wgan = disc_cost + gradient_penalty
if disc_real_acgan is not None:
disc_acgan_cost = self.crossEntropyLoss(disc_real_acgan, labels)
disc_acgan = disc_acgan_cost.sum()
disc_cost = disc_wgan + (self.acgan_scale * disc_acgan)
else:
disc_cost = disc_wgan
return {"loss_d": disc_cost}

def compute_gradient_penalty(self, real_data, fake_data, labels):
differences = fake_data - real_data
alpha = paddle.rand([fake_data.shape[0], 1])
interpolates = real_data + (alpha * differences)
gradients = paddle.grad(
outputs=self.discriminator_model({"data": interpolates, "labels": labels})[
"disc_fake"
],
inputs=interpolates,
create_graph=True,
retain_graph=False,
)[0]
slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2)
return gradient_penalty
```

### 1.2 梯度惩罚计算
```python
def gradient_penalty(discriminator, real_samples, fake_samples):
# 生成随机插值系数
alpha = paddle.rand(shape=[real_samples.shape[0], 1, 1, 1])

# 创建真实样本和生成样本之间的插值
interpolates = real_samples + alpha * (fake_samples - real_samples)
interpolates.stop_gradient = False

# 计算判别器对插值样本的输出
disc_interpolates = discriminator(interpolates)

# 计算梯度
gradients = paddle.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=paddle.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True
)[0]

# 计算梯度范数
gradients_norm = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=[1, 2, 3]))

# 计算梯度惩罚
gradient_penalty = paddle.mean(paddle.square(gradients_norm - 1.0))

return gradient_penalty
# CIFAR-10 判别器中的梯度惩罚计算
def compute_gradient_penalty(self, real_data, fake_data, labels):
differences = fake_data - real_data
alpha = paddle.rand([fake_data.shape[0], 1])
interpolates = real_data + (alpha * differences)
gradients = paddle.grad(
outputs=self.discriminator_model({"data": interpolates, "labels": labels})[
"disc_fake"
],
inputs=interpolates,
create_graph=True,
retain_graph=False,
)[0]
slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2)
return gradient_penalty
```

## 2. 网络架构
Expand All @@ -135,34 +197,12 @@ WGAN-GP 的训练流程与标准 GAN 有所不同,主要区别在于:

```python
# 训练循环示例
for iteration in range(ITERATIONS):
# 训练判别器
for _ in range(CRITIC_ITERS):
real_data = next(data_iterator)
noise = paddle.randn([BATCH_SIZE, NOISE_DIM])

# 计算判别器损失
fake_data = generator(noise)
real_output = discriminator(real_data)
fake_output = discriminator(fake_data)
gp = gradient_penalty(discriminator, real_data, fake_data)
d_loss = discriminator_loss(real_output, fake_output, gp)

# 更新判别器参数
d_optimizer.clear_grad()
d_loss.backward()
d_optimizer.step()

# 训练生成器
noise = paddle.randn([BATCH_SIZE, NOISE_DIM])
fake_data = generator(noise)
fake_output = discriminator(fake_data)
g_loss = generator_loss(fake_output)

# 更新生成器参数
g_optimizer.clear_grad()
g_loss.backward()
g_optimizer.step()
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()
```

## 4. 评估指标
Expand All @@ -171,11 +211,8 @@ for iteration in range(ITERATIONS):
### 4.1 Inception Score (IS)
用于评估生成图像的质量和多样性。

### 4.2 Fréchet Inception Distance (FID)
测量生成图像分布与真实图像分布之间的距离。

### 4.3 生成样本可视化
定期保存生成的样本,用于直观评估模型性能。
### 4.2 生成样本可视化
保存生成的样本,用于直观评估模型性能。

## 5. 与 PaddleScience 集成
我们将设计一个模块化的实现,便于与 PaddleScience 集成:
Expand All @@ -184,50 +221,142 @@ for iteration in range(ITERATIONS):
```
PaddleScience/
└── examples/
└── wgan_gp/
├── __init__.py
├── utils/
│ ├── __init__.py
│ ├── losses.py # 损失函数
│ ├── metrics.py # 评估指标
│ └── visualization.py # 可视化工具
├── models/
│ ├── __init__.py
│ ├── base_gan.py # GAN 基类
│ ├── wgan.py # WGAN 实现
│ └── wgan_gp.py # WGAN-GP 实现
└── cases/
├── wgan_gp_toy.py # 玩具数据集示例
├── wgan_gp_mnist.py # MNIST 示例
└── wgan_gp_cifar.py # CIFAR-10 示例
└── wgangp/
├── conf
│ ├── wgangp_cifar10.yaml # CIFAR-10 配置文件
│ ├── wgangp_mnist.yaml # MNIST 配置文件
│ └── wgangp_toy.yaml # 玩具数据集配置文件
├── function.py # 损失函数、评估指标、可视化工具
├── wgangp_cifr10.py # CIFAR-10 示例
├── wgangp_cifar10_model.py # CIFAR-10实验模型
├── wgangp_mnist.py # MNIST 示例
├── wgangp_mnist_model.py # MNIST实验模型
└── wgangp_toy.py # 玩具数据集示例
└── wgangp_toy_model.py # 玩具数据集实验模型
```

### 5.2 接口设计
提供简洁统一的接口,方便用户使用:

```python
# 示例用法
from models.wgan_gp import WGAN_GP

# 创建模型
model = WGAN_GP(
generator=generator_network,
discriminator=discriminator_network,
lambda_gp=10.0,
critic_iters=5
)

# 训练模型
model.train(
train_data=dataset,
batch_size=64,
iterations=100000,
g_learning_rate=1e-4,
d_learning_rate=1e-4
)

# 生成样本
samples = model.generate(num_samples=100)
import os
import paddle
from functions import Cifar10DisFuncs
from functions import Cifar10GenFuncs
from functions import load_cifar10
from omegaconf import DictConfig
from wgangp_cifar10_model import WganGpCifar10Discriminator
from wgangp_cifar10_model import WganGpCifar10Generator

def train(cfg: DictConfig):
# set model
generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
cfg.TRAIN.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))

# set Loss
generator_funcs = Cifar10GenFuncs(
**cfg["LOSS"]["gen"], discriminator_model=discriminator_model
)
discriminator_funcs = Cifar10DisFuncs(
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
)

# set dataloader
inputs, labels = load_cifar10(**cfg["DATA"])
dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
"label": labels,
},
"sampler": {
**cfg["TRAIN"]["sampler"],
},
"batch_size": cfg["TRAIN"]["batch_size"],
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
"num_workers": cfg["TRAIN"]["num_workers"],
"drop_last": cfg["TRAIN"]["drop_last"],
}

# set constraint
constraint_generator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}

constraint_discriminator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
name="constraint_discriminator",
)
constraint_discriminator_dict = {
constraint_discriminator.name: constraint_discriminator
}

# set optimizer
lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()

optimizer_generator = ppsci.optimizer.Adam(
learning_rate=lr_scheduler_generator,
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_discriminator = ppsci.optimizer.Adam(
learning_rate=lr_scheduler_discriminator,
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_generator = optimizer_generator(generator_model)
optimizer_discriminator = optimizer_discriminator(discriminator_model)

# initialize solver
solver_generator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "generator"),
constraint=constraint_generator_dict,
optimizer=optimizer_generator,
epochs=cfg.TRAIN.epochs_gen,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "discriminator"),
constraint=constraint_discriminator_dict,
optimizer=optimizer_discriminator,
epochs=cfg.TRAIN.epochs_dis,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)

# train
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()

# save model weight
paddle.save(
generator_model.state_dict(),
os.path.join(cfg.output_dir, "model_generator.pdparams"),
)
paddle.save(
discriminator_model.state_dict(),
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
)

```

# 六、测试验收的考量
Expand Down