16
16
from typing import Callable , List , Optional
17
17
18
18
import torch
19
+ from torch .nn import Module
19
20
20
21
from sparseml .core import State
21
22
from sparseml .core .model .pytorch import ModifiableModelPyTorch
@@ -190,19 +191,7 @@ def _apply_smoothing(self):
190
191
smooth_layer = mapping .smooth_layer
191
192
balance_layers = mapping .balance_layers
192
193
193
- # get the channel-wise dynamic range for each layer to be balanced
194
- weight_scales = []
195
- for layer in balance_layers :
196
- scale = layer .weight .abs ().max (dim = 0 , keepdim = True )[0 ]
197
- weight_scales .append (scale )
198
- weight_scales = 2.0 * torch .cat (weight_scales , dim = 0 ).max (dim = 0 )[0 ]
199
-
200
- # calculate the amount of smoothing to apply
201
- # s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha)
202
- # where j is the input channel, alpha is smoothing strength
203
- scales = activation_scales .pow (self .smoothing_strength ) / weight_scales .pow (
204
- 1 - self .smoothing_strength
205
- )
194
+ scales = self ._calculate_smoothing_scales (balance_layers , activation_scales )
206
195
207
196
# invert the smoothing in the following layers
208
197
for layer in balance_layers :
@@ -215,3 +204,29 @@ def _apply_smoothing(self):
215
204
smooth_layer .weight .div_ (scales .view (- 1 , 1 ))
216
205
if hasattr (smooth_layer , "bias" ):
217
206
smooth_layer .bias .div_ (scales )
207
+
208
+ def _calculate_smoothing_scales (
209
+ self , balance_layers : List [Module ], activation_scales : torch .Tensor
210
+ ) -> List [float ]:
211
+ """
212
+ Calculate how much smoothing to apply to each channel based on the dynamic
213
+ range of the activation and the following weights
214
+
215
+ :param balance_layers: layers to offset activation smoothing to
216
+ :param activation_scales: channel-wise dynamic range of activation to smooth
217
+ :return: channel-wise scales to use for smoothing activation
218
+ """
219
+ # get the channel-wise dynamic range for each layer to be balanced
220
+ weight_scales = []
221
+ for layer in balance_layers :
222
+ scale = layer .weight .abs ().max (dim = 0 , keepdim = True )[0 ]
223
+ weight_scales .append (scale )
224
+ weight_scales = 2.0 * torch .cat (weight_scales , dim = 0 ).max (dim = 0 )[0 ]
225
+
226
+ # calculate the amount of smoothing to apply
227
+ # s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha)
228
+ # where j is the input channel, alpha is smoothing strength
229
+ scales = activation_scales .pow (self .smoothing_strength ) / weight_scales .pow (
230
+ 1 - self .smoothing_strength
231
+ )
232
+ return scales
0 commit comments