diff --git a/README.md b/README.md index ffb552fe..3ba1f9dc 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,22 @@ For a demo of how to run the system in colab: [![Open In Colab](https://colab.re For chemical NER and the following for gene NER (`GENE-Y` are normalizable gene names and `GENE-N` are non-normalizable): ```gene_ner = {'acc': 0.9835720382010852, 'token_f1': [0.5503875968992249, 0.8223125230882896, 0.0, 0.0, 0.9949847518170479], 'f1': 0.7477207783371888, 'report': '\n precision recall f1-score support\n\n GENE-N 0.70 0.45 0.55 2355\n GENE-Y 0.76 0.88 0.82 5013\n\n micro avg 0.75 0.75 0.75 7368\n macro avg 0.73 0.67 0.68 7368\nweighted avg 0.74 0.75 0.73 7368\n'}``` +### Fine-tuning for concept normalization: End-to-end example +1. Prepare the concept normalization input data (train.tsv, dev.tsv, and test.tsv) use following format (.tsv file). + +| text | conceptnorm | +| --- | --- | +| pleural effusion | C0032227 | +| pulmonary consolidation | C0521530 | +| aorta tortuous | CUI-less | + +2. Prepare ontology cui text file (.txt), each line of the file will be a CUI. + +3. Prepare concept embeddings .npy file, and save it into a folder $concept_norm_path (a giant matrix with each row correspoding to the embeddings of each CUI, the order of the CUI embeddings follows the order of the ontology cui text file). + +4. Fine-tune with something like: +```python -m cnlpt.train_system # --do_train --do_eval --task_name conceptnorm --data_dir cnlp_concept_norm/ --encoder_name cambridgeltl/SapBERT-from-PubMedBERT-fulltext-mean-token --output_dir temp/ --concept_norm_path $concept_norm_path --overwrite_output_dir --cache cache --token true --num_train_epochs 5 --learning_rate 3e-5 --per_device_train_batch_size 64 --max_seq_length 16 --layer 12 --seed 24 --evals_per_epoch 1``` + ### Fine-tuning options Run ```python -m cnlpt.train_system -h``` to see all the available options. In addition to inherited Huggingface Transformers options, there are options to do the following: * Run simple baselines (use ``--model cnn --tokenizer_name roberta-base`` -- since there is no HF model then you must specify the tokenizer explicitly) diff --git a/setup.cfg b/setup.cfg index d3cfddb4..f70cf950 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ install_requires = [options.entry_points] console_scripts = + cnlpt_conceptnorm_rest = cnlpt.api.concept_norm_rest:rest cnlpt_dtr_rest = cnlpt.api.dtr_rest:rest cnlpt_event_rest = cnlpt.api.event_rest:rest cnlpt_negation_rest = cnlpt.api.negation_rest:rest diff --git a/src/cnlpt/CnlpModelForClassification.py b/src/cnlpt/CnlpModelForClassification.py index cf893bd9..ddff7933 100644 --- a/src/cnlpt/CnlpModelForClassification.py +++ b/src/cnlpt/CnlpModelForClassification.py @@ -13,11 +13,13 @@ import torch from torch import nn import logging -from torch.nn import CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss, MSELoss, Parameter from transformers.modeling_outputs import SequenceClassifierOutput -from torch.nn.functional import softmax, relu +from torch.nn.functional import softmax, relu, normalize import math import random +import numpy as np +import os logger = logging.getLogger(__name__) @@ -33,9 +35,50 @@ def forward(self, features, *kwargs): x = self.out_proj(x) return x +class CosineLayer(nn.Module): + def __init__( + self, + config, + concept_dim=88150 + ): + super(CosineLayer, self).__init__() + + self.dropout = nn.Dropout(0.1) + self.cos = nn.CosineSimilarity(dim=-1) + concept_dims = (concept_dim,config.hidden_size) + + if config.concept_norm is not None: + weights_matrix = np.load(os.path.join(config.concept_norm, + "concept_embeddings.npy")).astype(np.float32) + self.weight = Parameter(torch.from_numpy(weights_matrix), + requires_grad=True) + threshold_value = np.loadtxt( + os.path.join(config.concept_norm, "cuiless_threshold.txt")).astype(np.float32) + self.threshold = Parameter(torch.tensor(threshold_value), + requires_grad=False) + else: + self.weight = Parameter(torch.rand(concept_dims), + requires_grad=True) + torch.nn.init.xavier_uniform(self.weight) + self.threshold = Parameter(torch.tensor(0.35), requires_grad=True) + + def forward(self, features): + batch_size, fea_size = features.shape + features_norm = normalize(features) + weight_norm = normalize(self.weight) + sim_mt = torch.mm(features_norm, weight_norm.transpose(0, 1)) + cui_less_score = torch.full((batch_size, 1), 1).to( + features.device) * self.threshold.to(features.device) + similarity_score = torch.cat((sim_mt, cui_less_score), 1) + # if self.config.finetuning_task[task_ind] == "conceptnorm": + # #### TODO add scaling as a hyper-parameter for concept normalization + scaling = 0.03 + similarity_score = similarity_score/scaling + return similarity_score + class RepresentationProjectionLayer(nn.Module): - def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False, num_attention_heads=-1, head_size=64): + def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False, skip_projection=False,num_attention_heads=-1, head_size=64): super().__init__() self.dropout = nn.Dropout(config.hidden_dropout_prob) if relations: @@ -47,6 +90,7 @@ def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False self.tokens = tokens self.tagger = tagger self.relations = relations + self.skip_projection = skip_projection self.hidden_size = config.hidden_size if num_attention_heads <= 0 and relations: @@ -95,9 +139,13 @@ def forward(self, features, event_tokens, **kwargs): # take token (equiv. to [CLS]) x = features[self.layer_to_use][..., 0, :] - x = self.dropout(x) - x = self.dense(x) - x = torch.tanh(x) + # for normal classification we pass through a dense layer, for cosine layer + # classification we just grab the representation directly: + if not self.skip_projection: + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + return x @@ -131,6 +179,7 @@ def __init__( tagger = [False], relations = [False], use_prior_tasks=False, + concept_norm=None, **kwargs ): super().__init__(**kwargs) @@ -146,6 +195,7 @@ def __init__( self.use_prior_tasks = use_prior_tasks self.encoder_name = encoder_name self.encoder_config = AutoConfig.from_pretrained(encoder_name).to_dict() + self.concept_norm = concept_norm if encoder_name.startswith('distilbert'): self.hidden_dropout_prob = self.encoder_config['dropout'] self.hidden_size = self.encoder_config['dim'] @@ -222,13 +272,17 @@ def __init__(self, self.classifiers = nn.ModuleList() total_prev_task_labels = 0 for task_ind,task_num_labels in enumerate(self.num_labels): - self.feature_extractors.append(RepresentationProjectionLayer(config, layer=config.layer, tokens=config.tokens, tagger=config.tagger[task_ind], relations=config.relations[task_ind], num_attention_heads=config.num_rel_attention_heads, head_size=config.rel_attention_head_dims)) + conceptnorm = config.finetuning_task[task_ind] == "conceptnorm" + self.feature_extractors.append(RepresentationProjectionLayer(config, layer=config.layer, tokens=config.tokens, tagger=config.tagger[task_ind], relations=config.relations[task_ind], skip_projection=conceptnorm, num_attention_heads=config.num_rel_attention_heads, head_size=config.rel_attention_head_dims)) if config.relations[task_ind]: hidden_size = config.num_rel_attention_heads if config.use_prior_tasks: hidden_size += total_prev_task_labels self.classifiers.append(ClassificationHead(config, task_num_labels, hidden_size=hidden_size)) + elif conceptnorm: + self.classifiers.append(CosineLayer(config, + concept_dim=task_num_labels -1)) else: self.classifiers.append(ClassificationHead(config, task_num_labels)) total_prev_task_labels += task_num_labels @@ -424,7 +478,7 @@ def forward( labels=None, output_attentions=None, output_hidden_states=None, - event_tokens=None, + event_mask=None, ): r""" Forward method. @@ -449,7 +503,7 @@ def forward( If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. output_hidden_states: not used. - event_tokens: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True. + event_mask: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True. Returns: (`transformers.SequenceClassifierOutput`) the output of the model """ @@ -480,7 +534,7 @@ def forward( ) for task_ind,task_num_labels in enumerate(self.num_labels): - features = self.feature_extractors[task_ind](outputs.hidden_states, event_tokens) + features = self.feature_extractors[task_ind](outputs.hidden_states, event_mask) if self.use_prior_tasks: # note: this specific way of incorporating previous logits doesn't help in my experiments with thyme/clinical tempeval if self.relations[task_ind]: diff --git a/src/cnlpt/api/cnlp_rest.py b/src/cnlpt/api/cnlp_rest.py index 3a9f00cb..a34e4d64 100644 --- a/src/cnlpt/api/cnlp_rest.py +++ b/src/cnlpt/api/cnlp_rest.py @@ -1,6 +1,7 @@ # Core python imports import os +import os.path # FastAPI imports from pydantic import BaseModel @@ -70,6 +71,10 @@ def initialize_cnlpt_model(app, model_name, cuda=True, batch_size=8): AutoModel.register(CnlpConfig, CnlpModelForClassification) config = AutoConfig.from_pretrained(model_name) + + if 'concept_norm' in config.__dict__: + config.concept_norm = os.path.join(model_name, '..') + app.state.tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) model = CnlpModelForClassification.from_pretrained(model_name, cache_dir=os.getenv('HF_CACHE'), config=config) diff --git a/src/cnlpt/api/concept_norm_rest.py b/src/cnlpt/api/concept_norm_rest.py new file mode 100644 index 00000000..3ec8feea --- /dev/null +++ b/src/cnlpt/api/concept_norm_rest.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from fastapi import FastAPI +from pydantic import BaseModel + +from typing import List, Tuple, Dict + +# Modeling imports +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, +) + +# from .api.cnlp_rest import get_dataset +from datasets import Dataset + +from ..CnlpModelForClassification import CnlpModelForClassification, CnlpConfig +from .cnlp_rest import get_dataset, initialize_cnlpt_model +import numpy as np +import torch + +import logging, os, json +from time import time + +app = FastAPI() +model_name = "/lab-share/CHIP-Savova-e2/Public/resources/cnlpt/concept_norm/share/checkpoint-57456/" +logger = logging.getLogger('Concept_Normalization_REST_Processor') +logger.setLevel(logging.DEBUG) + +task = 'conceptnorm' +with open(os.path.join(model_name,"../ontology_cui.txt"), 'r') as outfile: + labels = json.load(outfile) + +outfile.close() +labels = labels + ["CUI-less"] + +max_length = 32 + +class Entity(BaseModel): + ''' doc_text: The raw text of the document + offset: A list of entities, where each is a tuple of character offsets into doc_text for that entity''' + entity_text: str + + +class ConceptNormResults(BaseModel): + ''' statuses: dictionary from entity id to classification decision about negation; true -> negated, false -> not negated''' + statuses: List[str] + +@app.on_event("startup") +async def startup_event(): + initialize_cnlpt_model(app, model_name) + +@app.post("/concept_norm/process") +async def process(entity: Entity): + text = entity.entity_text + logger.warn('Received entities of len %d to process' % (len(text))) + instances = [text] + start_time = time() + + dataset = get_dataset(instances, app.state.tokenizer, [labels,], [task,], max_length) + preproc_end = time() + + output = app.state.trainer.predict(test_dataset=dataset) + predictions = output.predictions[0] + predictions = np.argmax(predictions, axis=1) + + pred_end = time() + + results = [] + for ent_ind in range(len(dataset)): + results.append(labels[predictions[ent_ind]]) + + output = ConceptNormResults(statuses=results) + + postproc_end = time() + + preproc_time = preproc_end - start_time + pred_time = pred_end - preproc_end + postproc_time = postproc_end - pred_end + + logging.warn("Pre-processing time: %f, processing time: %f, post-processing time %f" % (preproc_time, pred_time, postproc_time)) + + return output + +@app.get("/conceptnorm/{test_str}") +async def test(test_str: str): + return {'argument': test_str} + + +def rest(): + import argparse + + parser = argparse.ArgumentParser(description='Run the http server for negation') + parser.add_argument('-p', '--port', type=int, help='The port number to run the server on', default=8000) + + args = parser.parse_args() + + import uvicorn + uvicorn.run("cnlpt.api.concept_norm_rest:app", host='0.0.0.0', port=args.port, reload=True) + +if __name__ == '__main__': + rest() \ No newline at end of file diff --git a/src/cnlpt/cnlp_processors.py b/src/cnlpt/cnlp_processors.py index 83dcfdf7..de88ddf0 100644 --- a/src/cnlpt/cnlp_processors.py +++ b/src/cnlpt/cnlp_processors.py @@ -26,18 +26,25 @@ classification = 'classification' tagging = 'tagging' relex = 'relations' +conceptnorm = 'conceptnorm' -def get_unique_labels(dataset, tasks, task_output_modes): +def get_unique_labels(dataset, tasks, task_output_modes, data_dir): dataset_unique_labels = dict() for task_ind,task_name in enumerate(tasks): - unique_labels = set() - # check all splits for labels just in case they do not fully overlap - for split in dataset: - # Add labels from this split to the overall label set and give a warning if they are not the same - split_labels = set( dataset[split][task_name]) - unique_labels |= split_labels - - unique_labels = list(unique_labels) + if task_name != conceptnorm: + unique_labels = set() + # check all splits for labels just in case they do not fully overlap + for split in dataset: + # Add labels from this split to the overall label set and give a warning if they are not the same + split_labels = set( dataset[split][task_name]) + unique_labels |= split_labels + + unique_labels = list(unique_labels) + else: + with open(os.path.join(data_dir,'ontology_cui.txt'), 'r') as outfile: + labels = json.load(outfile) + outfile.close() + unique_labels = labels + ["CUI-less"] output_mode = task_output_modes[task_name] @@ -58,8 +65,8 @@ def get_unique_labels(dataset, tasks, task_output_modes): rel_cat = rel_cat[:-1] unique_relations.add(rel_cat) unique_labels = list(unique_relations) - - unique_labels.sort() + if task_name != conceptnorm: + unique_labels.sort() dataset_unique_labels[task_name] = unique_labels @@ -69,24 +76,26 @@ def infer_output_modes(dataset): task_output_modes = {} for task_ind, task_name in enumerate(dataset.tasks): output_mode = classification - unique_labels = set() - # check all splits for labels just in case they do not fully overlap - for split in dataset: - # Add labels from this split to the overall label set and give a warning if they are not the same - split_labels = set( dataset[split][task_name]) - unique_labels |= split_labels + if task_name != conceptnorm: + unique_labels = set() + # check all splits for labels just in case they do not fully overlap + for split in dataset: + # Add labels from this split to the overall label set and give a warning if they are not the same + split_labels = set( dataset[split][task_name]) + unique_labels |= split_labels - unique_labels = list(unique_labels) + unique_labels = list(unique_labels) ## Check if any unique label has a space in it, then we know we are actually ## dealing with a tagging dataset, or if it ends in ), in which case it is a relation task. - for label in unique_labels: - if str(label)[-1] == ')': - output_mode = relex - break - elif ' ' in str(label): - output_mode = tagging - break + if task_name != conceptnorm: + for label in unique_labels: + if str(label)[-1] == ')': + output_mode = relex + break + elif ' ' in str(label): + output_mode = tagging + break task_output_modes[task_name] = output_mode @@ -198,7 +207,7 @@ def __init__(self, data_dir:str, tasks:Set[str]=None, max_train_items=-1): self.dataset.task_output_modes = infer_output_modes(self.dataset) # get any split of the data and ask for the set of unique labels for each task in the dataset from that split - self.labels = get_unique_labels(self.dataset, self.dataset.tasks, self.dataset.task_output_modes) + self.labels = get_unique_labels(self.dataset, self.dataset.tasks, self.dataset.task_output_modes, data_dir) self.dataset = get_task_pruned_dataset(self.dataset, self.dataset.tasks, self.labels) diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index 173e20fd..177a1aea 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -120,6 +120,10 @@ class ModelArguments: default=False, metadata={"help": "Classify over an actual token rather than the [CLS] ('') token -- requires that the tokens to be classified are surrounded by / tokens"} ) + concept_norm_path: Optional[str] = field( + default=None, metadata={"help": "the path of pre-computed concept embeddings and cui-less threshold"} + ) + # NxN relation classifier-specific arguments num_rel_feats: Optional[int] = field( default=12, metadata={"help": "Number of features/attention heads to use in the NxN relation classifier"} @@ -304,6 +308,8 @@ def main(json_file=None, json_obj=None): model_name = model_args.model hierarchical = model_name == 'hier' + conceptnorm_path = model_args.concept_norm_path + # Get datasets dataset = ( @@ -481,6 +487,7 @@ def main(json_file=None, json_obj=None): model_name = tempmodel.name else: # setting 2) evaluate or make predictions + config.concept_norm = None model = CnlpModelForClassification.from_pretrained( model_args.encoder_name, config=config, @@ -503,7 +510,8 @@ def main(json_file=None, json_obj=None): num_rel_attention_heads=model_args.num_rel_feats, rel_attention_head_dims=model_args.head_features, tagger=tagger, - relations=relations,) + relations=relations, + concept_norm=conceptnorm_path,) #num_tokens=len(tokenizer)) config.vocab_size = len(tokenizer) model = CnlpModelForClassification( @@ -691,6 +699,7 @@ def compute_metrics_fn(p: EvalPrediction): if training_args.do_predict: logging.info("*** Test ***") test_dataset=dataset.datasets[0]['test'] + dataset_labels = dataset.get_labels()[0] # FIXME: this part hasn't been updated for the MTL setup so it doesn't work anymore since # predictions is generalized to be a list of predictions and the output needs to be different for each kin. # maybe it's ok to only handle classification since it has a very straightforward output format and evaluation, @@ -709,7 +718,8 @@ def compute_metrics_fn(p: EvalPrediction): with open(output_test_file, "w") as writer: logger.info("***** Test results *****") for index, item in enumerate(task_predictions): - item = test_dataset.get_labels()[task_name][item] + item = dataset_labels[task_name][item] + # item = test_dataset.get_labels()[task_name][item] writer.write("%s\n" % (item)) return eval_results