diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 3cc6e08253..f53a9cdfc2 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -151,7 +151,7 @@ def __init__( self.num_features = in_features self.pre_logits = nn.Identity() - self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.num_features, num_classes, bias=bias) if num_classes > 0 else nn.Identity() self.head_dropout = nn.Dropout(drop_rate) def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):