Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ you need download pretrained bert model and xlnet model.
8. Modify configuration information in `pybert/configs/basic_config.py`(the path of data,...).
9. Run `python run_bert.py --do_data` to preprocess data.
10. Run `python run_bert.py --do_train --save_best --do_lower_case` to fine tuning bert model.
11. Run `run_bert.py --do_test --do_lower_case` to predict new data.
11. Run `python run_bert.py --do_test --do_lower_case` to predict new data.

### training

Expand Down
21 changes: 12 additions & 9 deletions pybert/callback/earlystopping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from ..common.tools import logger


class EarlyStopping(object):
'''
"""Stop training when a monitored quantity has stopped improving.
Expand Down Expand Up @@ -35,21 +37,22 @@ class EarlyStopping(object):
monitor: 计算指标
baseline: 基线
'''

def __init__(self,
min_delta = 0,
patience = 10,
verbose = 1,
mode = 'min',
monitor = 'loss',
baseline = None):
min_delta=0,
patience=10,
verbose=1,
mode='min',
monitor='loss',
baseline=None):

self.baseline = baseline
self.patience = patience
self.verbose = verbose
self.min_delta = min_delta
self.monitor = monitor

assert mode in ['min','max']
assert mode in ['min', 'max']

if mode == 'min':
self.monitor_op = np.less
Expand All @@ -70,13 +73,13 @@ def reset(self):
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def epoch_step(self,current):
def epoch_step(self, current):
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
if self.verbose >0:
if self.verbose > 0:
logger.info(f"{self.patience} epochs with no improvement after which training will be stopped")
self.stop_training = True
2 changes: 1 addition & 1 deletion pybert/callback/trainingmonitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def epoch_step(self, logs={}):
for (k, v) in logs.items():
l = self.H.get(k, [])
# np.float32会报错
if not isinstance(v, np.float):
if not isinstance(v, float):
v = round(float(v), 4)
l.append(v)
self.H[k] = l
Expand Down
4 changes: 2 additions & 2 deletions pybert/io/albert_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def create_examples(self,lines,example_type,cached_examples_file):
text_a = line[0]
label = line[1]
if isinstance(label,str):
label = [np.float(x) for x in label.split(",")]
label = [float(x) for x in label.split(",")]
else:
label = [np.float(x) for x in list(label)]
label = [float(x) for x in list(label)]
text_b = None
example = InputExample(guid = guid,text_a = text_a,text_b=text_b,label= label)
examples.append(example)
Expand Down
75 changes: 39 additions & 36 deletions pybert/io/bert_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import TensorDataset
from transformers import BertTokenizer


class InputExample(object):
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Expand All @@ -19,27 +20,30 @@ def __init__(self, guid, text_a, text_b=None, label=None):
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
self.label = label


class InputFeature(object):
'''
A single set of features of data.
'''
def __init__(self,input_ids,input_mask,segment_ids,label_id,input_len):
self.input_ids = input_ids
self.input_mask = input_mask

def __init__(self, input_ids, input_mask, segment_ids, label_id, input_len):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.label_id = label_id
self.input_len = input_len


class BertProcessor(object):
"""Base class for data converters for sequence classification data sets."""

def __init__(self,vocab_path,do_lower_case):
self.tokenizer = BertTokenizer(vocab_path,do_lower_case)
def __init__(self, vocab_path, do_lower_case):
self.tokenizer = BertTokenizer(vocab_path, do_lower_case)

def get_train(self, data_file):
"""Gets a collection of `InputExample`s for the train set."""
Expand All @@ -49,23 +53,23 @@ def get_dev(self, data_file):
"""Gets a collection of `InputExample`s for the dev set."""
return self.read_data(data_file)

def get_test(self,lines):
def get_test(self, lines):
return lines

def get_labels(self):
"""Gets the list of labels for this data set."""
return ["toxic","severe_toxic","obscene","threat","insult","identity_hate"]
return ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

