elshehawy commited on
Commit
d4df546
Β·
1 Parent(s): bd639d6

update app.py, and evaluate_data.py to work with uploaded files

Browse files
Files changed (2) hide show
  1. app.py +9 -7
  2. evaluate_data.py +17 -13
app.py CHANGED
@@ -10,7 +10,7 @@ import json
10
  # from simcse import SimCSE # use for gpt
11
  from evaluate_data import store_sample_data, get_metrics_trf
12
 
13
- store_sample_data()
14
 
15
 
16
 
@@ -92,10 +92,12 @@ example = """
92
  My latest exclusive for The Hill : Conservative frustration over Republican efforts to force a House vote on reauthorizing the Export - Import Bank boiled over Wednesday during a contentious GOP meeting.
93
 
94
  """
95
- def find_orgs(sentence, choice):
96
- return all_metrics
97
- radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
98
- textbox = gr.Textbox(label="Enter your text", placeholder=str(all_metrics), lines=8)
99
-
100
- iface = gr.Interface(fn=find_orgs, inputs=[textbox, radio_btn], outputs="text", examples=[[example]])
 
 
101
  iface.launch(share=True)
 
10
  # from simcse import SimCSE # use for gpt
11
  from evaluate_data import store_sample_data, get_metrics_trf
12
 
13
+ # store_sample_data()
14
 
15
 
16
 
 
92
  My latest exclusive for The Hill : Conservative frustration over Republican efforts to force a House vote on reauthorizing the Export - Import Bank boiled over Wednesday during a contentious GOP meeting.
93
 
94
  """
95
+ def find_orgs(uploaded_file):
96
+ uploaded_data = json.loads(uploaded_file)
97
+ return get_metrics_trf(uploaded_file)
98
+ # radio_btn = gr.Radio(choices=['GPT', 'iSemantics'], value='iSemantics', label='Available models', show_label=True)
99
+ # textbox = gr.Textbox(label="Enter your text", placeholder=str(all_metrics), lines=8)
100
+ upload_btn = gr.UploadButton(label='Upload a json file.')
101
+
102
+ iface = gr.Interface(fn=find_orgs, inputs=[upload_btn], outputs="text", examples=[[example]])
103
  iface.launch(share=True)
evaluate_data.py CHANGED
@@ -35,13 +35,6 @@ with open(feature_path, 'rb') as f:
35
  ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
36
 
37
 
38
- tokenized_test = test.map(
39
- tokenize_and_align_labels,
40
- batched=True,
41
- batch_size=None,
42
- remove_columns=test.column_names[2:],
43
- fn_kwargs={'tokenizer': tokenizer}
44
- )
45
 
46
  # tokenized_dataset.set_format('torch')
47
 
@@ -54,18 +47,28 @@ def collate_fn(data):
54
 
55
  return input_ids, token_type_ids, attention_mask, labels
56
 
57
- loader = torch.utils.data.DataLoader(tokenized_test, batch_size=16, collate_fn=collate_fn)
58
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
59
 
60
  ner_model = ner_model.eval()
61
 
62
 
63
 
64
- def get_metrics_trf():
65
  print(device)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  y_true, logits = [], []
68
-
69
  for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
70
  ner_model.to(device)
71
  with torch.no_grad():
@@ -110,10 +113,11 @@ def find_orgs(tokens, labels):
110
 
111
 
112
 
113
- def store_sample_data():
 
114
  test_data = []
115
 
116
- for sent in test:
117
  labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
118
  # print(labels)
119
  sent_orgs = find_orgs(sent['tokens'], labels)
 
35
  ner_model = AutoModelForTokenClassification.from_pretrained(checkpoint)
36
 
37
 
 
 
 
 
 
 
 
38
 
39
  # tokenized_dataset.set_format('torch')
40
 
 
47
 
48
  return input_ids, token_type_ids, attention_mask, labels
49
 
 
 
50
 
51
  ner_model = ner_model.eval()
52
 
53
 
54
 
55
+ def get_metrics_trf(data):
56
  print(device)
57
 
58
+ data = Dataset.from_dict(data)
59
+
60
+ tokenized_data = data.map(
61
+ tokenize_and_align_labels,
62
+ batched=True,
63
+ batch_size=None,
64
+ remove_columns=data.column_names[2:],
65
+ fn_kwargs={'tokenizer': tokenizer}
66
+ )
67
+
68
+ loader = torch.utils.data.DataLoader(tokenized_data, batch_size=16, collate_fn=collate_fn)
69
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
70
+
71
  y_true, logits = [], []
 
72
  for input_ids, token_type_ids, attention_mask, labels in tqdm(loader):
73
  ner_model.to(device)
74
  with torch.no_grad():
 
113
 
114
 
115
 
116
+ def store_sample_data(data):
117
+ data = Dataset.from_dict(data)
118
  test_data = []
119
 
120
+ for sent in data:
121
  labels = [ner_feature.feature.int2str(l) for l in sent['ner_tags']]
122
  # print(labels)
123
  sent_orgs = find_orgs(sent['tokens'], labels)