@@ -84,21 +84,19 @@ def model_fn(
84
84
85
85
86
86
def compute_weighted_cross_entropy (
87
- self ,
88
- logits : spec .Tensor ,
89
- targets : spec .Tensor ,
90
- weights : Optional [spec .Tensor ] = None ,
91
- label_smoothing : float = 0.1 ,
92
- ) -> Dict [str , spec .Tensor ]: # differentiable
87
+ self ,
88
+ logits : spec .Tensor ,
89
+ targets : spec .Tensor ,
90
+ weights : Optional [spec .Tensor ] = None ,
91
+ label_smoothing : float = 0.1 ,
92
+ ) -> Dict [str , spec .Tensor ]: # differentiable
93
93
"""Compute weighted cross entropy and entropy for log probs and targets.
94
-
95
94
Args:
96
95
logits: [batch, length, num_classes] float array.
97
96
targets: categorical targets [batch, length] int array.
98
97
weights: array of shape [batch, length].
99
98
label_smoothing: label smoothing constant, used to determine the on and off
100
99
values.
101
-
102
100
Returns:
103
101
{'summed': scalar summed loss, 'n_valid_examples': scalar number of
104
102
valid examples in batch, 'per_example': 1-d array of per-example losses}
@@ -108,18 +106,26 @@ def compute_weighted_cross_entropy(
108
106
f'Incorrect shapes. Got shape { logits .shape } logits and '
109
107
f'{ targets .shape } targets.'
110
108
)
111
- smoothed_targets = optax .smooth_labels (
112
- common_utils .onehot (targets , self ._vocab_size ), label_smoothing
113
- )
114
-
115
- per_example_losses = - jnp .sum (
116
- smoothed_targets * jax .nn .log_softmax (logits ), axis = - 1
117
- )
118
- if weights is None :
119
- weights = jnp .ones_like (targets )
120
- per_example_losses = jnp .where (weights , per_example_losses , 0.0 )
109
+ # Compute log probabilities
110
+ log_probs = jax .nn .log_softmax (logits , axis = - 1 )
111
+ # Extract log probability of the target class
112
+ # Shape: [batch, length]
113
+ target_log_probs = jnp .take_along_axis (
114
+ log_probs ,
115
+ targets [..., None ],
116
+ axis = - 1
117
+ ).squeeze (- 1 )
118
+ # Cross-entropy with smoothing: -(1 - α) * log_p[target] - α * mean(log_p)
119
+ # The above formula is easy to derive from the definition of label smoothing and cross-entropy loss.
120
+ confidence = 1.0 - label_smoothing
121
+ smoothing_term = label_smoothing / self ._vocab_size
122
+ per_example_losses = - 1.0 * (confidence * target_log_probs + smoothing_term * log_probs .sum (axis = - 1 ))
123
+ if weights is not None :
124
+ per_example_losses = jnp .where (weights , per_example_losses , 0.0 )
125
+ n_valid_examples = weights .sum ()
126
+ else :
127
+ n_valid_examples = targets .shape [0 ] * targets .shape [1 ]
121
128
summed_loss = per_example_losses .sum ()
122
- n_valid_examples = weights .sum ()
123
129
return {
124
130
'summed' : summed_loss ,
125
131
'n_valid_examples' : n_valid_examples ,
0 commit comments