Spaces:
Runtime error
Runtime error
Commit
·
9c98606
1
Parent(s):
981717f
Upload batch_main.py
Browse files- batch_main.py +63 -0
batch_main.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import string
|
4 |
+
import pdb
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
8 |
+
import BatchInference as bd
|
9 |
+
import batched_main_NER as ner
|
10 |
+
import aggregate_server_json as aggr
|
11 |
+
import json
|
12 |
+
|
13 |
+
|
14 |
+
DEFAULT_TOP_K = 20
|
15 |
+
SPECIFIC_TAG=":__entity__"
|
16 |
+
DEFAULT_MODEL_PATH="ajitrajasekharan/biomedical"
|
17 |
+
DEFAULT_RESULTS="results.txt"
|
18 |
+
|
19 |
+
|
20 |
+
def perform_inference(text,bio_model,ner_bio,aggr_ner):
|
21 |
+
print("Getting predictions from BIO model...")
|
22 |
+
bio_descs = bio_model.get_descriptors(text,None)
|
23 |
+
print("Computing BIO results...")
|
24 |
+
bio_ner = ner_bio.tag_sentence_service(text,bio_descs)
|
25 |
+
obj = json.loads(bio_ner)
|
26 |
+
combined_arr = [obj,obj]
|
27 |
+
aggregate_results = aggr_ner.fetch_all(text,combined_arr)
|
28 |
+
return aggregate_results
|
29 |
+
|
30 |
+
|
31 |
+
def process_input(results):
|
32 |
+
try:
|
33 |
+
input_file = results.input
|
34 |
+
output_file = results.output
|
35 |
+
print("Initializing BIO module...")
|
36 |
+
bio_model = bd.BatchInference("bio/desc_a100_config.json",'ajitrajasekharan/biomedical',False,False,DEFAULT_TOP_K,True,True, "bio/","bio/a100_labels.txt",False)
|
37 |
+
ner_bio = ner.UnsupNER("bio/ner_a100_config.json")
|
38 |
+
print("Initializing Aggregation module...")
|
39 |
+
aggr_ner = aggr.AggregateNER("./ensemble_config.json")
|
40 |
+
wfp = open(output_file,"w")
|
41 |
+
with open(input_file) as fp:
|
42 |
+
for line in fp:
|
43 |
+
text_input = line.strip().split()
|
44 |
+
print(text_input)
|
45 |
+
text_input = [t + ":__entity__" for t in text_input]
|
46 |
+
text_input = ' '.join(text_input)
|
47 |
+
start = time.time()
|
48 |
+
results = perform_inference(text_input,bio_model,ner_bio,aggr_ner)
|
49 |
+
print(f"prediction took {time.time() - start:.2f}s")
|
50 |
+
pdb.set_trace()
|
51 |
+
wfp.write(json.dumps(results))
|
52 |
+
wfp.write("\n\n")
|
53 |
+
wfp.close()
|
54 |
+
except Exception as e:
|
55 |
+
print("Some error occurred in batch processing")
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
parser = argparse.ArgumentParser(description='Batch handling of NER ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
59 |
+
parser.add_argument('-model', action="store", dest="model", default=DEFAULT_MODEL_PATH,help='BERT pretrained models, or custom model path')
|
60 |
+
parser.add_argument('-input', action="store", dest="input", required=True,help='Input file with sentences')
|
61 |
+
parser.add_argument('-output', action="store", dest="output", default=DEFAULT_RESULTS,help='Output file with sentences')
|
62 |
+
results = parser.parse_args()
|
63 |
+
process_input(results)
|