Skip to content

Commit 2b162e8

Browse files
committed
remove _eval_batch from jax workload
1 parent f531b35 commit 2b162e8

File tree

2 files changed

+17
-36
lines changed

2 files changed

+17
-36
lines changed

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,3 @@ def _normalize_eval_metrics(
139139
del num_examples
140140
eval_denominator = total_metrics.pop('denominator')
141141
return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics)
142-
143-
144-
def _eval_batch(self,
145-
params: spec.ParameterContainer,
146-
batch: Dict[str, spec.Tensor],
147-
model_state: spec.ModelAuxiliaryState,
148-
rng: spec.RandomState) -> spec.Tensor:
149-
"""Evaluate the model on a single batch."""
150-
logits, _ = self.model_fn(
151-
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
152-
# Calculate cross-entropy loss
153-
# TODO(kasimbeg): add weights?
154-
metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights'])
155-
return {
156-
'loss': metrics['summed'],
157-
'denominator': metrics['n_valid_examples'],
158-
}

algoperf/workloads/lm/workload.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -124,25 +124,6 @@ def _build_input_queue(
124124
):
125125
"""Build an input queue for the given split."""
126126

127-
def _eval_batch(
128-
self,
129-
params: spec.ParameterContainer,
130-
batch: Dict[str, spec.Tensor],
131-
model_state: spec.ModelAuxiliaryState,
132-
rng: spec.RandomState,
133-
) -> spec.Tensor:
134-
"""Evaluate the model on a single batch."""
135-
logits, _ = self.model_fn(
136-
params,
137-
batch,
138-
model_state,
139-
spec.ForwardPassMode.EVAL,
140-
rng,
141-
update_batch_norm=False,
142-
)
143-
144-
loss_dict = self.loss_fn(batch['targets'], logits)
145-
return loss_dict
146127

147128
def _eval_model_on_split(
148129
self,
@@ -181,6 +162,23 @@ def _eval_model_on_split(
181162
eval_results['ppl'] = np.exp(eval_results['loss'])
182163
return eval_results
183164

165+
166+
def _eval_batch(self,
167+
params: spec.ParameterContainer,
168+
batch: Dict[str, spec.Tensor],
169+
model_state: spec.ModelAuxiliaryState,
170+
rng: spec.RandomState) -> spec.Tensor:
171+
"""Evaluate the model on a single batch."""
172+
logits, _ = self.model_fn(
173+
params, batch, model_state, spec.ForwardPassMode.EVAL, rng, False)
174+
# Calculate cross-entropy loss
175+
metrics = self.compute_weighted_cross_entropy(logits, batch['targets'], batch['weights'])
176+
return {
177+
'loss': metrics['summed'],
178+
'denominator': metrics['n_valid_examples'],
179+
}
180+
181+
184182
@abc.abstractmethod
185183
def _normalize_eval_metrics(
186184
self, num_examples: int, total_metrics: Dict[str, Any]

0 commit comments

Comments
 (0)