diff --git a/README.md b/README.md
index ffb552fe..3ba1f9dc 100644
--- a/README.md
+++ b/README.md
@@ -137,6 +137,22 @@ For a demo of how to run the system in colab: [:
```gene_ner = {'acc': 0.9835720382010852, 'token_f1': [0.5503875968992249, 0.8223125230882896, 0.0, 0.0, 0.9949847518170479], 'f1': 0.7477207783371888, 'report': '\n precision recall f1-score support\n\n GENE-N 0.70 0.45 0.55 2355\n GENE-Y 0.76 0.88 0.82 5013\n\n micro avg 0.75 0.75 0.75 7368\n macro avg 0.73 0.67 0.68 7368\nweighted avg 0.74 0.75 0.73 7368\n'}```
+### Fine-tuning for concept normalization: End-to-end example
+1. Prepare the concept normalization input data (train.tsv, dev.tsv, and test.tsv) use following format (.tsv file).
+
+| text | conceptnorm |
+| --- | --- |
+| pleural effusion | C0032227 |
+| pulmonary consolidation | C0521530 |
+| aorta tortuous | CUI-less |
+
+2. Prepare ontology cui text file (.txt), each line of the file will be a CUI.
+
+3. Prepare concept embeddings .npy file, and save it into a folder $concept_norm_path (a giant matrix with each row correspoding to the embeddings of each CUI, the order of the CUI embeddings follows the order of the ontology cui text file).
+
+4. Fine-tune with something like:
+```python -m cnlpt.train_system # --do_train --do_eval --task_name conceptnorm --data_dir cnlp_concept_norm/ --encoder_name cambridgeltl/SapBERT-from-PubMedBERT-fulltext-mean-token --output_dir temp/ --concept_norm_path $concept_norm_path --overwrite_output_dir --cache cache --token true --num_train_epochs 5 --learning_rate 3e-5 --per_device_train_batch_size 64 --max_seq_length 16 --layer 12 --seed 24 --evals_per_epoch 1```
+
### Fine-tuning options
Run ```python -m cnlpt.train_system -h``` to see all the available options. In addition to inherited Huggingface Transformers options, there are options to do the following:
* Run simple baselines (use ``--model cnn --tokenizer_name roberta-base`` -- since there is no HF model then you must specify the tokenizer explicitly)
diff --git a/setup.cfg b/setup.cfg
index d3cfddb4..f70cf950 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -43,6 +43,7 @@ install_requires =
[options.entry_points]
console_scripts =
+ cnlpt_conceptnorm_rest = cnlpt.api.concept_norm_rest:rest
cnlpt_dtr_rest = cnlpt.api.dtr_rest:rest
cnlpt_event_rest = cnlpt.api.event_rest:rest
cnlpt_negation_rest = cnlpt.api.negation_rest:rest
diff --git a/src/cnlpt/CnlpModelForClassification.py b/src/cnlpt/CnlpModelForClassification.py
index cf893bd9..ddff7933 100644
--- a/src/cnlpt/CnlpModelForClassification.py
+++ b/src/cnlpt/CnlpModelForClassification.py
@@ -13,11 +13,13 @@
import torch
from torch import nn
import logging
-from torch.nn import CrossEntropyLoss, MSELoss
+from torch.nn import CrossEntropyLoss, MSELoss, Parameter
from transformers.modeling_outputs import SequenceClassifierOutput
-from torch.nn.functional import softmax, relu
+from torch.nn.functional import softmax, relu, normalize
import math
import random
+import numpy as np
+import os
logger = logging.getLogger(__name__)
@@ -33,9 +35,50 @@ def forward(self, features, *kwargs):
x = self.out_proj(x)
return x
+class CosineLayer(nn.Module):
+ def __init__(
+ self,
+ config,
+ concept_dim=88150
+ ):
+ super(CosineLayer, self).__init__()
+
+ self.dropout = nn.Dropout(0.1)
+ self.cos = nn.CosineSimilarity(dim=-1)
+ concept_dims = (concept_dim,config.hidden_size)
+
+ if config.concept_norm is not None:
+ weights_matrix = np.load(os.path.join(config.concept_norm,
+ "concept_embeddings.npy")).astype(np.float32)
+ self.weight = Parameter(torch.from_numpy(weights_matrix),
+ requires_grad=True)
+ threshold_value = np.loadtxt(
+ os.path.join(config.concept_norm, "cuiless_threshold.txt")).astype(np.float32)
+ self.threshold = Parameter(torch.tensor(threshold_value),
+ requires_grad=False)
+ else:
+ self.weight = Parameter(torch.rand(concept_dims),
+ requires_grad=True)
+ torch.nn.init.xavier_uniform(self.weight)
+ self.threshold = Parameter(torch.tensor(0.35), requires_grad=True)
+
+ def forward(self, features):
+ batch_size, fea_size = features.shape
+ features_norm = normalize(features)
+ weight_norm = normalize(self.weight)
+ sim_mt = torch.mm(features_norm, weight_norm.transpose(0, 1))
+ cui_less_score = torch.full((batch_size, 1), 1).to(
+ features.device) * self.threshold.to(features.device)
+ similarity_score = torch.cat((sim_mt, cui_less_score), 1)
+ # if self.config.finetuning_task[task_ind] == "conceptnorm":
+ # #### TODO add scaling as a hyper-parameter for concept normalization
+ scaling = 0.03
+ similarity_score = similarity_score/scaling
+ return similarity_score
+
class RepresentationProjectionLayer(nn.Module):
- def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False, num_attention_heads=-1, head_size=64):
+ def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False, skip_projection=False,num_attention_heads=-1, head_size=64):
super().__init__()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if relations:
@@ -47,6 +90,7 @@ def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False
self.tokens = tokens
self.tagger = tagger
self.relations = relations
+ self.skip_projection = skip_projection
self.hidden_size = config.hidden_size
if num_attention_heads <= 0 and relations:
@@ -95,9 +139,13 @@ def forward(self, features, event_tokens, **kwargs):
# take token (equiv. to [CLS])
x = features[self.layer_to_use][..., 0, :]
- x = self.dropout(x)
- x = self.dense(x)
- x = torch.tanh(x)
+ # for normal classification we pass through a dense layer, for cosine layer
+ # classification we just grab the representation directly:
+ if not self.skip_projection:
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+
return x
@@ -131,6 +179,7 @@ def __init__(
tagger = [False],
relations = [False],
use_prior_tasks=False,
+ concept_norm=None,
**kwargs
):
super().__init__(**kwargs)
@@ -146,6 +195,7 @@ def __init__(
self.use_prior_tasks = use_prior_tasks
self.encoder_name = encoder_name
self.encoder_config = AutoConfig.from_pretrained(encoder_name).to_dict()
+ self.concept_norm = concept_norm
if encoder_name.startswith('distilbert'):
self.hidden_dropout_prob = self.encoder_config['dropout']
self.hidden_size = self.encoder_config['dim']
@@ -222,13 +272,17 @@ def __init__(self,
self.classifiers = nn.ModuleList()
total_prev_task_labels = 0
for task_ind,task_num_labels in enumerate(self.num_labels):
- self.feature_extractors.append(RepresentationProjectionLayer(config, layer=config.layer, tokens=config.tokens, tagger=config.tagger[task_ind], relations=config.relations[task_ind], num_attention_heads=config.num_rel_attention_heads, head_size=config.rel_attention_head_dims))
+ conceptnorm = config.finetuning_task[task_ind] == "conceptnorm"
+ self.feature_extractors.append(RepresentationProjectionLayer(config, layer=config.layer, tokens=config.tokens, tagger=config.tagger[task_ind], relations=config.relations[task_ind], skip_projection=conceptnorm, num_attention_heads=config.num_rel_attention_heads, head_size=config.rel_attention_head_dims))
if config.relations[task_ind]:
hidden_size = config.num_rel_attention_heads
if config.use_prior_tasks:
hidden_size += total_prev_task_labels
self.classifiers.append(ClassificationHead(config, task_num_labels, hidden_size=hidden_size))
+ elif conceptnorm:
+ self.classifiers.append(CosineLayer(config,
+ concept_dim=task_num_labels -1))
else:
self.classifiers.append(ClassificationHead(config, task_num_labels))
total_prev_task_labels += task_num_labels
@@ -424,7 +478,7 @@ def forward(
labels=None,
output_attentions=None,
output_hidden_states=None,
- event_tokens=None,
+ event_mask=None,
):
r"""
Forward method.
@@ -449,7 +503,7 @@ def forward(
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers.
output_hidden_states: not used.
- event_tokens: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True.
+ event_mask: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True.
Returns: (`transformers.SequenceClassifierOutput`) the output of the model
"""
@@ -480,7 +534,7 @@ def forward(
)
for task_ind,task_num_labels in enumerate(self.num_labels):
- features = self.feature_extractors[task_ind](outputs.hidden_states, event_tokens)
+ features = self.feature_extractors[task_ind](outputs.hidden_states, event_mask)
if self.use_prior_tasks:
# note: this specific way of incorporating previous logits doesn't help in my experiments with thyme/clinical tempeval
if self.relations[task_ind]:
diff --git a/src/cnlpt/api/cnlp_rest.py b/src/cnlpt/api/cnlp_rest.py
index 3a9f00cb..a34e4d64 100644
--- a/src/cnlpt/api/cnlp_rest.py
+++ b/src/cnlpt/api/cnlp_rest.py
@@ -1,6 +1,7 @@
# Core python imports
import os
+import os.path
# FastAPI imports
from pydantic import BaseModel
@@ -70,6 +71,10 @@ def initialize_cnlpt_model(app, model_name, cuda=True, batch_size=8):
AutoModel.register(CnlpConfig, CnlpModelForClassification)
config = AutoConfig.from_pretrained(model_name)
+
+ if 'concept_norm' in config.__dict__:
+ config.concept_norm = os.path.join(model_name, '..')
+
app.state.tokenizer = AutoTokenizer.from_pretrained(model_name,
config=config)
model = CnlpModelForClassification.from_pretrained(model_name, cache_dir=os.getenv('HF_CACHE'), config=config)
diff --git a/src/cnlpt/api/concept_norm_rest.py b/src/cnlpt/api/concept_norm_rest.py
new file mode 100644
index 00000000..3ec8feea
--- /dev/null
+++ b/src/cnlpt/api/concept_norm_rest.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from fastapi import FastAPI
+from pydantic import BaseModel
+
+from typing import List, Tuple, Dict
+
+# Modeling imports
+from transformers import (
+ AutoConfig,
+ AutoModel,
+ AutoTokenizer,
+ HfArgumentParser,
+ Trainer,
+ TrainingArguments,
+)
+
+# from .api.cnlp_rest import get_dataset
+from datasets import Dataset
+
+from ..CnlpModelForClassification import CnlpModelForClassification, CnlpConfig
+from .cnlp_rest import get_dataset, initialize_cnlpt_model
+import numpy as np
+import torch
+
+import logging, os, json
+from time import time
+
+app = FastAPI()
+model_name = "/lab-share/CHIP-Savova-e2/Public/resources/cnlpt/concept_norm/share/checkpoint-57456/"
+logger = logging.getLogger('Concept_Normalization_REST_Processor')
+logger.setLevel(logging.DEBUG)
+
+task = 'conceptnorm'
+with open(os.path.join(model_name,"../ontology_cui.txt"), 'r') as outfile:
+ labels = json.load(outfile)
+
+outfile.close()
+labels = labels + ["CUI-less"]
+
+max_length = 32
+
+class Entity(BaseModel):
+ ''' doc_text: The raw text of the document
+ offset: A list of entities, where each is a tuple of character offsets into doc_text for that entity'''
+ entity_text: str
+
+
+class ConceptNormResults(BaseModel):
+ ''' statuses: dictionary from entity id to classification decision about negation; true -> negated, false -> not negated'''
+ statuses: List[str]
+
+@app.on_event("startup")
+async def startup_event():
+ initialize_cnlpt_model(app, model_name)
+
+@app.post("/concept_norm/process")
+async def process(entity: Entity):
+ text = entity.entity_text
+ logger.warn('Received entities of len %d to process' % (len(text)))
+ instances = [text]
+ start_time = time()
+
+ dataset = get_dataset(instances, app.state.tokenizer, [labels,], [task,], max_length)
+ preproc_end = time()
+
+ output = app.state.trainer.predict(test_dataset=dataset)
+ predictions = output.predictions[0]
+ predictions = np.argmax(predictions, axis=1)
+
+ pred_end = time()
+
+ results = []
+ for ent_ind in range(len(dataset)):
+ results.append(labels[predictions[ent_ind]])
+
+ output = ConceptNormResults(statuses=results)
+
+ postproc_end = time()
+
+ preproc_time = preproc_end - start_time
+ pred_time = pred_end - preproc_end
+ postproc_time = postproc_end - pred_end
+
+ logging.warn("Pre-processing time: %f, processing time: %f, post-processing time %f" % (preproc_time, pred_time, postproc_time))
+
+ return output
+
+@app.get("/conceptnorm/{test_str}")
+async def test(test_str: str):
+ return {'argument': test_str}
+
+
+def rest():
+ import argparse
+
+ parser = argparse.ArgumentParser(description='Run the http server for negation')
+ parser.add_argument('-p', '--port', type=int, help='The port number to run the server on', default=8000)
+
+ args = parser.parse_args()
+
+ import uvicorn
+ uvicorn.run("cnlpt.api.concept_norm_rest:app", host='0.0.0.0', port=args.port, reload=True)
+
+if __name__ == '__main__':
+ rest()
\ No newline at end of file
diff --git a/src/cnlpt/cnlp_processors.py b/src/cnlpt/cnlp_processors.py
index 83dcfdf7..de88ddf0 100644
--- a/src/cnlpt/cnlp_processors.py
+++ b/src/cnlpt/cnlp_processors.py
@@ -26,18 +26,25 @@
classification = 'classification'
tagging = 'tagging'
relex = 'relations'
+conceptnorm = 'conceptnorm'
-def get_unique_labels(dataset, tasks, task_output_modes):
+def get_unique_labels(dataset, tasks, task_output_modes, data_dir):
dataset_unique_labels = dict()
for task_ind,task_name in enumerate(tasks):
- unique_labels = set()
- # check all splits for labels just in case they do not fully overlap
- for split in dataset:
- # Add labels from this split to the overall label set and give a warning if they are not the same
- split_labels = set( dataset[split][task_name])
- unique_labels |= split_labels
-
- unique_labels = list(unique_labels)
+ if task_name != conceptnorm:
+ unique_labels = set()
+ # check all splits for labels just in case they do not fully overlap
+ for split in dataset:
+ # Add labels from this split to the overall label set and give a warning if they are not the same
+ split_labels = set( dataset[split][task_name])
+ unique_labels |= split_labels
+
+ unique_labels = list(unique_labels)
+ else:
+ with open(os.path.join(data_dir,'ontology_cui.txt'), 'r') as outfile:
+ labels = json.load(outfile)
+ outfile.close()
+ unique_labels = labels + ["CUI-less"]
output_mode = task_output_modes[task_name]
@@ -58,8 +65,8 @@ def get_unique_labels(dataset, tasks, task_output_modes):
rel_cat = rel_cat[:-1]
unique_relations.add(rel_cat)
unique_labels = list(unique_relations)
-
- unique_labels.sort()
+ if task_name != conceptnorm:
+ unique_labels.sort()
dataset_unique_labels[task_name] = unique_labels
@@ -69,24 +76,26 @@ def infer_output_modes(dataset):
task_output_modes = {}
for task_ind, task_name in enumerate(dataset.tasks):
output_mode = classification
- unique_labels = set()
- # check all splits for labels just in case they do not fully overlap
- for split in dataset:
- # Add labels from this split to the overall label set and give a warning if they are not the same
- split_labels = set( dataset[split][task_name])
- unique_labels |= split_labels
+ if task_name != conceptnorm:
+ unique_labels = set()
+ # check all splits for labels just in case they do not fully overlap
+ for split in dataset:
+ # Add labels from this split to the overall label set and give a warning if they are not the same
+ split_labels = set( dataset[split][task_name])
+ unique_labels |= split_labels
- unique_labels = list(unique_labels)
+ unique_labels = list(unique_labels)
## Check if any unique label has a space in it, then we know we are actually
## dealing with a tagging dataset, or if it ends in ), in which case it is a relation task.
- for label in unique_labels:
- if str(label)[-1] == ')':
- output_mode = relex
- break
- elif ' ' in str(label):
- output_mode = tagging
- break
+ if task_name != conceptnorm:
+ for label in unique_labels:
+ if str(label)[-1] == ')':
+ output_mode = relex
+ break
+ elif ' ' in str(label):
+ output_mode = tagging
+ break
task_output_modes[task_name] = output_mode
@@ -198,7 +207,7 @@ def __init__(self, data_dir:str, tasks:Set[str]=None, max_train_items=-1):
self.dataset.task_output_modes = infer_output_modes(self.dataset)
# get any split of the data and ask for the set of unique labels for each task in the dataset from that split
- self.labels = get_unique_labels(self.dataset, self.dataset.tasks, self.dataset.task_output_modes)
+ self.labels = get_unique_labels(self.dataset, self.dataset.tasks, self.dataset.task_output_modes, data_dir)
self.dataset = get_task_pruned_dataset(self.dataset, self.dataset.tasks, self.labels)
diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py
index 173e20fd..177a1aea 100644
--- a/src/cnlpt/train_system.py
+++ b/src/cnlpt/train_system.py
@@ -120,6 +120,10 @@ class ModelArguments:
default=False, metadata={"help": "Classify over an actual token rather than the [CLS] ('') token -- requires that the tokens to be classified are surrounded by / tokens"}
)
+ concept_norm_path: Optional[str] = field(
+ default=None, metadata={"help": "the path of pre-computed concept embeddings and cui-less threshold"}
+ )
+
# NxN relation classifier-specific arguments
num_rel_feats: Optional[int] = field(
default=12, metadata={"help": "Number of features/attention heads to use in the NxN relation classifier"}
@@ -304,6 +308,8 @@ def main(json_file=None, json_obj=None):
model_name = model_args.model
hierarchical = model_name == 'hier'
+ conceptnorm_path = model_args.concept_norm_path
+
# Get datasets
dataset = (
@@ -481,6 +487,7 @@ def main(json_file=None, json_obj=None):
model_name = tempmodel.name
else:
# setting 2) evaluate or make predictions
+ config.concept_norm = None
model = CnlpModelForClassification.from_pretrained(
model_args.encoder_name,
config=config,
@@ -503,7 +510,8 @@ def main(json_file=None, json_obj=None):
num_rel_attention_heads=model_args.num_rel_feats,
rel_attention_head_dims=model_args.head_features,
tagger=tagger,
- relations=relations,)
+ relations=relations,
+ concept_norm=conceptnorm_path,)
#num_tokens=len(tokenizer))
config.vocab_size = len(tokenizer)
model = CnlpModelForClassification(
@@ -691,6 +699,7 @@ def compute_metrics_fn(p: EvalPrediction):
if training_args.do_predict:
logging.info("*** Test ***")
test_dataset=dataset.datasets[0]['test']
+ dataset_labels = dataset.get_labels()[0]
# FIXME: this part hasn't been updated for the MTL setup so it doesn't work anymore since
# predictions is generalized to be a list of predictions and the output needs to be different for each kin.
# maybe it's ok to only handle classification since it has a very straightforward output format and evaluation,
@@ -709,7 +718,8 @@ def compute_metrics_fn(p: EvalPrediction):
with open(output_test_file, "w") as writer:
logger.info("***** Test results *****")
for index, item in enumerate(task_predictions):
- item = test_dataset.get_labels()[task_name][item]
+ item = dataset_labels[task_name][item]
+ # item = test_dataset.get_labels()[task_name][item]
writer.write("%s\n" % (item))
return eval_results