Kevin Wu commited on
Commit
0e0266f
·
1 Parent(s): 6adea60
Files changed (1) hide show
  1. app.py +52 -33
app.py CHANGED
@@ -9,6 +9,7 @@ import re
9
  import pandas as pd
10
  import prompts
11
  import traceback
 
12
 
13
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
14
 
@@ -31,15 +32,20 @@ def parse_xml_response(xml_string: str) -> pd.DataFrame:
31
  then convert it to a pandas DataFrame with a nested index.
32
  """
33
  try:
34
- # Extract only the XML content between the first and last tags
35
- xml_content = re.search(r'<.*?>.*</.*?>', xml_string, re.DOTALL)
36
  if xml_content:
37
  xml_string = xml_content.group(0)
38
  else:
39
  print("No valid XML content found.")
40
  return pd.DataFrame()
41
 
42
- root = ET.fromstring(xml_string)
 
 
 
 
 
43
 
44
  result = {}
45
 
@@ -113,35 +119,40 @@ def parse_xml_response(xml_string: str) -> pd.DataFrame:
113
  print(f"Traceback: {traceback.format_exc()}")
114
  return pd.DataFrame()
115
 
116
- def get_response(file_id, assistant_id):
117
- try:
118
- thread = client.beta.threads.create(
119
- messages=[
120
- {
121
- "role": "user",
122
- "content": prompts.info_prompt,
123
- "attachments": [
124
- {"file_id": file_id, "tools": [{"type": "file_search"}]}
125
- ],
126
- }
127
- ]
128
- )
129
- run = client.beta.threads.runs.create_and_poll(
130
- thread_id=thread.id, assistant_id=assistant_id
131
- )
132
- messages = list(
133
- client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
134
- )
135
- assert len(messages) == 1, f"Expected 1 message, got {len(messages)}"
136
- message_content = messages[0].content[0].text
137
- annotations = message_content.annotations
138
- for index, annotation in enumerate(annotations):
139
- message_content.value = message_content.value.replace(annotation.text, f"")
140
- return message_content.value
141
- except Exception as e:
142
- print(f"Error in get_response: {str(e)}")
143
- print(f"Traceback: {traceback.format_exc()}")
144
- raise
 
 
 
 
 
145
 
146
  def process(file_content):
147
  try:
@@ -153,9 +164,17 @@ def process(file_content):
153
 
154
  message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants")
155
 
156
- response = get_response(message_file.id, demo.id)
157
  df = parse_xml_response(response)
158
 
 
 
 
 
 
 
 
 
159
  if df.empty:
160
  return "<p>No valid information could be extracted from the provided file.</p>"
161
 
 
9
  import pandas as pd
10
  import prompts
11
  import traceback
12
+ from io import StringIO
13
 
14
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
15
 
 
32
  then convert it to a pandas DataFrame with a nested index.
33
  """
34
  try:
35
+ # Extract only the XML content between the outermost tags
36
+ xml_content = re.search(r'<[^>]+>.*</[^>]+>', xml_string, re.DOTALL)
37
  if xml_content:
38
  xml_string = xml_content.group(0)
39
  else:
40
  print("No valid XML content found.")
41
  return pd.DataFrame()
42
 
43
+ # Wrap the content in a root element to ensure there's only one root
44
+ xml_string = f"<root>{xml_string}</root>"
45
+
46
+ # Parse the XML
47
+ parser = ET.XMLParser(recover=True) # This allows for more lenient parsing
48
+ root = ET.fromstring(xml_string, parser=parser)
49
 
50
  result = {}
51
 
 
119
  print(f"Traceback: {traceback.format_exc()}")
120
  return pd.DataFrame()
121
 
122
+ def get_response(file_id, assistant_id, max_retries=3):
123
+ for attempt in range(max_retries):
124
+ try:
125
+ thread = client.beta.threads.create(
126
+ messages=[
127
+ {
128
+ "role": "user",
129
+ "content": prompts.info_prompt,
130
+ "attachments": [
131
+ {"file_id": file_id, "tools": [{"type": "file_search"}]}
132
+ ],
133
+ }
134
+ ]
135
+ )
136
+ run = client.beta.threads.runs.create_and_poll(
137
+ thread_id=thread.id, assistant_id=assistant_id
138
+ )
139
+ messages = list(
140
+ client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
141
+ )
142
+ assert len(messages) == 1, f"Expected 1 message, got {len(messages)}"
143
+ message_content = messages[0].content[0].text
144
+ annotations = message_content.annotations
145
+ for index, annotation in enumerate(annotations):
146
+ message_content.value = message_content.value.replace(annotation.text, f"")
147
+ return message_content.value
148
+ except Exception as e:
149
+ print(f"Error in get_response (attempt {attempt + 1}): {str(e)}")
150
+ print(f"Traceback: {traceback.format_exc()}")
151
+ if attempt < max_retries - 1:
152
+ print(f"Retrying in 5 seconds...")
153
+ time.sleep(5)
154
+ else:
155
+ raise Exception("Max retries reached. Unable to get response from the model.")
156
 
157
  def process(file_content):
158
  try:
 
164
 
165
  message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants")
166
 
167
+ response = get_response(message_file.id, demo.id) # This now includes retry logic
168
  df = parse_xml_response(response)
169
 
170
+ # ... (rest of the function remains the same)
171
+
172
+ except Exception as e:
173
+ error_message = f"An error occurred while processing the file: {str(e)}"
174
+ print(error_message)
175
+ print(f"Traceback: {traceback.format_exc()}")
176
+ return f"<p>{error_message}</p>"
177
+
178
  if df.empty:
179
  return "<p>No valid information could be extracted from the provided file.</p>"
180