Spaces:
Runtime error
Runtime error
update app.py, and evaluate_data.py
Browse files- app.py +19 -2
- evaluate_data.py +3 -8
app.py
CHANGED
@@ -45,7 +45,6 @@ def get_completion(prompt, model=llm_model):
|
|
45 |
return response.choices[0].message.content
|
46 |
|
47 |
|
48 |
-
|
49 |
def find_orgs_gpt(sentence):
|
50 |
prompt = f"""
|
51 |
In context of named entity recognition (NER), find all organizations in the text delimited by triple backticks.
|
@@ -106,7 +105,25 @@ def find_orgs(uploaded_file):
|
|
106 |
print(uploaded_file.decode())
|
107 |
uploaded_data = json.loads(uploaded_file)
|
108 |
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
# radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
|
111 |
# textbox = gr.Textbox(label="Enter your text", placeholder=str(all_metrics), lines=8)
|
112 |
upload_btn = gr.UploadButton(label='Upload a json file.', type='binary')
|
|
|
45 |
return response.choices[0].message.content
|
46 |
|
47 |
|
|
|
48 |
def find_orgs_gpt(sentence):
|
49 |
prompt = f"""
|
50 |
In context of named entity recognition (NER), find all organizations in the text delimited by triple backticks.
|
|
|
105 |
print(uploaded_file.decode())
|
106 |
uploaded_data = json.loads(uploaded_file)
|
107 |
|
108 |
+
|
109 |
+
all_metrics = {}
|
110 |
+
all_metrics['trf'] = get_metrics_trf(uploaded_data)
|
111 |
+
|
112 |
+
store_sample_data(uploaded_data)
|
113 |
+
with open('./data/sample_data.json', 'r') as f:
|
114 |
+
sample_data = json.load(f)
|
115 |
+
|
116 |
+
gpt_orgs, true_orgs = [], []
|
117 |
+
|
118 |
+
for sent in sample_data:
|
119 |
+
gpt_orgs.append(find_orgs_gpt(sent['text']))
|
120 |
+
true_orgs.append(sent['orgs'])
|
121 |
+
|
122 |
+
|
123 |
+
# sim_model = SimCSE('sentence-transformers/all-MiniLM-L6-v2')
|
124 |
+
# all_metrics['gpt'] = calc_metrics(true_orgs, gpt_orgs, sim_model)
|
125 |
+
|
126 |
+
return
|
127 |
# radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
|
128 |
# textbox = gr.Textbox(label="Enter your text", placeholder=str(all_metrics), lines=8)
|
129 |
upload_btn = gr.UploadButton(label='Upload a json file.', type='binary')
|
evaluate_data.py
CHANGED
@@ -34,11 +34,8 @@ with open(feature_path, 'rb') as f:
|
|
34 |
|
35 |
ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
|
36 |
|
37 |
-
|
38 |
-
|
39 |
# tokenized_dataset.set_format('torch')
|
40 |
|
41 |
-
|
42 |
def collate_fn(data):
|
43 |
input_ids = [(element['input_ids']) for element in data]
|
44 |
attention_mask = [element['attention_mask'] for element in data]
|
@@ -56,7 +53,7 @@ def get_metrics_trf(data):
|
|
56 |
print(device)
|
57 |
|
58 |
data = Dataset.from_dict(data)
|
59 |
-
|
60 |
tokenized_data = data.map(
|
61 |
tokenize_and_align_labels,
|
62 |
batched=True,
|
@@ -90,9 +87,7 @@ def get_metrics_trf(data):
|
|
90 |
# json.dump(all_metrics, f)
|
91 |
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
def find_orgs(tokens, labels):
|
96 |
orgs = []
|
97 |
prev_tok_id = 0
|
98 |
for i, (token, label) in enumerate(zip(tokens, labels)):
|
@@ -129,5 +124,5 @@ def store_sample_data(data):
|
|
129 |
'orgs': sent_orgs
|
130 |
})
|
131 |
|
132 |
-
with open('data/sample_data.json', 'w') as f:
|
133 |
json.dump(test_data, f)
|
|
|
34 |
|
35 |
ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
|
36 |
|
|
|
|
|
37 |
# tokenized_dataset.set_format('torch')
|
38 |
|
|
|
39 |
def collate_fn(data):
|
40 |
input_ids = [(element['input_ids']) for element in data]
|
41 |
attention_mask = [element['attention_mask'] for element in data]
|
|
|
53 |
print(device)
|
54 |
|
55 |
data = Dataset.from_dict(data)
|
56 |
+
|
57 |
tokenized_data = data.map(
|
58 |
tokenize_and_align_labels,
|
59 |
batched=True,
|
|
|
87 |
# json.dump(all_metrics, f)
|
88 |
|
89 |
|
90 |
+
def find_orgs_in_data(tokens, labels):
|
|
|
|
|
91 |
orgs = []
|
92 |
prev_tok_id = 0
|
93 |
for i, (token, label) in enumerate(zip(tokens, labels)):
|
|
|
124 |
'orgs': sent_orgs
|
125 |
})
|
126 |
|
127 |
+
with open('./data/sample_data.json', 'w') as f:
|
128 |
json.dump(test_data, f)
|