@@ -100,6 +100,8 @@ def __init__(self,
100100 self .posterior_rho_init = posterior_rho_init
101101 self .bias = bias
102102
103+ self .kl = 0
104+
103105 self .mu_kernel = nn .Parameter (
104106 torch .Tensor (out_channels , in_channels // groups , kernel_size ))
105107 self .rho_kernel = nn .Parameter (
@@ -150,7 +152,7 @@ def init_parameters(self):
150152 self .prior_bias_mu .data .fill_ (self .prior_mean )
151153 self .prior_bias_sigma .data .fill_ (self .prior_variance )
152154
153- def forward (self , x ):
155+ def forward (self , x , return_kl = True ):
154156
155157 # linear outputs
156158 outputs = F .conv1d (x ,
@@ -191,8 +193,11 @@ def forward(self, x):
191193 dilation = self .dilation ,
192194 groups = self .groups ) * sign_output
193195
196+ self .kl = kl
194197 # returning outputs + perturbations
195- return outputs + perturbed_outputs , kl
198+ if return_kl :
199+ return outputs + perturbed_outputs , kl
200+ return outputs + perturbed_outputs
196201
197202
198203class Conv2dFlipout (BaseVariationalLayer_ ):
@@ -244,6 +249,8 @@ def __init__(self,
244249 self .posterior_rho_init = posterior_rho_init
245250 self .bias = bias
246251
252+ self .kl = 0
253+
247254 self .mu_kernel = nn .Parameter (
248255 torch .Tensor (out_channels , in_channels // groups , kernel_size ,
249256 kernel_size ))
@@ -299,7 +306,7 @@ def init_parameters(self):
299306 self .prior_bias_mu .data .fill_ (self .prior_mean )
300307 self .prior_bias_sigma .data .fill_ (self .prior_variance )
301308
302- def forward (self , x ):
309+ def forward (self , x , return_kl = True ):
303310
304311 # linear outputs
305312 outputs = F .conv2d (x ,
@@ -340,8 +347,11 @@ def forward(self, x):
340347 dilation = self .dilation ,
341348 groups = self .groups ) * sign_output
342349
350+ self .kl = kl
343351 # returning outputs + perturbations
344- return outputs + perturbed_outputs , kl
352+ if return_kl :
353+ return outputs + perturbed_outputs , kl
354+ return outputs + perturbed_outputs
345355
346356
347357class Conv3dFlipout (BaseVariationalLayer_ ):
@@ -388,6 +398,8 @@ def __init__(self,
388398 self .groups = groups
389399 self .bias = bias
390400
401+ self .kl = 0
402+
391403 self .prior_mean = prior_mean
392404 self .prior_variance = prior_variance
393405 self .posterior_mu_init = posterior_mu_init
@@ -448,7 +460,7 @@ def init_parameters(self):
448460 self .prior_bias_mu .data .fill_ (self .prior_mean )
449461 self .prior_bias_sigma .data .fill_ (self .prior_variance )
450462
451- def forward (self , x ):
463+ def forward (self , x , return_kl = True ):
452464
453465 # linear outputs
454466 outputs = F .conv3d (x ,
@@ -489,8 +501,11 @@ def forward(self, x):
489501 dilation = self .dilation ,
490502 groups = self .groups ) * sign_output
491503
504+ self .kl = kl
492505 # returning outputs + perturbations
493- return outputs + perturbed_outputs , kl
506+ if return_kl :
507+ return outputs + perturbed_outputs , kl
508+ return outputs + perturbed_outputs
494509
495510
496511class ConvTranspose1dFlipout (BaseVariationalLayer_ ):
@@ -537,6 +552,8 @@ def __init__(self,
537552 self .groups = groups
538553 self .bias = bias
539554
555+ self .kl = 0
556+
540557 self .prior_mean = prior_mean
541558 self .prior_variance = prior_variance
542559 self .posterior_mu_init = posterior_mu_init
@@ -593,7 +610,7 @@ def init_parameters(self):
593610 self .prior_bias_mu .data .fill_ (self .prior_mean )
594611 self .prior_bias_sigma .data .fill_ (self .prior_variance )
595612
596- def forward (self , x ):
613+ def forward (self , x , return_kl = True ):
597614
598615 # linear outputs
599616 outputs = F .conv_transpose1d (x ,
@@ -635,8 +652,11 @@ def forward(self, x):
635652 dilation = self .dilation ,
636653 groups = self .groups ) * sign_output
637654
655+ self .kl = kl
638656 # returning outputs + perturbations
639- return outputs + perturbed_outputs , kl
657+ if return_kl :
658+ return outputs + perturbed_outputs , kl
659+ return outputs + perturbed_outputs
640660
641661
642662class ConvTranspose2dFlipout (BaseVariationalLayer_ ):
@@ -683,6 +703,8 @@ def __init__(self,
683703 self .groups = groups
684704 self .bias = bias
685705
706+ self .kl = 0
707+
686708 self .prior_mean = prior_mean
687709 self .prior_variance = prior_variance
688710 self .posterior_mu_init = posterior_mu_init
@@ -743,7 +765,7 @@ def init_parameters(self):
743765 self .prior_bias_mu .data .fill_ (self .prior_mean )
744766 self .prior_bias_sigma .data .fill_ (self .prior_variance )
745767
746- def forward (self , x ):
768+ def forward (self , x , return_kl = True ):
747769
748770 # linear outputs
749771 outputs = F .conv_transpose2d (x ,
@@ -785,8 +807,11 @@ def forward(self, x):
785807 dilation = self .dilation ,
786808 groups = self .groups ) * sign_output
787809
810+ self .kl = kl
788811 # returning outputs + perturbations
789- return outputs + perturbed_outputs , kl
812+ if return_kl :
813+ return outputs + perturbed_outputs , kl
814+ return outputs + perturbed_outputs
790815
791816
792817class ConvTranspose3dFlipout (BaseVariationalLayer_ ):
@@ -838,6 +863,8 @@ def __init__(self,
838863 self .posterior_rho_init = posterior_rho_init
839864 self .bias = bias
840865
866+ self .kl = 0
867+
841868 self .mu_kernel = nn .Parameter (
842869 torch .Tensor (in_channels , out_channels // groups , kernel_size ,
843870 kernel_size , kernel_size ))
@@ -893,7 +920,7 @@ def init_parameters(self):
893920 self .prior_bias_mu .data .fill_ (self .prior_mean )
894921 self .prior_bias_sigma .data .fill_ (self .prior_variance )
895922
896- def forward (self , x ):
923+ def forward (self , x , return_kl = True ):
897924
898925 # linear outputs
899926 outputs = F .conv_transpose3d (x ,
@@ -935,5 +962,8 @@ def forward(self, x):
935962 dilation = self .dilation ,
936963 groups = self .groups ) * sign_output
937964
965+ self .kl = kl
938966 # returning outputs + perturbations
939- return outputs + perturbed_outputs , kl
967+ if return_kl :
968+ return outputs + perturbed_outputs , kl
969+ return outputs + perturbed_outputs
0 commit comments