@classmethod
def read_data(cls, input_file,quotechar = None):
def read_data(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
if 'pkl' in str(input_file):
lines = load_pickle(input_file)
else:
lines = input_file
return lines

def truncate_seq_pair(self,tokens_a,tokens_b,max_length):
def truncate_seq_pair(self, tokens_a, tokens_b, max_length):
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
Expand All @@ -79,33 +83,33 @@ def truncate_seq_pair(self,tokens_a,tokens_b,max_length):
else:
tokens_b.pop()

def create_examples(self,lines,example_type,cached_examples_file):
def create_examples(self, lines, example_type, cached_examples_file):
'''
Creates examples for data
'''
pbar = ProgressBar(n_total = len(lines),desc='create examples')
pbar = ProgressBar(n_total=len(lines), desc='create examples')
if cached_examples_file.exists():
logger.info("Loading examples from cached file %s", cached_examples_file)
examples = torch.load(cached_examples_file)
else:
examples = []
for i,line in enumerate(lines):
guid = '%s-%d'%(example_type,i)
for i, line in enumerate(lines):
guid = '%s-%d' % (example_type, i)
text_a = line[0]
label = line[1]
if isinstance(label,str):
label = [np.float(x) for x in label.split(",")]
if isinstance(label, str):
label = [float(x) for x in label.split(",")]
else:
label = [np.float(x) for x in list(label)]
label = [float(x) for x in list(label)]
text_b = None
example = InputExample(guid = guid,text_a = text_a,text_b=text_b,label= label)
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
pbar(step=i)
logger.info("Saving examples into cached file %s", cached_examples_file)
torch.save(examples, cached_examples_file)
return examples

def create_features(self,examples,max_seq_len,cached_features_file):
def create_features(self, examples, max_seq_len, cached_features_file):
'''
# The convention in BERT is:
# (a) For sequence pairs:
Expand All @@ -115,13 +119,13 @@ def create_features(self,examples,max_seq_len,cached_features_file):
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
'''
pbar = ProgressBar(n_total=len(examples),desc='create features')
pbar = ProgressBar(n_total=len(examples), desc='create features')
if cached_features_file.exists():
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
features = []
for ex_id,example in enumerate(examples):
for ex_id, example in enumerate(examples):
tokens_a = self.tokenizer.tokenize(example.text_a)
tokens_b = None
label_id = example.label
Expand All @@ -131,7 +135,7 @@ def create_features(self,examples,max_seq_len,cached_features_file):
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self.truncate_seq_pair(tokens_a,tokens_b,max_length = max_seq_len - 3)
self.truncate_seq_pair(tokens_a, tokens_b, max_length=max_seq_len - 3)
else:
# Account for [CLS] and [SEP] with '-2'
if len(tokens_a) > max_seq_len - 2:
Expand All @@ -147,8 +151,8 @@ def create_features(self,examples,max_seq_len,cached_features_file):
padding = [0] * (max_seq_len - len(input_ids))
input_len = len(input_ids)

input_ids += padding
input_mask += padding
input_ids += padding
input_mask += padding
segment_ids += padding

assert len(input_ids) == max_seq_len
Expand All @@ -163,27 +167,26 @@ def create_features(self,examples,max_seq_len,cached_features_file):
logger.info(f"input_mask: {' '.join([str(x) for x in input_mask])}")
logger.info(f"segment_ids: {' '.join([str(x) for x in segment_ids])}")

feature = InputFeature(input_ids = input_ids,
input_mask = input_mask,
segment_ids = segment_ids,
label_id = label_id,
input_len = input_len)
feature = InputFeature(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
input_len=input_len)
features.append(feature)
pbar(step=ex_id)
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
return features

def create_dataset(self,features,is_sorted = False):
def create_dataset(self, features, is_sorted=False):
# Convert to Tensors and build dataset
if is_sorted:
logger.info("sorted data by th length of input")
features = sorted(features,key=lambda x:x.input_len,reverse=True)
features = sorted(features, key=lambda x: x.input_len, reverse=True)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features],dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
all_input_lens = torch.tensor([f.input_len for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_input_lens)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_input_lens)
return dataset

4 changes: 2 additions & 2 deletions pybert/io/xlnet_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def create_examples(self,lines,example_type,cached_examples_file):
text_a = line[0]
label = line[1]
if isinstance(label,str):
label = [np.float(x) for x in label.split(",")]
label = [float(x) for x in label.split(",")]
else:
label = [np.float(x) for x in list(label)]
label = [float(x) for x in list(label)]
text_b = None
example = InputExample(guid = guid,text_a = text_a,text_b=text_b,label= label)
examples.append(example)
Expand Down
2 changes: 1 addition & 1 deletion pybert/model/bert_for_multi_label.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
from transformers.modeling_bert import BertPreTrainedModel, BertModel
from transformers import BertPreTrainedModel, BertModel

class BertForMultiLable(BertPreTrainedModel):
def __init__(self, config):
Expand Down
Loading