11# 在 PaddleScience 中复现 WGAN-GP 模型
22
3- | 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
4- | --- | --- |
5- | 提交作者 | robinbg |
6- | 提交时间 | 2025-04-04 |
7- | 版本号 | V1.0 |
8- | 依赖飞桨版本 | develop |
3+ | 任务名称 | 在 PaddleScience 中复现 WGAN-GP 模型 |
4+ | --- | ------------------------------------------- |
5+ | 提交作者 | robinbg、XvLingWYY |
6+ | 提交时间 | 2025-04-04 |
7+ | 版本号 | V1.0 |
8+ | 依赖飞桨版本 | develop |
99| 文件名 | 20250404_add_wgan_gp_for_paddlescience.md |
1010
1111# 一、概述
@@ -75,44 +75,106 @@ WGAN-GP 的核心在于其损失函数和梯度惩罚项的计算。以下是主
7575
7676### 1.1 损失函数
7777``` python
78- # 生成器损失
79- def generator_loss (fake_output ):
80- return - paddle.mean(fake_output)
81-
82- # 判别器损失(包含梯度惩罚)
83- def discriminator_loss (real_output , fake_output , gradient_penalty ):
84- return paddle.mean(fake_output) - paddle.mean(real_output) + LAMBDA * gradient_penalty
78+ # CIFAR10实验中生成器损失
79+ class Cifar10GenFuncs :
80+ """
81+ Loss function for cifar10 generator
82+ Args
83+ discriminator_model: discriminator model
84+ acgan_scale_g: scale of acgan loss for generator
85+
86+ """
87+
88+ def __init__ (
89+ self ,
90+ discriminator_model ,
91+ acgan_scale_g = 0.1 ,
92+ ):
93+ self .crossEntropyLoss = paddle.nn.CrossEntropyLoss()
94+ self .acgan_scale_g = acgan_scale_g
95+ self .discriminator_model = discriminator_model
96+
97+ def loss (self , output_dict : Dict, * args ):
98+ fake_image = output_dict[" fake_data" ]
99+ labels = output_dict[" labels" ]
100+ outputs = self .discriminator_model({" data" : fake_image, " labels" : labels})
101+ disc_fake, disc_fake_acgan = outputs[" disc_fake" ], outputs[" disc_acgan" ]
102+ gen_cost = - paddle.mean(disc_fake)
103+ if disc_fake_acgan is not None :
104+ gen_acgan_cost = self .crossEntropyLoss(disc_fake_acgan, labels)
105+ gen_cost += self .acgan_scale_g * gen_acgan_cost
106+ return {" loss_g" : gen_cost}
107+
108+ # CIFAR10实验中判别器损失
109+ class Cifar10DisFuncs :
110+ """
111+ Loss function for cifar10 discriminator
112+ Args
113+ discriminator_model: discriminator model
114+ acgan_scale: scale of acgan loss for discriminator
115+
116+ """
117+
118+ def __init__ (self , discriminator_model , acgan_scale ):
119+ self .crossEntropyLoss = paddle.nn.CrossEntropyLoss()
120+ self .acgan_scale = acgan_scale
121+ self .discriminator_model = discriminator_model
122+
123+ def loss (self , output_dict : Dict, label_dict : Dict, * args ):
124+ fake_image = output_dict[" fake_data" ]
125+ real_image = label_dict[" real_data" ]
126+ labels = output_dict[" labels" ]
127+ disc_fake = self .discriminator_model({" data" : fake_image, " labels" : labels})[
128+ " disc_fake"
129+ ]
130+ out = self .discriminator_model({" data" : real_image, " labels" : labels})
131+ disc_real, disc_real_acgan = out[" disc_fake" ], out[" disc_acgan" ]
132+ gradient_penalty = self .compute_gradient_penalty(real_image, fake_image, labels)
133+ disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
134+ disc_wgan = disc_cost + gradient_penalty
135+ if disc_real_acgan is not None :
136+ disc_acgan_cost = self .crossEntropyLoss(disc_real_acgan, labels)
137+ disc_acgan = disc_acgan_cost.sum()
138+ disc_cost = disc_wgan + (self .acgan_scale * disc_acgan)
139+ else :
140+ disc_cost = disc_wgan
141+ return {" loss_d" : disc_cost}
142+
143+ def compute_gradient_penalty (self , real_data , fake_data , labels ):
144+ differences = fake_data - real_data
145+ alpha = paddle.rand([fake_data.shape[0 ], 1 ])
146+ interpolates = real_data + (alpha * differences)
147+ gradients = paddle.grad(
148+ outputs = self .discriminator_model({" data" : interpolates, " labels" : labels})[
149+ " disc_fake"
150+ ],
151+ inputs = interpolates,
152+ create_graph = True ,
153+ retain_graph = False ,
154+ )[0 ]
155+ slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis = 1 ))
156+ gradient_penalty = 10 * paddle.mean((slopes - 1.0 ) ** 2 )
157+ return gradient_penalty
85158```
86159
87160### 1.2 梯度惩罚计算
88161``` python
89- def gradient_penalty (discriminator , real_samples , fake_samples ):
90- # 生成随机插值系数
91- alpha = paddle.rand(shape = [real_samples.shape[0 ], 1 , 1 , 1 ])
92-
93- # 创建真实样本和生成样本之间的插值
94- interpolates = real_samples + alpha * (fake_samples - real_samples)
95- interpolates.stop_gradient = False
96-
97- # 计算判别器对插值样本的输出
98- disc_interpolates = discriminator(interpolates)
99-
100- # 计算梯度
101- gradients = paddle.grad(
102- outputs = disc_interpolates,
103- inputs = interpolates,
104- grad_outputs = paddle.ones_like(disc_interpolates),
105- create_graph = True ,
106- retain_graph = True
107- )[0 ]
108-
109- # 计算梯度范数
110- gradients_norm = paddle.sqrt(paddle.sum(paddle.square(gradients), axis = [1 , 2 , 3 ]))
111-
112- # 计算梯度惩罚
113- gradient_penalty = paddle.mean(paddle.square(gradients_norm - 1.0 ))
114-
115- return gradient_penalty
162+ # CIFAR-10 判别器中的梯度惩罚计算
163+ def compute_gradient_penalty (self , real_data , fake_data , labels ):
164+ differences = fake_data - real_data
165+ alpha = paddle.rand([fake_data.shape[0 ], 1 ])
166+ interpolates = real_data + (alpha * differences)
167+ gradients = paddle.grad(
168+ outputs = self .discriminator_model({" data" : interpolates, " labels" : labels})[
169+ " disc_fake"
170+ ],
171+ inputs = interpolates,
172+ create_graph = True ,
173+ retain_graph = False ,
174+ )[0 ]
175+ slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis = 1 ))
176+ gradient_penalty = 10 * paddle.mean((slopes - 1.0 ) ** 2 )
177+ return gradient_penalty
116178```
117179
118180## 2. 网络架构
@@ -135,34 +197,12 @@ WGAN-GP 的训练流程与标准 GAN 有所不同,主要区别在于:
135197
136198``` python
137199# 训练循环示例
138- for iteration in range (ITERATIONS ):
139- # 训练判别器
140- for _ in range (CRITIC_ITERS ):
141- real_data = next (data_iterator)
142- noise = paddle.randn([BATCH_SIZE , NOISE_DIM ])
143-
144- # 计算判别器损失
145- fake_data = generator(noise)
146- real_output = discriminator(real_data)
147- fake_output = discriminator(fake_data)
148- gp = gradient_penalty(discriminator, real_data, fake_data)
149- d_loss = discriminator_loss(real_output, fake_output, gp)
150-
151- # 更新判别器参数
152- d_optimizer.clear_grad()
153- d_loss.backward()
154- d_optimizer.step()
155-
156- # 训练生成器
157- noise = paddle.randn([BATCH_SIZE , NOISE_DIM ])
158- fake_data = generator(noise)
159- fake_output = discriminator(fake_data)
160- g_loss = generator_loss(fake_output)
161-
162- # 更新生成器参数
163- g_optimizer.clear_grad()
164- g_loss.backward()
165- g_optimizer.step()
200+ for i in range (cfg.TRAIN .epochs):
201+ logger.message(f " \n Epoch: { i + 1 } \n " )
202+ optimizer_discriminator.clear_grad()
203+ solver_discriminator.train()
204+ optimizer_generator.clear_grad()
205+ solver_generator.train()
166206```
167207
168208## 4. 评估指标
@@ -171,11 +211,8 @@ for iteration in range(ITERATIONS):
171211### 4.1 Inception Score (IS)
172212用于评估生成图像的质量和多样性。
173213
174- ### 4.2 Fréchet Inception Distance (FID)
175- 测量生成图像分布与真实图像分布之间的距离。
176-
177- ### 4.3 生成样本可视化
178- 定期保存生成的样本,用于直观评估模型性能。
214+ ### 4.2 生成样本可视化
215+ 保存生成的样本,用于直观评估模型性能。
179216
180217## 5. 与 PaddleScience 集成
181218我们将设计一个模块化的实现,便于与 PaddleScience 集成:
@@ -184,50 +221,142 @@ for iteration in range(ITERATIONS):
184221```
185222PaddleScience/
186223└── examples/
187- └── wgan_gp/
188- ├── __init__.py
189- ├── utils/
190- │ ├── __init__.py
191- │ ├── losses.py # 损失函数
192- │ ├── metrics.py # 评估指标
193- │ └── visualization.py # 可视化工具
194- ├── models/
195- │ ├── __init__.py
196- │ ├── base_gan.py # GAN 基类
197- │ ├── wgan.py # WGAN 实现
198- │ └── wgan_gp.py # WGAN-GP 实现
199- └── cases/
200- ├── wgan_gp_toy.py # 玩具数据集示例
201- ├── wgan_gp_mnist.py # MNIST 示例
202- └── wgan_gp_cifar.py # CIFAR-10 示例
224+ └── wgangp/
225+ ├── conf
226+ │ ├── wgangp_cifar10.yaml # CIFAR-10 配置文件
227+ │ ├── wgangp_mnist.yaml # MNIST 配置文件
228+ │ └── wgangp_toy.yaml # 玩具数据集配置文件
229+ ├── function.py # 损失函数、评估指标、可视化工具
230+ ├── wgangp_cifr10.py # CIFAR-10 示例
231+ ├── wgangp_cifar10_model.py # CIFAR-10实验模型
232+ ├── wgangp_mnist.py # MNIST 示例
233+ ├── wgangp_mnist_model.py # MNIST实验模型
234+ └── wgangp_toy.py # 玩具数据集示例
235+ └── wgangp_toy_model.py # 玩具数据集实验模型
203236```
204237
205238### 5.2 接口设计
206239提供简洁统一的接口,方便用户使用:
207240
208241``` python
209242# 示例用法
210- from models.wgan_gp import WGAN_GP
211-
212- # 创建模型
213- model = WGAN_GP(
214- generator = generator_network,
215- discriminator = discriminator_network,
216- lambda_gp = 10.0 ,
217- critic_iters = 5
218- )
219-
220- # 训练模型
221- model.train(
222- train_data = dataset,
223- batch_size = 64 ,
224- iterations = 100000 ,
225- g_learning_rate = 1e-4 ,
226- d_learning_rate = 1e-4
227- )
228-
229- # 生成样本
230- samples = model.generate(num_samples = 100 )
243+ import os
244+ import paddle
245+ from functions import Cifar10DisFuncs
246+ from functions import Cifar10GenFuncs
247+ from functions import load_cifar10
248+ from omegaconf import DictConfig
249+ from wgangp_cifar10_model import WganGpCifar10Discriminator
250+ from wgangp_cifar10_model import WganGpCifar10Generator
251+
252+ def train (cfg : DictConfig):
253+ # set model
254+ generator_model = WganGpCifar10Generator(** cfg[" MODEL" ][" gen_net" ])
255+ discriminator_model = WganGpCifar10Discriminator(** cfg[" MODEL" ][" dis_net" ])
256+ if cfg.TRAIN .pretrained_dis_model_path and os.path.exists(
257+ cfg.TRAIN .pretrained_dis_model_path
258+ ):
259+ discriminator_model.load_dict(paddle.load(cfg.TRAIN .pretrained_dis_model_path))
260+
261+ # set Loss
262+ generator_funcs = Cifar10GenFuncs(
263+ ** cfg[" LOSS" ][" gen" ], discriminator_model = discriminator_model
264+ )
265+ discriminator_funcs = Cifar10DisFuncs(
266+ ** cfg[" LOSS" ][" dis" ], discriminator_model = discriminator_model
267+ )
268+
269+ # set dataloader
270+ inputs, labels = load_cifar10(** cfg[" DATA" ])
271+ dataloader_cfg = {
272+ " dataset" : {
273+ " name" : cfg[" EVAL" ][" dataset" ][" name" ],
274+ " input" : inputs,
275+ " label" : labels,
276+ },
277+ " sampler" : {
278+ ** cfg[" TRAIN" ][" sampler" ],
279+ },
280+ " batch_size" : cfg[" TRAIN" ][" batch_size" ],
281+ " use_shared_memory" : cfg[" TRAIN" ][" use_shared_memory" ],
282+ " num_workers" : cfg[" TRAIN" ][" num_workers" ],
283+ " drop_last" : cfg[" TRAIN" ][" drop_last" ],
284+ }
285+
286+ # set constraint
287+ constraint_generator = ppsci.constraint.SupervisedConstraint(
288+ dataloader_cfg = dataloader_cfg,
289+ loss = ppsci.loss.FunctionalLoss(generator_funcs.loss),
290+ output_expr = {" labels" : lambda out : out[" labels" ]},
291+ name = " constraint_generator" ,
292+ )
293+ constraint_generator_dict = {constraint_generator.name: constraint_generator}
294+
295+ constraint_discriminator = ppsci.constraint.SupervisedConstraint(
296+ dataloader_cfg = dataloader_cfg,
297+ loss = ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
298+ output_expr = {" labels" : lambda out : out[" labels" ]},
299+ name = " constraint_discriminator" ,
300+ )
301+ constraint_discriminator_dict = {
302+ constraint_discriminator.name: constraint_discriminator
303+ }
304+
305+ # set optimizer
306+ lr_scheduler_generator = Linear(** cfg[" TRAIN" ][" lr_scheduler_gen" ])()
307+ lr_scheduler_discriminator = Linear(** cfg[" TRAIN" ][" lr_scheduler_dis" ])()
308+
309+ optimizer_generator = ppsci.optimizer.Adam(
310+ learning_rate = lr_scheduler_generator,
311+ beta1 = cfg[" TRAIN" ][" optimizer" ][" beta1" ],
312+ beta2 = cfg[" TRAIN" ][" optimizer" ][" beta2" ],
313+ )
314+ optimizer_discriminator = ppsci.optimizer.Adam(
315+ learning_rate = lr_scheduler_discriminator,
316+ beta1 = cfg[" TRAIN" ][" optimizer" ][" beta1" ],
317+ beta2 = cfg[" TRAIN" ][" optimizer" ][" beta2" ],
318+ )
319+ optimizer_generator = optimizer_generator(generator_model)
320+ optimizer_discriminator = optimizer_discriminator(discriminator_model)
321+
322+ # initialize solver
323+ solver_generator = ppsci.solver.Solver(
324+ model = generator_model,
325+ output_dir = os.path.join(cfg.output_dir, " generator" ),
326+ constraint = constraint_generator_dict,
327+ optimizer = optimizer_generator,
328+ epochs = cfg.TRAIN .epochs_gen,
329+ iters_per_epoch = cfg.TRAIN .iters_per_epoch_gen,
330+ pretrained_model_path = cfg.TRAIN .pretrained_gen_model_path,
331+ )
332+ solver_discriminator = ppsci.solver.Solver(
333+ model = generator_model,
334+ output_dir = os.path.join(cfg.output_dir, " discriminator" ),
335+ constraint = constraint_discriminator_dict,
336+ optimizer = optimizer_discriminator,
337+ epochs = cfg.TRAIN .epochs_dis,
338+ iters_per_epoch = cfg.TRAIN .iters_per_epoch_dis,
339+ pretrained_model_path = cfg.TRAIN .pretrained_gen_model_path,
340+ )
341+
342+ # train
343+ for i in range (cfg.TRAIN .epochs):
344+ logger.message(f " \n Epoch: { i + 1 } \n " )
345+ optimizer_discriminator.clear_grad()
346+ solver_discriminator.train()
347+ optimizer_generator.clear_grad()
348+ solver_generator.train()
349+
350+ # save model weight
351+ paddle.save(
352+ generator_model.state_dict(),
353+ os.path.join(cfg.output_dir, " model_generator.pdparams" ),
354+ )
355+ paddle.save(
356+ discriminator_model.state_dict(),
357+ os.path.join(cfg.output_dir, " model_discriminator.pdparams" ),
358+ )
359+
231360```
232361
233362# 六、测试验收的考量
0 commit comments