Skip to content

Commit 37d0634

Browse files
implement weight_classes for tagging tasks
1 parent 60a1a72 commit 37d0634

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/cnlpt/train_system.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,10 @@ def main(
309309
dataset.tasks_to_labels[task] = dataset.tasks_to_labels[task][1:] + [
310310
dataset.tasks_to_labels[task][0]
311311
]
312-
labels = dataset.processed_dataset["train"][task]
312+
if tagger[task]:
313+
labels = [token_label for sent in dataset.processed_dataset["train"][task] for token_label in sent.split()]
314+
else:
315+
labels = dataset.processed_dataset["train"][task]
313316
weights = []
314317
label_counts = Counter(labels)
315318
for label in dataset.tasks_to_labels[task]:

0 commit comments

Comments
 (0)