mgbam commited on
Commit
c608949
·
verified ·
1 Parent(s): accacff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -39
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import json
 
2
  import requests
3
  import streamlit as st
4
  import pdfplumber
5
  import pandas as pd
6
  import sqlalchemy
7
- from typing import Any, Dict, List
8
 
9
  # Provider clients – ensure these libraries are installed
10
  try:
@@ -17,10 +18,10 @@ try:
17
  except ImportError:
18
  groq = None
19
 
20
- # Hugging Face inference endpoint
21
- HF_API_URL = "https://api-inference.huggingface.co/models/"
22
- DEFAULT_TEMPERATURE = 0.1
23
- GROQ_MODEL = "mixtral-8x7b-32768"
24
 
25
 
26
  class QADataGenerator:
@@ -32,8 +33,8 @@ class QADataGenerator:
32
  self._setup_providers()
33
  self._setup_input_handlers()
34
  self._initialize_session_state()
35
- # Updated prompt template with escaped curly braces
36
- self.custom_prompt_template = (
37
  "You are an expert in extracting question and answer pairs from documents. "
38
  "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
39
  "Each dictionary must have keys 'question' and 'answer'. "
@@ -69,7 +70,7 @@ class QADataGenerator:
69
 
70
  def _setup_input_handlers(self) -> None:
71
  """Register handlers for different input data types."""
72
- self.input_handlers: Dict[str, Any] = {
73
  "text": self.handle_text,
74
  "pdf": self.handle_pdf,
75
  "csv": self.handle_csv,
@@ -79,7 +80,7 @@ class QADataGenerator:
79
 
80
  def _initialize_session_state(self) -> None:
81
  """Initialize Streamlit session state with default configuration."""
82
- defaults = {
83
  "config": {
84
  "provider": "OpenAI",
85
  "model": "gpt-4-turbo",
@@ -101,30 +102,31 @@ class QADataGenerator:
101
 
102
  # ----- Input Handlers -----
103
  def handle_text(self, text: str) -> Dict[str, Any]:
 
104
  return {"data": text, "source": "text"}
105
 
106
  def handle_pdf(self, file) -> Dict[str, Any]:
 
107
  try:
108
  with pdfplumber.open(file) as pdf:
109
- full_text = ""
110
- for page in pdf.pages:
111
- page_text = page.extract_text() or ""
112
- full_text += page_text + "\n"
113
- return {"data": full_text, "source": "pdf"}
114
  except Exception as e:
115
  self.log_error(f"PDF Processing Error: {e}")
116
  return {"data": "", "source": "pdf"}
117
 
118
  def handle_csv(self, file) -> Dict[str, Any]:
 
119
  try:
120
  df = pd.read_csv(file)
121
- # Convert the DataFrame to a JSON string
122
- return {"data": df.to_json(orient="records"), "source": "csv"}
123
  except Exception as e:
124
  self.log_error(f"CSV Processing Error: {e}")
125
  return {"data": "", "source": "csv"}
126
 
127
  def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
 
128
  try:
129
  response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
130
  response.raise_for_status()
@@ -134,6 +136,7 @@ class QADataGenerator:
134
  return {"data": "", "source": "api"}
135
 
136
  def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
 
137
  try:
138
  engine = sqlalchemy.create_engine(config["connection"])
139
  with engine.connect() as conn:
@@ -166,22 +169,22 @@ class QADataGenerator:
166
  """
167
  Generate Q&A pairs by sending the built prompt to the selected LLM provider.
168
  """
169
- api_key = st.session_state.api_key
170
  if not api_key:
171
  self.log_error("API key is missing!")
172
  return False
173
 
174
- provider_name = st.session_state.config["provider"]
175
- provider_cfg = self.providers.get(provider_name)
176
  if not provider_cfg:
177
  self.log_error(f"Provider {provider_name} is not configured.")
178
  return False
179
 
180
- client_initializer = provider_cfg["client"]
181
  client = client_initializer(api_key)
182
- model = st.session_state.config["model"]
183
- temperature = st.session_state.config["temperature"]
184
- prompt = self.build_prompt()
185
 
186
  st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
187
  try:
@@ -238,36 +241,40 @@ class QADataGenerator:
238
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
239
  """
240
  Parse the LLM response and return a list of Q&A pairs.
241
- Expects the response to be JSON formatted.
 
242
  """
243
  st.write("Parsing response for provider:", provider)
244
  try:
 
245
  if provider == "HuggingFace":
246
- # For HuggingFace, assume the generated text is under "generated_text"
247
  if isinstance(response, list) and response and "generated_text" in response[0]:
248
  raw_text = response[0]["generated_text"]
249
  else:
250
  self.log_error("Unexpected HuggingFace response format.")
251
  return []
252
  else:
253
- # For OpenAI (and similar providers) assume the response is similar to:
254
- # response.choices[0].message.content
255
  if response and hasattr(response, "choices") and response.choices:
256
  raw_text = response.choices[0].message.content
257
  else:
258
  self.log_error("Unexpected response format from provider.")
259
  return []
260
 
261
- # Try parsing the raw text as JSON
262
  try:
263
  qa_list = json.loads(raw_text)
264
- if isinstance(qa_list, list):
265
- return qa_list
266
- else:
267
- self.log_error("Parsed output is not a list.")
268
- return []
269
  except json.JSONDecodeError as e:
270
- self.log_error(f"JSON Parsing Error: {e}. Raw output: {raw_text}")
 
 
 
 
 
 
 
 
 
 
271
  return []
272
  except Exception as e:
273
  self.log_error(f"Response Parsing Error: {e}")
@@ -276,7 +283,7 @@ class QADataGenerator:
276
 
277
  # ============ UI Components ============
278
 
279
- def config_ui(generator: QADataGenerator):
280
  """Display configuration options in the sidebar."""
281
  with st.sidebar:
282
  st.header("Configuration")
@@ -293,7 +300,7 @@ def config_ui(generator: QADataGenerator):
293
  api_key = st.text_input(f"{provider} API Key", type="password")
294
  st.session_state.api_key = api_key
295
 
296
- def input_ui(generator: QADataGenerator):
297
  """Display input data source options using tabs."""
298
  st.subheader("Input Data Sources")
299
  tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
@@ -339,7 +346,7 @@ def input_ui(generator: QADataGenerator):
339
  st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
340
  st.success("Database input added!")
341
 
342
- def output_ui(generator: QADataGenerator):
343
  """Display the generated Q&A pairs and provide a download option."""
344
  st.subheader("Q&A Pairs Output")
345
  if st.session_state.qa_pairs:
@@ -354,7 +361,7 @@ def output_ui(generator: QADataGenerator):
354
  else:
355
  st.info("No Q&A pairs generated yet.")
356
 
357
- def logs_ui():
358
  """Display error logs and debugging information in an expandable section."""
359
  with st.expander("Error Logs & Debug Info", expanded=False):
360
  if st.session_state.error_logs:
@@ -363,7 +370,8 @@ def logs_ui():
363
  else:
364
  st.write("No logs yet.")
365
 
366
- def main():
 
367
  st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
368
  st.title("Advanced Q&A Synthetic Generator")
369
  st.markdown(
 
1
  import json
2
+ import ast
3
  import requests
4
  import streamlit as st
5
  import pdfplumber
6
  import pandas as pd
7
  import sqlalchemy
8
+ from typing import Any, Dict, List, Callable
9
 
10
  # Provider clients – ensure these libraries are installed
11
  try:
 
18
  except ImportError:
19
  groq = None
20
 
21
+ # Hugging Face inference endpoint and defaults
22
+ HF_API_URL: str = "https://api-inference.huggingface.co/models/"
23
+ DEFAULT_TEMPERATURE: float = 0.1
24
+ GROQ_MODEL: str = "mixtral-8x7b-32768"
25
 
26
 
27
  class QADataGenerator:
 
33
  self._setup_providers()
34
  self._setup_input_handlers()
35
  self._initialize_session_state()
36
+ # Updated prompt template with escaped curly braces for literal output
37
+ self.custom_prompt_template: str = (
38
  "You are an expert in extracting question and answer pairs from documents. "
39
  "Generate 3 Q&A pairs from the following data, formatted as a JSON list of dictionaries. "
40
  "Each dictionary must have keys 'question' and 'answer'. "
 
70
 
71
  def _setup_input_handlers(self) -> None:
72
  """Register handlers for different input data types."""
73
+ self.input_handlers: Dict[str, Callable[[Any], Dict[str, Any]]] = {
74
  "text": self.handle_text,
75
  "pdf": self.handle_pdf,
76
  "csv": self.handle_csv,
 
80
 
81
  def _initialize_session_state(self) -> None:
82
  """Initialize Streamlit session state with default configuration."""
83
+ defaults: Dict[str, Any] = {
84
  "config": {
85
  "provider": "OpenAI",
86
  "model": "gpt-4-turbo",
 
102
 
103
  # ----- Input Handlers -----
104
  def handle_text(self, text: str) -> Dict[str, Any]:
105
+ """Process plain text input."""
106
  return {"data": text, "source": "text"}
107
 
108
  def handle_pdf(self, file) -> Dict[str, Any]:
109
+ """Extract text from a PDF file."""
110
  try:
111
  with pdfplumber.open(file) as pdf:
112
+ full_text = "\n".join(page.extract_text() or "" for page in pdf.pages)
113
+ return {"data": full_text, "source": "pdf"}
 
 
 
114
  except Exception as e:
115
  self.log_error(f"PDF Processing Error: {e}")
116
  return {"data": "", "source": "pdf"}
117
 
118
  def handle_csv(self, file) -> Dict[str, Any]:
119
+ """Process a CSV file by converting it to JSON."""
120
  try:
121
  df = pd.read_csv(file)
122
+ json_data = df.to_json(orient="records")
123
+ return {"data": json_data, "source": "csv"}
124
  except Exception as e:
125
  self.log_error(f"CSV Processing Error: {e}")
126
  return {"data": "", "source": "csv"}
127
 
128
  def handle_api(self, config: Dict[str, str]) -> Dict[str, Any]:
129
+ """Fetch data from an API endpoint."""
130
  try:
131
  response = requests.get(config["url"], headers=config.get("headers", {}), timeout=10)
132
  response.raise_for_status()
 
136
  return {"data": "", "source": "api"}
137
 
138
  def handle_db(self, config: Dict[str, str]) -> Dict[str, Any]:
139
+ """Query a database using the provided connection string and SQL query."""
140
  try:
141
  engine = sqlalchemy.create_engine(config["connection"])
142
  with engine.connect() as conn:
 
169
  """
170
  Generate Q&A pairs by sending the built prompt to the selected LLM provider.
171
  """
172
+ api_key: str = st.session_state.api_key
173
  if not api_key:
174
  self.log_error("API key is missing!")
175
  return False
176
 
177
+ provider_name: str = st.session_state.config["provider"]
178
+ provider_cfg: Dict[str, Any] = self.providers.get(provider_name, {})
179
  if not provider_cfg:
180
  self.log_error(f"Provider {provider_name} is not configured.")
181
  return False
182
 
183
+ client_initializer: Callable[[str], Any] = provider_cfg["client"]
184
  client = client_initializer(api_key)
185
+ model: str = st.session_state.config["model"]
186
+ temperature: float = st.session_state.config["temperature"]
187
+ prompt: str = self.build_prompt()
188
 
189
  st.info(f"Using **{provider_name}** with model **{model}** at temperature **{temperature:.2f}**")
190
  try:
 
241
  def _parse_response(self, response: Any, provider: str) -> List[Dict[str, str]]:
242
  """
243
  Parse the LLM response and return a list of Q&A pairs.
244
+ Expects the response to be JSON formatted; if JSON decoding fails,
245
+ tries to use ast.literal_eval as a fallback.
246
  """
247
  st.write("Parsing response for provider:", provider)
248
  try:
249
+ # For non-HuggingFace providers, extract the raw text from the response.
250
  if provider == "HuggingFace":
 
251
  if isinstance(response, list) and response and "generated_text" in response[0]:
252
  raw_text = response[0]["generated_text"]
253
  else:
254
  self.log_error("Unexpected HuggingFace response format.")
255
  return []
256
  else:
 
 
257
  if response and hasattr(response, "choices") and response.choices:
258
  raw_text = response.choices[0].message.content
259
  else:
260
  self.log_error("Unexpected response format from provider.")
261
  return []
262
 
263
+ # Attempt to parse using json.loads first.
264
  try:
265
  qa_list = json.loads(raw_text)
 
 
 
 
 
266
  except json.JSONDecodeError as e:
267
+ self.log_error(f"JSON Parsing Error: {e}. Attempting fallback with ast.literal_eval. Raw output: {raw_text}")
268
+ try:
269
+ qa_list = ast.literal_eval(raw_text)
270
+ except Exception as e2:
271
+ self.log_error(f"ast.literal_eval failed: {e2}")
272
+ return []
273
+
274
+ if isinstance(qa_list, list):
275
+ return qa_list
276
+ else:
277
+ self.log_error("Parsed output is not a list.")
278
  return []
279
  except Exception as e:
280
  self.log_error(f"Response Parsing Error: {e}")
 
283
 
284
  # ============ UI Components ============
285
 
286
+ def config_ui(generator: QADataGenerator) -> None:
287
  """Display configuration options in the sidebar."""
288
  with st.sidebar:
289
  st.header("Configuration")
 
300
  api_key = st.text_input(f"{provider} API Key", type="password")
301
  st.session_state.api_key = api_key
302
 
303
+ def input_ui(generator: QADataGenerator) -> None:
304
  """Display input data source options using tabs."""
305
  st.subheader("Input Data Sources")
306
  tabs = st.tabs(["Text", "PDF", "CSV", "API", "Database"])
 
346
  st.session_state.inputs.append(generator.handle_db({"connection": db_conn, "query": db_query}))
347
  st.success("Database input added!")
348
 
349
+ def output_ui(generator: QADataGenerator) -> None:
350
  """Display the generated Q&A pairs and provide a download option."""
351
  st.subheader("Q&A Pairs Output")
352
  if st.session_state.qa_pairs:
 
361
  else:
362
  st.info("No Q&A pairs generated yet.")
363
 
364
+ def logs_ui() -> None:
365
  """Display error logs and debugging information in an expandable section."""
366
  with st.expander("Error Logs & Debug Info", expanded=False):
367
  if st.session_state.error_logs:
 
370
  else:
371
  st.write("No logs yet.")
372
 
373
+ def main() -> None:
374
+ """Main Streamlit application entry point."""
375
  st.set_page_config(page_title="Advanced Q&A Synthetic Generator", layout="wide")
376
  st.title("Advanced Q&A Synthetic Generator")
377
  st.markdown(