diff --git a/lab1/text_recognizer/lit_models/base.py b/lab1/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab1/text_recognizer/lit_models/base.py +++ b/lab1/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target) diff --git a/lab2/text_recognizer/lit_models/base.py b/lab2/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab2/text_recognizer/lit_models/base.py +++ b/lab2/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target) diff --git a/lab3/text_recognizer/lit_models/base.py b/lab3/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab3/text_recognizer/lit_models/base.py +++ b/lab3/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target) diff --git a/lab4/text_recognizer/lit_models/base.py b/lab4/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab4/text_recognizer/lit_models/base.py +++ b/lab4/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target) diff --git a/lab5/text_recognizer/lit_models/base.py b/lab5/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab5/text_recognizer/lit_models/base.py +++ b/lab5/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target) diff --git a/lab7/text_recognizer/lit_models/base.py b/lab7/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab7/text_recognizer/lit_models/base.py +++ b/lab7/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target) diff --git a/lab8/text_recognizer/lit_models/base.py b/lab8/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab8/text_recognizer/lit_models/base.py +++ b/lab8/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target) diff --git a/lab9/text_recognizer/lit_models/base.py b/lab9/text_recognizer/lit_models/base.py index 9c2e8a9..7cc717b 100644 --- a/lab9/text_recognizer/lit_models/base.py +++ b/lab9/text_recognizer/lit_models/base.py @@ -21,7 +21,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: Normalized preds are not necessary for accuracy computation as we just care about argmax(). """ if preds.min() < 0 or preds.max() > 1: - preds = torch.nn.functional.softmax(preds, dim=-1) + preds = torch.nn.functional.softmax(preds, dim=1) super().update(preds=preds, target=target)