liamcripwell commited on
Commit
13e3b2f
·
1 Parent(s): b17b8f7

allow non-json template

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -92,11 +92,31 @@ def sliding_window_prediction(template, text, model, tokenizer, window_size=4000
92
  pred = handle_broken_output(pred, prev)
93
 
94
  # create highlighted text
95
- highlighted_pred = highlight_words(text, json.loads(pred))
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # Sync empty fields
98
- synced_pred = sync_empty_fields(json.loads(pred), json.loads(template))
99
- synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False)
 
 
 
 
 
100
 
101
  # Return progress, current prediction, and updated HTML
102
  yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred
@@ -118,13 +138,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
118
  model.eval()
119
 
120
  def gradio_interface_function(template, text, is_example):
121
- # reject invalid JSON
122
- try:
123
- template_json = json.loads(template)
124
- except:
125
- yield "", "Invalid JSON template", ""
126
- return # End the function since there was an error
127
-
128
  if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
129
  yield "", "Input text too long for space. Download model to use unrestricted.", ""
130
  return # End the function since there was an error
 
92
  pred = handle_broken_output(pred, prev)
93
 
94
  # create highlighted text
95
+ try:
96
+ highlighted_pred = highlight_words(text, json.loads(pred))
97
+ except:
98
+ highlighted_pred = text
99
+
100
+ # attempt json parsing
101
+ template_dict = None
102
+ pred_dict = None
103
+ try:
104
+ template_dict = json.loads(template)
105
+ except:
106
+ pass
107
+ try:
108
+ pred_dict = json.loads(pred)
109
+ except:
110
+ pass
111
+
112
  # Sync empty fields
113
+ if template_dict and pred_dict:
114
+ synced_pred = sync_empty_fields(pred_dict, template_dict)
115
+ synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False)
116
+ elif pred_dict:
117
+ synced_pred = json.dumps(pred_dict, indent=4, ensure_ascii=False)
118
+ else:
119
+ synced_pred = pred
120
 
121
  # Return progress, current prediction, and updated HTML
122
  yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred
 
138
  model.eval()
139
 
140
  def gradio_interface_function(template, text, is_example):
 
 
 
 
 
 
 
141
  if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
142
  yield "", "Input text too long for space. Download model to use unrestricted.", ""
143
  return # End the function since there was an error