mgbam commited on
Commit
945d7f4
·
verified ·
1 Parent(s): 1148bbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import requests
3
  import streamlit as st
@@ -32,19 +33,20 @@ class QADataGenerator:
32
  self._setup_providers()
33
  self._setup_input_handlers()
34
  self._initialize_session_state()
35
- # This prompt instructs the LLM to generate three Q&A pairs.
36
- # Note: Literal curly braces in the example are escaped with double braces.
37
  self.custom_prompt_template = (
38
  "You are an expert in extracting question and answer pairs from documents. "
39
- "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
40
  "Each dictionary must have keys 'question' and 'answer'. "
41
  "The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
42
  "Do not hallucinate.\n\n"
43
- "Example JSON Output:\n"
44
- "[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
45
- "{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
46
- "{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
47
- "Now, generate 3 Q&A pairs from this data:\n{data}"
 
48
  )
49
 
50
  def _setup_providers(self) -> None:
@@ -85,6 +87,7 @@ class QADataGenerator:
85
  "provider": "OpenAI",
86
  "model": "gpt-4-turbo",
87
  "temperature": DEFAULT_TEMPERATURE,
 
88
  },
89
  "api_key": "",
90
  "inputs": [], # List to store input sources
@@ -156,9 +159,11 @@ class QADataGenerator:
156
  def build_prompt(self) -> str:
157
  """
158
  Build the complete prompt using the custom template and aggregated inputs.
 
159
  """
160
  data = self.aggregate_inputs()
161
- prompt = self.custom_prompt_template.format(data=data)
 
162
  st.write("### Built Prompt")
163
  st.write(prompt)
164
  return prompt
@@ -239,27 +244,25 @@ class QADataGenerator:
239
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
240
  """
241
  Parse the LLM response and return a list of Q&A pairs.
242
- Expects the response to be JSON formatted.
 
243
  """
244
  st.write("Parsing response for provider:", provider)
245
  try:
246
  if provider == "HuggingFace":
247
- # For HuggingFace, assume the generated text is under "generated_text"
248
  if isinstance(response, list) and response and "generated_text" in response[0]:
249
  raw_text = response[0]["generated_text"]
250
  else:
251
  self.log_error("Unexpected HuggingFace response format.")
252
  return []
253
  else:
254
- # For OpenAI (and similar providers), assume the response is similar to:
255
- # response.choices[0].message.content
256
  if response and hasattr(response, "choices") and response.choices:
257
  raw_text = response.choices[0].message.content
258
  else:
259
  self.log_error("Unexpected response format from provider.")
260
  return []
261
 
262
- # Try parsing the raw text as JSON
263
  try:
264
  qa_list = json.loads(raw_text)
265
  if isinstance(qa_list, list):
@@ -267,9 +270,18 @@ class QADataGenerator:
267
  else:
268
  self.log_error("Parsed output is not a list.")
269
  return []
270
- except json.JSONDecodeError as e:
271
- self.log_error(f"JSON Parsing Error: {e}. Raw output: {raw_text}")
272
- return []
 
 
 
 
 
 
 
 
 
273
  except Exception as e:
274
  self.log_error(f"Response Parsing Error: {e}")
275
  return []
@@ -291,6 +303,10 @@ def config_ui(generator: QADataGenerator):
291
  temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
292
  st.session_state.config["temperature"] = temperature
293
 
 
 
 
 
294
  api_key = st.text_input(f"{provider} API Key", type="password")
295
  st.session_state.api_key = api_key
296
 
@@ -401,4 +417,4 @@ def main():
401
 
402
 
403
  if __name__ == "__main__":
404
- main()
 
1
+ import ast
2
  import json
3
  import requests
4
  import streamlit as st
 
33
  self._setup_providers()
34
  self._setup_input_handlers()
35
  self._initialize_session_state()
