File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed
chapter15_Differential_Privacy Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -26,9 +26,13 @@ def get_model(name="vgg16", pretrained=True):
26
26
else :
27
27
return model
28
28
29
+
29
30
def model_norm (model_1 , model_2 ):
30
- squared_sum = 0
31
- for name , layer in model_1 .named_parameters ():
32
- # print(torch.mean(layer.data), torch.mean(model_2.state_dict()[name].data))
33
- squared_sum += torch .sum (torch .pow (layer .data - model_2 .state_dict ()[name ].data , 2 ))
34
- return math .sqrt (squared_sum )
31
+ params_1 = torch .cat ([param .view (- 1 ) for param in model_1 .parameters ()])
32
+ params_2 = torch .cat ([param .view (- 1 ) for param in model_2 .parameters ()])
33
+
34
+ return torch .norm (params_1 - params_2 , p = 2 )
35
+
36
+ def quick_model_norm (model_1 , model_2 ):
37
+ diffs = [(p1 - p2 ).view (- 1 ) for p1 , p2 in zip (model_1 .parameters (), model_2 .parameters ())]
38
+ return torch .norm (torch .cat (diffs ), p = 2 )
You can’t perform that action at this time.
0 commit comments