ankur-bohra commited on
Commit
2fa693e
1 Parent(s): c0ab6fd

Fix flagging and possible race conditions

Browse files
Files changed (1) hide show
  1. app.py +90 -77
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import base64
2
  import os
 
3
  from io import BytesIO
4
  from pathlib import Path
5
 
6
- from langchain.schema.output_parser import OutputParserException
7
  import gradio as gr
 
8
  from PIL import Image
9
 
10
  import categories
@@ -40,7 +41,6 @@ hf_writer_incorrect = gr.HuggingFaceDatasetSaver(
40
  # global example_paths, current_file_path
41
  # if current_file_path not in example_paths:
42
  # return function(*args, **kwargs)
43
-
44
 
45
 
46
  def display_file(input_file):
@@ -76,63 +76,72 @@ def clear_inputs():
76
  return gr.File.update(value=None)
77
 
78
 
79
- def submit(input_file, old_text):
 
 
 
 
 
 
80
  if not input_file:
81
  gr.Error("Please upload a file to continue!")
82
  return gr.Textbox.update()
83
- print("-"*5)
84
- print("New input")
85
  # Send change to preprocessed image or to extracted text
86
  if input_file.name.endswith(".pdf"):
87
  text = process_pdf(Path(input_file.name), extract_only=True)
88
  else:
89
  text = process_image(Path(input_file.name), extract_only=True)
90
- print("Extracted text")
91
  return text
92
 
93
 
94
- def categorize_extracted_text(extracted_text):
95
- category = categories.categorize_text(extracted_text)
96
- print("Recognized category:", category)
97
- # gr.Info(f"Recognized category: {category}")
98
  return category
99
 
100
 
101
- def parse_from_category(category, extracted_text):
102
- if not category:
103
- print("Updated with no category:", category)
104
- return (
105
- gr.Chatbot.update(None),
106
- gr.JSON.update(None),
107
- gr.Button.update(interactive=False),
108
- gr.Button.update(interactive=False),
109
- )
110
- else:
111
- print("Updated with actual category:", category)
112
  category = Category[category]
113
- print("Parsing text from", category)
114
  chain = categories.category_modules[category].chain
115
  formatted_prompt = chain.prompt.format_prompt(
116
- text=extracted_text,
117
  format_instructions=chain.output_parser.get_format_instructions(),
118
  )
 
 
 
 
 
 
119
  result = chain.generate(
120
  input_list=[
121
  {
122
- "text": extracted_text,
123
  "format_instructions": chain.output_parser.get_format_instructions(),
124
  }
125
  ]
126
  )
127
- question = f""
128
- if len(formatted_prompt.messages) > 1:
129
- question += f"**System:**\n{formatted_prompt.messages[0].content}"
130
- question += f"\n\n**Human:**\n{formatted_prompt.messages[1].content}"
131
- print("\tConstructed prompt")
132
  answer = result.generations[0][0].text
133
- print("\tProcessed text")
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  try:
135
- information = chain.output_parser.parse_with_prompt(answer, formatted_prompt)
136
  information = information.json() if information else {}
137
  except OutputParserException as e:
138
  information = {
@@ -140,25 +149,22 @@ def parse_from_category(category, extracted_text):
140
  "details": str(e),
141
  "output": e.llm_output,
142
  }
143
- return (
144
- gr.Chatbot.update([[question, answer]]),
145
- gr.JSON.update(information),
146
- gr.Button.update(interactive=True),
147
- gr.Button.update(interactive=True),
148
- )
149
 
150
 
151
- def dynamic_auto_flag(flag_method):
152
- def modified_flag_method(share_result, *args, **kwargs):
153
- if share_result:
154
- flag_method(*args, **kwargs)
155
 
156
- return modified_flag_method
157
 
 
 
158
 
159
- # def save_example_and_submit(input_file):
160
- # example_paths.append(input_file.name)
161
- # submit(input_file, "")
 
 
 
162
 
163
 
164
  with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
@@ -261,35 +267,14 @@ with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
261
  flag_irrelevant_button = gr.Button(
262
  "Flag as irrelevant", variant="stop", interactive=True
263
  )
264
-
265
  show_intermediate.change(
266
  show_intermediate_outputs, show_intermediate, [intermediate_outputs]
267
  )
268
 
269
- clear.click(clear_inputs, None, [input_file])
270
- submit_button.click(
271
- submit,
272
- [input_file, extracted_text],
273
- [extracted_text],
274
- )
275
- # submit_button.click(
276
- # lambda input_file, category, chatbot, information: (print("File supplied, resetting") or (
277
- # gr.Dropdown.update(Category.ACCOMODATION),
278
- # gr.Chatbot.update(None),
279
- # gr.Textbox.update(None),
280
- # )) if input_file else (print("File not supplied, keeping") or print(category, chatbot, information)),
281
- # [input_file, category, chatbot, information],
282
- # [category, chatbot, information],
283
- # )
284
- extracted_text.change(
285
- categorize_extracted_text,
286
- [extracted_text],
287
- [category],
288
- )
289
- category.change(
290
- parse_from_category,
291
- [category, extracted_text],
292
- [chatbot, information, flag_incorrect_button, flag_irrelevant_button],
293
  )
294
 
295
  hf_writer_normal.setup(
@@ -297,11 +282,37 @@ with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
297
  flagging_dir="flagged",
298
  )
299
  flag_method = gr.flagging.FlagMethod(
300
- hf_writer_normal, "", "", visual_feedback=True
301
  )
302
- information.change(
303
- dynamic_auto_flag(flag_method),
304
- inputs=[
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  share_result,
306
  input_file,
307
  extracted_text,
@@ -310,9 +321,8 @@ with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
310
  information,
311
  contact,
312
  ],
313
- outputs=None,
314
  preprocess=False,
315
- queue=False,
316
  )
317
 
318
  hf_writer_incorrect.setup(
@@ -373,5 +383,8 @@ with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
373
  queue=False,
374
  )
375
 
376
-
 
 
 
377
  page.launch(show_api=True, show_error=True, debug=True)
 
1
  import base64
2
  import os
3
+ import re
4
  from io import BytesIO
5
  from pathlib import Path
6
 
 
7
  import gradio as gr
8
+ from langchain.schema.output_parser import OutputParserException
9
  from PIL import Image
10
 
11
  import categories
 
41
  # global example_paths, current_file_path
42
  # if current_file_path not in example_paths:
43
  # return function(*args, **kwargs)
 
44
 
45
 
46
  def display_file(input_file):
 
76
  return gr.File.update(value=None)
77
 
78
 
79
+ def clear_outputs(input_file):
80
+ if input_file:
81
+ return None, None, None, None
82
+
83
+
84
+ def extract_text(input_file):
85
+ """Takes the input file and updates the extracted text"""
86
  if not input_file:
87
  gr.Error("Please upload a file to continue!")
88
  return gr.Textbox.update()
 
 
89
  # Send change to preprocessed image or to extracted text
90
  if input_file.name.endswith(".pdf"):
91
  text = process_pdf(Path(input_file.name), extract_only=True)
92
  else:
93
  text = process_image(Path(input_file.name), extract_only=True)
 
94
  return text
95
 
96
 
97
+ def categorize_text(text):
98
+ """Takes the extracted text and updates the category"""
99
+ category = categories.categorize_text(text)
 
100
  return category
101
 
102
 
103
+ def query(category, text):
104
+ """Takes the extracted text and category and updates the chatbot in two steps:
105
+ 1. Construct a prompt
106
+ 2. Generate a response
107
+ """
 
 
 
 
 
 
108
  category = Category[category]
 
109
  chain = categories.category_modules[category].chain
110
  formatted_prompt = chain.prompt.format_prompt(
111
+ text=text,
112
  format_instructions=chain.output_parser.get_format_instructions(),
113
  )
114
+ question = f""
115
+ if len(formatted_prompt.messages) > 1:
116
+ question += f"**System:**\n{formatted_prompt.messages[0].content}"
117
+ question += f"\n\n**Human:**\n{formatted_prompt.messages[1].content}"
118
+ yield gr.Chatbot.update([[question, "Generating..."]])
119
+
120
  result = chain.generate(
121
  input_list=[
122
  {
123
+ "text": text,
124
  "format_instructions": chain.output_parser.get_format_instructions(),
125
  }
126
  ]
127
  )
 
 
 
 
 
128
  answer = result.generations[0][0].text
129
+ yield gr.Chatbot.update([[question, answer]])
130
+
131
+
132
+ PARSING_REGEXP = r"\*\*System:\*\*\n([\s\S]+)\n\n\*\*Human:\*\*\n([\s\S]+)"
133
+
134
+
135
+ def parse(category, chatbot):
136
+ """Takes the chatbot prompt and response and updates the extracted information"""
137
+ global PARSING_REGEXP
138
+
139
+ answer = chatbot[0][1]
140
+ category = Category[category]
141
+ chain = categories.category_modules[category].chain
142
+ yield {"status": "Parsing response..."}
143
  try:
144
+ information = chain.output_parser.parse(answer)
145
  information = information.json() if information else {}
146
  except OutputParserException as e:
147
  information = {
 
149
  "details": str(e),
150
  "output": e.llm_output,
151
  }
152
+ yield information
 
 
 
 
 
153
 
154
 
155
+ def activate_flags():
156
+ return gr.Button.update(interactive=True), gr.Button.update(interactive=True)
 
 
157
 
 
158
 
159
+ def deactivate_flags():
160
+ return gr.Button.update(interactive=False), gr.Button.update(interactive=False)
161
 
162
+
163
+ def flag_if_shared(flag_method):
164
+ def proxy(share_result, request: gr.Request, *args, **kwargs):
165
+ if share_result:
166
+ return flag_method(request, *args, **kwargs)
167
+ return proxy
168
 
169
 
170
  with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
 
267
  flag_irrelevant_button = gr.Button(
268
  "Flag as irrelevant", variant="stop", interactive=True
269
  )
 
270
  show_intermediate.change(
271
  show_intermediate_outputs, show_intermediate, [intermediate_outputs]
272
  )
273
 
274
+ clear.click(clear_inputs, None, [input_file]).then(
275
+ deactivate_flags,
276
+ None,
277
+ [flag_incorrect_button, flag_irrelevant_button],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  )
279
 
280
  hf_writer_normal.setup(
 
282
  flagging_dir="flagged",
283
  )
284
  flag_method = gr.flagging.FlagMethod(
285
+ hf_writer_normal, "", "", visual_feedback=False
286
  )
287
+
288
+ submit_button.click(
289
+ clear_outputs,
290
+ [input_file],
291
+ [extracted_text, category, chatbot, information],
292
+ ).then(
293
+ extract_text,
294
+ [input_file],
295
+ [extracted_text],
296
+ ).then(
297
+ categorize_text,
298
+ [extracted_text],
299
+ [category],
300
+ ).then(
301
+ query,
302
+ [category, extracted_text],
303
+ [chatbot],
304
+ queue=True,
305
+ ).then(
306
+ parse,
307
+ [category, chatbot],
308
+ [information],
309
+ ).then(
310
+ activate_flags,
311
+ None,
312
+ [flag_incorrect_button, flag_irrelevant_button],
313
+ ).then(
314
+ flag_if_shared(flag_method),
315
+ [
316
  share_result,
317
  input_file,
318
  extracted_text,
 
321
  information,
322
  contact,
323
  ],
324
+ None,
325
  preprocess=False,
 
326
  )
327
 
328
  hf_writer_incorrect.setup(
 
383
  queue=False,
384
  )
385
 
386
+ page.queue(
387
+ concurrency_count=1,
388
+ max_size=1,
389
+ )
390
  page.launch(show_api=True, show_error=True, debug=True)