mgbam commited on
Commit
f3076fc
·
verified ·
1 Parent(s): c608949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -33,10 +33,10 @@ class QADataGenerator:
33
  self._setup_providers()
34
  self._setup_input_handlers()
35
  self._initialize_session_state()
36
- # Updated prompt template with escaped curly braces for literal output
37
  self.custom_prompt_template: str = (
38
  "You are an expert in extracting question and answer pairs from documents. "
39
- "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
40
  "Each dictionary must have keys 'question' and 'answer'. "
41
  "The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
42
  "Do not hallucinate. \n\n"
@@ -44,7 +44,7 @@ class QADataGenerator:
44
  "[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
45
  "{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
46
  "{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
47
- "Now, generate 3 Q&A pairs from this data:\n{data}"
48
  )
49
 
50
  def _setup_providers(self) -> None:
@@ -85,10 +85,11 @@ class QADataGenerator:
85
  "provider": "OpenAI",
86
  "model": "gpt-4-turbo",
87
  "temperature": DEFAULT_TEMPERATURE,
 
88
  },
89
  "api_key": "",
90
  "inputs": [], # List to store input sources
91
- "qa_pairs": "", # Generated Q&A pairs output
92
  "error_logs": [], # To store any error messages
93
  }
94
  for key, value in defaults.items():
@@ -157,10 +158,12 @@ class QADataGenerator:
157
 
158
  def build_prompt(self) -> str:
159
  """
160
- Build the complete prompt using the custom template and aggregated inputs.
 
161
  """
162
  data = self.aggregate_inputs()
163
- prompt = self.custom_prompt_template.format(data=data)
 
164
  st.write("### Built Prompt")
165
  st.write(prompt)
166
  return prompt
@@ -242,11 +245,10 @@ class QADataGenerator:
242
  """
243
  Parse the LLM response and return a list of Q&A pairs.
244
  Expects the response to be JSON formatted; if JSON decoding fails,
245
- tries to use ast.literal_eval as a fallback.
246
  """
247
  st.write("Parsing response for provider:", provider)
248
  try:
249
- # For non-HuggingFace providers, extract the raw text from the response.
250
  if provider == "HuggingFace":
251
  if isinstance(response, list) and response and "generated_text" in response[0]:
252
  raw_text = response[0]["generated_text"]
@@ -260,7 +262,6 @@ class QADataGenerator:
260
  self.log_error("Unexpected response format from provider.")
261
  return []
262
 
263
- # Attempt to parse using json.loads first.
264
  try:
265
  qa_list = json.loads(raw_text)
266
  except json.JSONDecodeError as e:
@@ -297,6 +298,9 @@ def config_ui(generator: QADataGenerator) -> None:
297
  temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
298
  st.session_state.config["temperature"] = temperature
299
 
 
 
 
300
  api_key = st.text_input(f"{provider} API Key", type="password")
301
  st.session_state.api_key = api_key
302
 
@@ -347,17 +351,32 @@ def input_ui(generator: QADataGenerator) -> None:
347
  st.success("Database input added!")
348
 
349
  def output_ui(generator: QADataGenerator) -> None:
350
- """Display the generated Q&A pairs and provide a download option."""
351
  st.subheader("Q&A Pairs Output")
352
  if st.session_state.qa_pairs:
353
  st.write("### Generated Q&A Pairs")
354
  st.write(st.session_state.qa_pairs)
 
 
355
  st.download_button(
356
- "Download Output",
357
  json.dumps(st.session_state.qa_pairs, indent=2),
358
  file_name="qa_pairs.json",
359
  mime="application/json"
360
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  else:
362
  st.info("No Q&A pairs generated yet.")
363
 
 
33
  self._setup_providers()
34
  self._setup_input_handlers()
35
  self._initialize_session_state()
36
+ # Updated prompt template with dynamic {num_examples} parameter and escaped curly braces
37
  self.custom_prompt_template: str = (
38
  "You are an expert in extracting question and answer pairs from documents. "
39
+ "Generate {num_examples} Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
40
  "Each dictionary must have keys 'question' and 'answer'. "
41
  "The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
42
  "Do not hallucinate. \n\n"
 
44
  "[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
45
  "{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
46
  "{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
47
+ "Now, generate {num_examples} Q&A pairs from this data:\n{data}"
48
  )
49
 
50
  def _setup_providers(self) -> None:
 
85
  "provider": "OpenAI",
86
  "model": "gpt-4-turbo",
87
  "temperature": DEFAULT_TEMPERATURE,
88
+ "num_examples": 3, # Default number of Q&A pairs
89
  },
90
  "api_key": "",
91
  "inputs": [], # List to store input sources
92
+ "qa_pairs": None, # Generated Q&A pairs output
93
  "error_logs": [], # To store any error messages
94
  }
95
  for key, value in defaults.items():
 
158
 
159
  def build_prompt(self) -> str:
160
  """
161
+ Build the complete prompt using the custom template, aggregated inputs,
162
+ and the number of examples.
163
  """
164
  data = self.aggregate_inputs()
165
+ num_examples = st.session_state.config.get("num_examples", 3)
166
+ prompt = self.custom_prompt_template.format(data=data, num_examples=num_examples)
167
  st.write("### Built Prompt")
168
  st.write(prompt)
169
  return prompt
 
245
  """
246
  Parse the LLM response and return a list of Q&A pairs.
247
  Expects the response to be JSON formatted; if JSON decoding fails,
248
+ uses ast.literal_eval as a fallback.
249
  """
250
  st.write("Parsing response for provider:", provider)
251
  try:
 
252
  if provider == "HuggingFace":
253
  if isinstance(response, list) and response and "generated_text" in response[0]:
254
  raw_text = response[0]["generated_text"]
 
262
  self.log_error("Unexpected response format from provider.")
263
  return []
264
 
 
265
  try:
266
  qa_list = json.loads(raw_text)
267
  except json.JSONDecodeError as e:
 
298
  temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
299
  st.session_state.config["temperature"] = temperature
300
 
301
+ num_examples = st.number_input("Number of Q&A Pairs", min_value=1, max_value=10, value=3, step=1)
302
+ st.session_state.config["num_examples"] = num_examples
303
+
304
  api_key = st.text_input(f"{provider} API Key", type="password")
305
  st.session_state.api_key = api_key
306
 
 
351
  st.success("Database input added!")
352
 
353
  def output_ui(generator: QADataGenerator) -> None:
354
+ """Display the generated Q&A pairs and provide download options."""
355
  st.subheader("Q&A Pairs Output")
356
  if st.session_state.qa_pairs:
357
  st.write("### Generated Q&A Pairs")
358
  st.write(st.session_state.qa_pairs)
359
+
360
+ # Download as JSON
361
  st.download_button(
362
+ "Download as JSON",
363
  json.dumps(st.session_state.qa_pairs, indent=2),
364
  file_name="qa_pairs.json",
365
  mime="application/json"
366
  )
367
+
368
+ # Download as CSV
369
+ try:
370
+ df = pd.DataFrame(st.session_state.qa_pairs)
371
+ csv_data = df.to_csv(index=False)
372
+ st.download_button(
373
+ "Download as CSV",
374
+ csv_data,
375
+ file_name="qa_pairs.csv",
376
+ mime="text/csv"
377
+ )
378
+ except Exception as e:
379
+ st.error(f"Error generating CSV: {e}")
380
  else:
381
  st.info("No Q&A pairs generated yet.")
382