Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit f2efb4d

Browse files
markurtzdbogunowiczBenjamin
committed
[Fix] label_list not being set for NLP token classification training if distillation teacher and student labels do not match (#1414)
* [Fix] Fix label_list not being set for NLP token classification training if distillation teacher and student labels do not match * Added two fixes: omitting the labels/indices matching for student/teacher if teacher is a string; prioritizing teacher labels to student labels if teacher labels are string and student's int * revert previous int label patch - allow int labels to let given dataset be the source of truth * only override label_list when teacher and student labels sets are equal --------- Co-authored-by: Damian <[email protected]> Co-authored-by: Benjamin <[email protected]>
1 parent 2d0a6a9 commit f2efb4d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/sparseml/transformers/token_classification.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,11 @@ def main(**kwargs):
382382
},
383383
)
384384

385-
if teacher:
385+
if teacher and not isinstance(teacher, str):
386386
# check whether teacher and student have the corresponding outputs
387-
label_to_id, label_list = _check_teacher_student_outputs(teacher, label_to_id)
387+
label_to_id, label_list = _check_teacher_student_outputs(
388+
teacher, label_to_id, label_list
389+
)
388390

389391
tokenizer_src = (
390392
model_args.tokenizer_name
@@ -580,7 +582,7 @@ def compute_metrics(p):
580582

581583

582584
def _check_teacher_student_outputs(
583-
teacher: Module, label_to_id: Dict[str, int]
585+
teacher: Module, label_to_id: Dict[str, int], label_list: List[str]
584586
) -> Tuple[Dict[str, int], List[str]]:
585587
# Check that the teacher and student have the same labels and if they do,
586588
# check that the mapping between labels and ids is the same.
@@ -765,7 +767,9 @@ def _get_tokenized_dataset(
765767
# Map that sends B-Xxx label to its I-Xxx counterpart
766768
b_to_i_label = []
767769
for idx, label in enumerate(label_list):
768-
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
770+
if isinstance(label, str) and (
771+
label.startswith("B-") and label.replace("B-", "I-") in label_list
772+
):
769773
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
770774
else:
771775
b_to_i_label.append(idx)

0 commit comments

Comments
 (0)