36
+ # This prompt instructs the LLM to generate a configurable number of Q&A pairs.
37
+ # Note: Literal curly braces for the example are escaped with double braces.
38
  self.custom_prompt_template = (
39
  "You are an expert in extracting question and answer pairs from documents. "
40
+ "Generate {num_pairs} Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
41
  "Each dictionary must have keys 'question' and 'answer'. "
42
  "The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
43
  "Do not hallucinate.\n\n"
44
+ "Example JSON Output for {num_pairs} pairs:\n"
45
+ "[{{'question': 'Example question 1', 'answer': 'Example answer 1'}}, "
46
+ "{{'question': 'Example question 2', 'answer': 'Example answer 2'}}, "
47
+ "..., "
48
+ "{{'question': 'Example question {num_pairs}', 'answer': 'Example answer {num_pairs}'}}]\n\n"
49
+ "Now, generate {num_pairs} Q&A pairs from this data:\n{data}"
50
  )
51
 
52
  def _setup_providers(self) -> None:
 
87
  "provider": "OpenAI",
88
  "model": "gpt-4-turbo",
89
  "temperature": DEFAULT_TEMPERATURE,
90
+ "num_pairs": 3, # Default to 3 Q&A pairs
91
  },
92
  "api_key": "",
93
  "inputs": [], # List to store input sources
 
159
  def build_prompt(self) -> str:
160
  """
161
  Build the complete prompt using the custom template and aggregated inputs.
162
+ The number of Q&A pairs is inserted via the {num_pairs} placeholder.
163
  """
164
  data = self.aggregate_inputs()
165
+ num_pairs = st.session_state.config.get("num_pairs", 3)
166
+ prompt = self.custom_prompt_template.format(data=data, num_pairs=num_pairs)
167
  st.write("### Built Prompt")
168
  st.write(prompt)
169
  return prompt
 
244
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
245
  """
246
  Parse the LLM response and return a list of Q&A pairs.
247
+ Expects the response to be in a JSON-like format.
248
+ If JSON parsing fails (e.g. due to single quotes), fall back to ast.literal_eval.
249
  """
250
  st.write("Parsing response for provider:", provider)
251
  try:
252
  if provider == "HuggingFace":
 
253
  if isinstance(response, list) and response and "generated_text" in response[0]:
254
  raw_text = response[0]["generated_text"]
255
  else:
256
  self.log_error("Unexpected HuggingFace response format.")
257
  return []
258
  else:
 
 
259
  if response and hasattr(response, "choices") and response.choices:
260
  raw_text = response.choices[0].message.content
261
  else:
262
  self.log_error("Unexpected response format from provider.")
263
  return []
264
 
265
+ # Try parsing as JSON first
266
  try:
267
  qa_list = json.loads(raw_text)
268
  if isinstance(qa_list, list):
 
270
  else:
271
  self.log_error("Parsed output is not a list.")
272
  return []
273
+ except json.JSONDecodeError:
274
+ st.write("Standard JSON parsing failed. Falling back to ast.literal_eval...")
275
+ try:
276
+ qa_list = ast.literal_eval(raw_text)
277
+ if isinstance(qa_list, list):
278
+ return qa_list
279
+ else:
280
+ self.log_error("Parsed output using ast.literal_eval is not a list.")
281
+ return []
282
+ except Exception as e:
283
+ self.log_error(f"ast.literal_eval parsing error: {e}. Raw output: {raw_text}")
284
+ return []
285
  except Exception as e:
286
  self.log_error(f"Response Parsing Error: {e}")
287
  return []
 
303
  temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
304
  st.session_state.config["temperature"] = temperature
305
 
306
+ # New: Number of Q&A pairs
307
+ num_pairs = st.number_input("Number of Q&A Pairs", min_value=1, max_value=20, value=3, step=1)
308
+ st.session_state.config["num_pairs"] = num_pairs
309
+
310
  api_key = st.text_input(f"{provider} API Key", type="password")
311
  st.session_state.api_key = api_key
312
 
 
417
 
418
 
419
  if __name__ == "__main__":
420
+ main()