@@ -78,9 +78,7 @@ def __init__(self):
78
78
super ().__init__ ()
79
79
80
80
def forward (self , x , augment = False , profile = False , visualize = False ):
81
- y = []
82
- for module in self :
83
- y .append (module (x , augment , profile , visualize )[0 ])
81
+ y = [module (x , augment , profile , visualize )[0 ] for module in self ]
84
82
# y = torch.stack(y).max(0)[0] # max ensemble
85
83
# y = torch.stack(y).mean(0) # mean ensemble
86
84
y = torch .cat (y , 1 ) # nms ensemble
@@ -102,21 +100,19 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
102
100
t = type (m )
103
101
if t in (nn .Hardswish , nn .LeakyReLU , nn .ReLU , nn .ReLU6 , nn .SiLU , Detect , Model ):
104
102
m .inplace = inplace # torch 1.7.0 compatibility
105
- if t is Detect :
106
- if not isinstance (m .anchor_grid , list ): # new Detect Layer compatibility
107
- delattr (m , 'anchor_grid' )
108
- setattr (m , 'anchor_grid' , [torch .zeros (1 )] * m .nl )
103
+ if t is Detect and not isinstance (m .anchor_grid , list ):
104
+ delattr (m , 'anchor_grid' )
105
+ setattr (m , 'anchor_grid' , [torch .zeros (1 )] * m .nl )
109
106
elif t is Conv :
110
107
m ._non_persistent_buffers_set = set () # torch 1.6.0 compatibility
111
108
elif t is nn .Upsample and not hasattr (m , 'recompute_scale_factor' ):
112
109
m .recompute_scale_factor = None # torch 1.11.0 compatibility
113
110
114
111
if len (model ) == 1 :
115
112
return model [- 1 ] # return model
116
- else :
117
- print (f'Ensemble created with { weights } \n ' )
118
- for k in 'names' , 'nc' , 'yaml' :
119
- setattr (model , k , getattr (model [0 ], k ))
120
- model .stride = model [torch .argmax (torch .tensor ([m .stride .max () for m in model ])).int ()].stride # max stride
121
- assert all (model [0 ].nc == m .nc for m in model ), f'Models have different class counts: { [m .nc for m in model ]} '
122
- return model # return ensemble
113
+ print (f'Ensemble created with { weights } \n ' )
114
+ for k in 'names' , 'nc' , 'yaml' :
115
+ setattr (model , k , getattr (model [0 ], k ))
116
+ model .stride = model [torch .argmax (torch .tensor ([m .stride .max () for m in model ])).int ()].stride # max stride
117
+ assert all (model [0 ].nc == m .nc for m in model ), f'Models have different class counts: { [m .nc for m in model ]} '
118
+ return model # return ensemble
0 commit comments