Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -33,10 +33,10 @@ class QADataGenerator:
|
|
33 |
self._setup_providers()
|
34 |
self._setup_input_handlers()
|
35 |
self._initialize_session_state()
|
36 |
-
# Updated prompt template with escaped curly braces
|
37 |
self.custom_prompt_template: str = (
|
38 |
"You are an expert in extracting question and answer pairs from documents. "
|
39 |
-
"Generate
|
40 |
"Each dictionary must have keys 'question' and 'answer'. "
|
41 |
"The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
|
42 |
"Do not hallucinate. \n\n"
|
@@ -44,7 +44,7 @@ class QADataGenerator:
|
|
44 |
"[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
|
45 |
"{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
|
46 |
"{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
|
47 |
-
"Now, generate
|
48 |
)
|
49 |
|
50 |
def _setup_providers(self) -> None:
|
@@ -85,10 +85,11 @@ class QADataGenerator:
|
|
85 |
"provider": "OpenAI",
|
86 |
"model": "gpt-4-turbo",
|
87 |
"temperature": DEFAULT_TEMPERATURE,
|
|
|
88 |
},
|
89 |
"api_key": "",
|
90 |
"inputs": [], # List to store input sources
|
91 |
-
"qa_pairs":
|
92 |
"error_logs": [], # To store any error messages
|
93 |
}
|
94 |
for key, value in defaults.items():
|
@@ -157,10 +158,12 @@ class QADataGenerator:
|
|
157 |
|
158 |
def build_prompt(self) -> str:
|
159 |
"""
|
160 |
-
Build the complete prompt using the custom template
|
|
|
161 |
"""
|
162 |
data = self.aggregate_inputs()
|
163 |
-
|
|
|
164 |
st.write("### Built Prompt")
|
165 |
st.write(prompt)
|
166 |
return prompt
|
@@ -242,11 +245,10 @@ class QADataGenerator:
|
|
242 |
"""
|
243 |
Parse the LLM response and return a list of Q&A pairs.
|
244 |
Expects the response to be JSON formatted; if JSON decoding fails,
|
245 |
-
|
246 |
"""
|
247 |
st.write("Parsing response for provider:", provider)
|
248 |
try:
|
249 |
-
# For non-HuggingFace providers, extract the raw text from the response.
|
250 |
if provider == "HuggingFace":
|
251 |
if isinstance(response, list) and response and "generated_text" in response[0]:
|
252 |
raw_text = response[0]["generated_text"]
|
@@ -260,7 +262,6 @@ class QADataGenerator:
|
|
260 |
self.log_error("Unexpected response format from provider.")
|
261 |
return []
|
262 |
|
263 |
-
# Attempt to parse using json.loads first.
|
264 |
try:
|
265 |
qa_list = json.loads(raw_text)
|
266 |
except json.JSONDecodeError as e:
|
@@ -297,6 +298,9 @@ def config_ui(generator: QADataGenerator) -> None:
|
|
297 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
298 |
st.session_state.config["temperature"] = temperature
|
299 |
|
|
|
|
|
|
|
300 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
301 |
st.session_state.api_key = api_key
|
302 |
|
@@ -347,17 +351,32 @@ def input_ui(generator: QADataGenerator) -> None:
|
|
347 |
st.success("Database input added!")
|
348 |
|
349 |
def output_ui(generator: QADataGenerator) -> None:
|
350 |
-
"""Display the generated Q&A pairs and provide
|
351 |
st.subheader("Q&A Pairs Output")
|
352 |
if st.session_state.qa_pairs:
|
353 |
st.write("### Generated Q&A Pairs")
|
354 |
st.write(st.session_state.qa_pairs)
|
|
|
|
|
355 |
st.download_button(
|
356 |
-
"Download
|
357 |
json.dumps(st.session_state.qa_pairs, indent=2),
|
358 |
file_name="qa_pairs.json",
|
359 |
mime="application/json"
|
360 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
else:
|
362 |
st.info("No Q&A pairs generated yet.")
|
363 |
|
|
|
33 |
self._setup_providers()
|
34 |
self._setup_input_handlers()
|
35 |
self._initialize_session_state()
|
36 |
+
# Updated prompt template with dynamic {num_examples} parameter and escaped curly braces
|
37 |
self.custom_prompt_template: str = (
|
38 |
"You are an expert in extracting question and answer pairs from documents. "
|
39 |
+
"Generate {num_examples} Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
|
40 |
"Each dictionary must have keys 'question' and 'answer'. "
|
41 |
"The questions should be clear and concise, and the answers must be based solely on the provided data with no external information. "
|
42 |
"Do not hallucinate. \n\n"
|
|
|
44 |
"[{{'question': 'What is the capital of France?', 'answer': 'Paris'}}, "
|
45 |
"{{'question': 'What is the highest mountain in the world?', 'answer': 'Mount Everest'}}, "
|
46 |
"{{'question': 'What is the chemical symbol for gold?', 'answer': 'Au'}}]\n\n"
|
47 |
+
"Now, generate {num_examples} Q&A pairs from this data:\n{data}"
|
48 |
)
|
49 |
|
50 |
def _setup_providers(self) -> None:
|
|
|
85 |
"provider": "OpenAI",
|
86 |
"model": "gpt-4-turbo",
|
87 |
"temperature": DEFAULT_TEMPERATURE,
|
88 |
+
"num_examples": 3, # Default number of Q&A pairs
|
89 |
},
|
90 |
"api_key": "",
|
91 |
"inputs": [], # List to store input sources
|
92 |
+
"qa_pairs": None, # Generated Q&A pairs output
|
93 |
"error_logs": [], # To store any error messages
|
94 |
}
|
95 |
for key, value in defaults.items():
|
|
|
158 |
|
159 |
def build_prompt(self) -> str:
|
160 |
"""
|
161 |
+
Build the complete prompt using the custom template, aggregated inputs,
|
162 |
+
and the number of examples.
|
163 |
"""
|
164 |
data = self.aggregate_inputs()
|
165 |
+
num_examples = st.session_state.config.get("num_examples", 3)
|
166 |
+
prompt = self.custom_prompt_template.format(data=data, num_examples=num_examples)
|
167 |
st.write("### Built Prompt")
|
168 |
st.write(prompt)
|
169 |
return prompt
|
|
|
245 |
"""
|
246 |
Parse the LLM response and return a list of Q&A pairs.
|
247 |
Expects the response to be JSON formatted; if JSON decoding fails,
|
248 |
+
uses ast.literal_eval as a fallback.
|
249 |
"""
|
250 |
st.write("Parsing response for provider:", provider)
|
251 |
try:
|
|
|
252 |
if provider == "HuggingFace":
|
253 |
if isinstance(response, list) and response and "generated_text" in response[0]:
|
254 |
raw_text = response[0]["generated_text"]
|
|
|
262 |
self.log_error("Unexpected response format from provider.")
|
263 |
return []
|
264 |
|
|
|
265 |
try:
|
266 |
qa_list = json.loads(raw_text)
|
267 |
except json.JSONDecodeError as e:
|
|
|
298 |
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
299 |
st.session_state.config["temperature"] = temperature
|
300 |
|
301 |
+
num_examples = st.number_input("Number of Q&A Pairs", min_value=1, max_value=10, value=3, step=1)
|
302 |
+
st.session_state.config["num_examples"] = num_examples
|
303 |
+
|
304 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
305 |
st.session_state.api_key = api_key
|
306 |
|
|
|
351 |
st.success("Database input added!")
|
352 |
|
353 |
def output_ui(generator: QADataGenerator) -> None:
|
354 |
+
"""Display the generated Q&A pairs and provide download options."""
|
355 |
st.subheader("Q&A Pairs Output")
|
356 |
if st.session_state.qa_pairs:
|
357 |
st.write("### Generated Q&A Pairs")
|
358 |
st.write(st.session_state.qa_pairs)
|
359 |
+
|
360 |
+
# Download as JSON
|
361 |
st.download_button(
|
362 |
+
"Download as JSON",
|
363 |
json.dumps(st.session_state.qa_pairs, indent=2),
|
364 |
file_name="qa_pairs.json",
|
365 |
mime="application/json"
|
366 |
)
|
367 |
+
|
368 |
+
# Download as CSV
|
369 |
+
try:
|
370 |
+
df = pd.DataFrame(st.session_state.qa_pairs)
|
371 |
+
csv_data = df.to_csv(index=False)
|
372 |
+
st.download_button(
|
373 |
+
"Download as CSV",
|
374 |
+
csv_data,
|
375 |
+
file_name="qa_pairs.csv",
|
376 |
+
mime="text/csv"
|
377 |
+
)
|
378 |
+
except Exception as e:
|
379 |
+
st.error(f"Error generating CSV: {e}")
|
380 |
else:
|
381 |
st.info("No Q&A pairs generated yet.")
|
382 |
|