diff --git a/search_space_compression.py b/search_space_compression.py index 92c23b1..8748efb 100644 --- a/search_space_compression.py +++ b/search_space_compression.py @@ -419,7 +419,12 @@ def forward(self, latent): else: name = self.shared_attr_vocab.index_to_name[name_index] - attr_dims = max(1, int(math.ceil(F.softplus(self.attr_dims_head(embedding)).item()))) + raw_dim = self.attr_dims_head(embedding) + raw_dim = torch.nan_to_num(raw_dim, nan=0.0, posinf=10.0, neginf=-10.0) + attr_dim_val = F.softplus(raw_dim).item() + if math.isnan(attr_dim_val) or math.isinf(attr_dim_val): + attr_dim_val = 1.0 + attr_dims = max(1, int(math.ceil(attr_dim_val))) values = [] value_hidden = embedding.unsqueeze(0).unsqueeze(1) value_input = torch.zeros(1, 1, 1, device=device)