@@ -124,25 +124,6 @@ def _build_input_queue(
124
124
):
125
125
"""Build an input queue for the given split."""
126
126
127
- def _eval_batch (
128
- self ,
129
- params : spec .ParameterContainer ,
130
- batch : Dict [str , spec .Tensor ],
131
- model_state : spec .ModelAuxiliaryState ,
132
- rng : spec .RandomState ,
133
- ) -> spec .Tensor :
134
- """Evaluate the model on a single batch."""
135
- logits , _ = self .model_fn (
136
- params ,
137
- batch ,
138
- model_state ,
139
- spec .ForwardPassMode .EVAL ,
140
- rng ,
141
- update_batch_norm = False ,
142
- )
143
-
144
- loss_dict = self .loss_fn (batch ['targets' ], logits )
145
- return loss_dict
146
127
147
128
def _eval_model_on_split (
148
129
self ,
@@ -181,6 +162,23 @@ def _eval_model_on_split(
181
162
eval_results ['ppl' ] = np .exp (eval_results ['loss' ])
182
163
return eval_results
183
164
165
+
166
+ def _eval_batch (self ,
167
+ params : spec .ParameterContainer ,
168
+ batch : Dict [str , spec .Tensor ],
169
+ model_state : spec .ModelAuxiliaryState ,
170
+ rng : spec .RandomState ) -> spec .Tensor :
171
+ """Evaluate the model on a single batch."""
172
+ logits , _ = self .model_fn (
173
+ params , batch , model_state , spec .ForwardPassMode .EVAL , rng , False )
174
+ # Calculate cross-entropy loss
175
+ metrics = self .compute_weighted_cross_entropy (logits , batch ['targets' ], batch ['weights' ])
176
+ return {
177
+ 'loss' : metrics ['summed' ],
178
+ 'denominator' : metrics ['n_valid_examples' ],
179
+ }
180
+
181
+
184
182
@abc .abstractmethod
185
183
def _normalize_eval_metrics (
186
184
self , num_examples : int , total_metrics : Dict [str , Any ]
0 commit comments