Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,22 @@ import json
|
|
7 |
from PIL import Image
|
8 |
from io import BytesIO
|
9 |
from openai import OpenAI
|
|
|
10 |
import groq
|
11 |
import sqlalchemy
|
12 |
from typing import Dict, Any
|
13 |
|
|
|
|
|
|
|
|
|
14 |
class SyntheticDataGenerator:
|
|
|
|
|
|
|
|
|
15 |
def __init__(self):
|
|
|
16 |
self.providers = {
|
17 |
"Deepseek": {
|
18 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
|
@@ -29,9 +39,13 @@ class SyntheticDataGenerator:
|
|
29 |
"HuggingFace": {
|
30 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
31 |
"models": ["gpt2", "llama-2"]
|
32 |
-
}
|
|
|
|
|
|
|
|
|
33 |
}
|
34 |
-
|
35 |
self.input_handlers = {
|
36 |
"pdf": self.handle_pdf,
|
37 |
"text": self.handle_text,
|
@@ -39,10 +53,20 @@ class SyntheticDataGenerator:
|
|
39 |
"api": self.handle_api,
|
40 |
"db": self.handle_db
|
41 |
}
|
42 |
-
|
43 |
self.init_session()
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def init_session(self):
|
|
|
46 |
session_defaults = {
|
47 |
'inputs': [],
|
48 |
'qa_data': [],
|
@@ -54,34 +78,42 @@ class SyntheticDataGenerator:
|
|
54 |
'config': {
|
55 |
'provider': "Deepseek",
|
56 |
'model': "deepseek-chat",
|
57 |
-
'temperature':
|
58 |
}
|
59 |
}
|
60 |
-
|
61 |
for key, val in session_defaults.items():
|
62 |
if key not in st.session_state:
|
63 |
st.session_state[key] = val
|
64 |
|
65 |
# Input Processors
|
66 |
def handle_pdf(self, file):
|
67 |
-
|
|
|
68 |
with pdfplumber.open(file) as pdf:
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def handle_text(self, text):
|
|
|
79 |
return [{
|
80 |
"text": text,
|
81 |
"meta": {"type": "domain", "source": "manual"}
|
82 |
}]
|
83 |
|
84 |
def handle_csv(self, file):
|
|
|
85 |
try:
|
86 |
df = pd.read_csv(file)
|
87 |
return [{
|
@@ -93,17 +125,21 @@ class SyntheticDataGenerator:
|
|
93 |
return []
|
94 |
|
95 |
def handle_api(self, config):
|
|
|
96 |
try:
|
97 |
response = requests.get(config['url'], headers=config['headers'])
|
|
|
98 |
return [{
|
99 |
"text": json.dumps(response.json()),
|
100 |
"meta": {"type": "api", "endpoint": config['url']}
|
101 |
}]
|
102 |
-
except
|
103 |
self.log_error(f"API Error: {str(e)}")
|
104 |
return []
|
105 |
|
|
|
106 |
def handle_db(self, config):
|
|
|
107 |
try:
|
108 |
engine = sqlalchemy.create_engine(config['connection'])
|
109 |
with engine.connect() as conn:
|
@@ -117,6 +153,7 @@ class SyntheticDataGenerator:
|
|
117 |
return []
|
118 |
|
119 |
def process_images(self, page):
|
|
|
120 |
images = []
|
121 |
for img in page.images:
|
122 |
try:
|
@@ -134,130 +171,237 @@ class SyntheticDataGenerator:
|
|
134 |
|
135 |
# Core Generation Engine
|
136 |
def generate(self, api_key: str) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
try:
|
138 |
provider_cfg = self.providers[st.session_state.config['provider']]
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
for i, input_data in enumerate(st.session_state.inputs):
|
142 |
st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
|
143 |
-
|
144 |
if st.session_state.config['provider'] == "HuggingFace":
|
145 |
response = self._huggingface_inference(client, input_data)
|
|
|
|
|
146 |
else:
|
147 |
response = self._standard_inference(client, input_data)
|
148 |
-
|
149 |
if response:
|
150 |
-
|
151 |
-
|
|
|
152 |
return True
|
153 |
except Exception as e:
|
154 |
self.log_error(f"Generation Error: {str(e)}")
|
155 |
return False
|
156 |
|
157 |
def _standard_inference(self, client, input_data):
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
def _huggingface_inference(self, client, input_data):
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def _build_prompt(self, input_data):
|
178 |
-
|
|
|
179 |
if input_data['meta']['type'] == 'csv':
|
180 |
return base + "Structured data:\n" + input_data['text']
|
181 |
elif input_data['meta']['type'] == 'api':
|
182 |
return base + "API response:\n" + input_data['text']
|
183 |
return base + input_data['text']
|
184 |
|
185 |
-
def _parse_response(self, response):
|
|
|
186 |
try:
|
187 |
-
if
|
188 |
return response[0]['generated_text']
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
except Exception as e:
|
191 |
-
self.log_error(f"Parse Error: {
|
192 |
return []
|
193 |
|
194 |
def log_error(self, message):
|
|
|
195 |
st.session_state.processing['errors'].append(message)
|
196 |
st.error(message)
|
197 |
|
198 |
# Streamlit UI Components
|
199 |
def input_sidebar(gen: SyntheticDataGenerator):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
with st.sidebar:
|
201 |
st.header("⚙️ Configuration")
|
202 |
-
|
203 |
# AI Provider Settings
|
204 |
provider = st.selectbox("Provider", list(gen.providers.keys()))
|
205 |
provider_cfg = gen.providers[provider]
|
206 |
-
|
207 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
|
|
|
|
208 |
model = st.selectbox("Model", provider_cfg["models"])
|
209 |
-
temp = st.slider("Temperature", 0.0, 1.0,
|
210 |
-
|
211 |
# Update session config
|
212 |
st.session_state.config.update({
|
213 |
"provider": provider,
|
214 |
"model": model,
|
215 |
"temperature": temp
|
216 |
})
|
217 |
-
|
218 |
# Input Source Selection
|
219 |
st.header("🔗 Data Sources")
|
220 |
input_type = st.selectbox("Input Type", list(gen.input_handlers.keys()))
|
221 |
-
|
222 |
if input_type == "text":
|
223 |
domain_input = st.text_area("Domain Knowledge", height=150)
|
224 |
if st.button("Add Domain Input"):
|
225 |
-
gen.input_handlers["text"](domain_input)
|
226 |
-
|
227 |
elif input_type == "csv":
|
228 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
229 |
if csv_file:
|
230 |
-
|
231 |
-
|
232 |
elif input_type == "api":
|
233 |
api_url = st.text_input("API Endpoint")
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
return api_key
|
238 |
|
239 |
def main_display(gen: SyntheticDataGenerator):
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
241 |
-
|
242 |
# Input Processing
|
243 |
col1, col2 = st.columns([3, 1])
|
244 |
with col1:
|
245 |
pdf_file = st.file_uploader("Upload Document", type=["pdf"])
|
246 |
if pdf_file:
|
247 |
-
|
248 |
-
|
249 |
# Generation Controls
|
250 |
with col2:
|
251 |
if st.button("Start Generation"):
|
252 |
with st.status("Processing..."):
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
255 |
# Results Display
|
256 |
if st.session_state.qa_data:
|
257 |
st.header("Generated Data")
|
258 |
df = pd.DataFrame(st.session_state.qa_data)
|
259 |
st.dataframe(df)
|
260 |
-
|
261 |
# Export Options
|
262 |
st.download_button(
|
263 |
"Export CSV",
|
@@ -266,6 +410,7 @@ def main_display(gen: SyntheticDataGenerator):
|
|
266 |
)
|
267 |
|
268 |
def main():
|
|
|
269 |
gen = SyntheticDataGenerator()
|
270 |
api_key = input_sidebar(gen)
|
271 |
main_display(gen)
|
|
|
7 |
from PIL import Image
|
8 |
from io import BytesIO
|
9 |
from openai import OpenAI
|
10 |
+
import google.generativeai as genai # Added Google GenAI
|
11 |
import groq
|
12 |
import sqlalchemy
|
13 |
from typing import Dict, Any
|
14 |
|
15 |
+
# Constants for Default Values and API URLs
|
16 |
+
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
17 |
+
DEFAULT_TEMPERATURE = 0.3
|
18 |
+
|
19 |
class SyntheticDataGenerator:
|
20 |
+
"""
|
21 |
+
A class to generate synthetic Q&A data from various input sources using different LLM providers.
|
22 |
+
"""
|
23 |
+
|
24 |
def __init__(self):
|
25 |
+
"""Initializes the SyntheticDataGenerator with supported providers, input handlers, and session state."""
|
26 |
self.providers = {
|
27 |
"Deepseek": {
|
28 |
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
|
|
|
39 |
"HuggingFace": {
|
40 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
41 |
"models": ["gpt2", "llama-2"]
|
42 |
+
},
|
43 |
+
"Google": {
|
44 |
+
"client": lambda key: self._configure_google_genai(key), # Using a custom configure function
|
45 |
+
"models": ["gemini-2.0-pro"] # Add supported Gemini models. Consider adding "gemini-1.5-pro" when released.
|
46 |
+
},
|
47 |
}
|
48 |
+
|
49 |
self.input_handlers = {
|
50 |
"pdf": self.handle_pdf,
|
51 |
"text": self.handle_text,
|
|
|
53 |
"api": self.handle_api,
|
54 |
"db": self.handle_db
|
55 |
}
|
56 |
+
|
57 |
self.init_session()
|
58 |
|
59 |
+
def _configure_google_genai(self, api_key: str):
|
60 |
+
"""Configures the Google Generative AI client."""
|
61 |
+
try:
|
62 |
+
genai.configure(api_key=api_key)
|
63 |
+
return genai.GenerativeModel # return the model class, not an instantiation
|
64 |
+
except Exception as e:
|
65 |
+
st.error(f"Error configuring Google GenAI: {e}")
|
66 |
+
return None # Important: Handle the case where configuration fails
|
67 |
+
|
68 |
def init_session(self):
|
69 |
+
"""Initializes the Streamlit session state with default values."""
|
70 |
session_defaults = {
|
71 |
'inputs': [],
|
72 |
'qa_data': [],
|
|
|
78 |
'config': {
|
79 |
'provider': "Deepseek",
|
80 |
'model': "deepseek-chat",
|
81 |
+
'temperature': DEFAULT_TEMPERATURE
|
82 |
}
|
83 |
}
|
84 |
+
|
85 |
for key, val in session_defaults.items():
|
86 |
if key not in st.session_state:
|
87 |
st.session_state[key] = val
|
88 |
|
89 |
# Input Processors
|
90 |
def handle_pdf(self, file):
|
91 |
+
"""Extracts text and images from a PDF file."""
|
92 |
+
try:
|
93 |
with pdfplumber.open(file) as pdf:
|
94 |
+
extracted_data = []
|
95 |
+
for i, page in enumerate(pdf.pages):
|
96 |
+
page_text = page.extract_text() or ""
|
97 |
+
page_images = self.process_images(page)
|
98 |
+
extracted_data.append({
|
99 |
+
"text": page_text,
|
100 |
+
"images": page_images,
|
101 |
+
"meta": {"type": "pdf", "page": i + 1}
|
102 |
+
})
|
103 |
+
return extracted_data
|
104 |
+
except Exception as e:
|
105 |
+
self.log_error(f"PDF Error: {str(e)}")
|
106 |
+
return []
|
107 |
|
108 |
def handle_text(self, text):
|
109 |
+
"""Handles manual text input."""
|
110 |
return [{
|
111 |
"text": text,
|
112 |
"meta": {"type": "domain", "source": "manual"}
|
113 |
}]
|
114 |
|
115 |
def handle_csv(self, file):
|
116 |
+
"""Reads a CSV file and prepares data for Q&A generation."""
|
117 |
try:
|
118 |
df = pd.read_csv(file)
|
119 |
return [{
|
|
|
125 |
return []
|
126 |
|
127 |
def handle_api(self, config):
|
128 |
+
"""Fetches data from an API endpoint."""
|
129 |
try:
|
130 |
response = requests.get(config['url'], headers=config['headers'])
|
131 |
+
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
132 |
return [{
|
133 |
"text": json.dumps(response.json()),
|
134 |
"meta": {"type": "api", "endpoint": config['url']}
|
135 |
}]
|
136 |
+
except requests.exceptions.RequestException as e:
|
137 |
self.log_error(f"API Error: {str(e)}")
|
138 |
return []
|
139 |
|
140 |
+
|
141 |
def handle_db(self, config):
|
142 |
+
"""Connects to a database and executes a query."""
|
143 |
try:
|
144 |
engine = sqlalchemy.create_engine(config['connection'])
|
145 |
with engine.connect() as conn:
|
|
|
153 |
return []
|
154 |
|
155 |
def process_images(self, page):
|
156 |
+
"""Extracts and processes images from a PDF page."""
|
157 |
images = []
|
158 |
for img in page.images:
|
159 |
try:
|
|
|
171 |
|
172 |
# Core Generation Engine
|
173 |
def generate(self, api_key: str) -> bool:
|
174 |
+
"""
|
175 |
+
Generates Q&A pairs using the selected LLM provider.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
api_key (str): The API key for the selected LLM provider.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
bool: True if generation was successful, False otherwise.
|
182 |
+
"""
|
183 |
try:
|
184 |
provider_cfg = self.providers[st.session_state.config['provider']]
|
185 |
+
client_initializer = provider_cfg["client"] #Get the client init function.
|
186 |
+
|
187 |
+
# Check that the key is not an empty string
|
188 |
+
if not api_key:
|
189 |
+
st.error("API Key cannot be empty.")
|
190 |
+
return False
|
191 |
+
|
192 |
+
# Initialize the client
|
193 |
+
if st.session_state.config['provider'] == "Google":
|
194 |
+
client = client_initializer(api_key) # Client is the class
|
195 |
+
if not client:
|
196 |
+
return False # Google config failed
|
197 |
+
else:
|
198 |
+
client = client_initializer(api_key)
|
199 |
+
|
200 |
for i, input_data in enumerate(st.session_state.inputs):
|
201 |
st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
|
202 |
+
|
203 |
if st.session_state.config['provider'] == "HuggingFace":
|
204 |
response = self._huggingface_inference(client, input_data)
|
205 |
+
elif st.session_state.config['provider'] == "Google":
|
206 |
+
response = self._google_inference(client, input_data)
|
207 |
else:
|
208 |
response = self._standard_inference(client, input_data)
|
209 |
+
|
210 |
if response:
|
211 |
+
# Check if the parsing function needs access to the provider
|
212 |
+
st.session_state.qa_data.extend(self._parse_response(response, st.session_state.config['provider']))
|
213 |
+
|
214 |
return True
|
215 |
except Exception as e:
|
216 |
self.log_error(f"Generation Error: {str(e)}")
|
217 |
return False
|
218 |
|
219 |
def _standard_inference(self, client, input_data):
|
220 |
+
"""Performs inference using standard OpenAI-compatible API."""
|
221 |
+
try:
|
222 |
+
return client.chat.completions.create(
|
223 |
+
model=st.session_state.config['model'],
|
224 |
+
messages=[{
|
225 |
+
"role": "user",
|
226 |
+
"content": self._build_prompt(input_data)
|
227 |
+
}],
|
228 |
+
temperature=st.session_state.config['temperature'],
|
229 |
+
response_format={"type": "json_object"} #Request json
|
230 |
+
)
|
231 |
+
except Exception as e:
|
232 |
+
self.log_error(f"OpenAI Inference Error: {e}")
|
233 |
+
return None
|
234 |
|
235 |
def _huggingface_inference(self, client, input_data):
|
236 |
+
"""Performs inference using Hugging Face Inference API."""
|
237 |
+
try:
|
238 |
+
response = requests.post(
|
239 |
+
HF_API_URL + st.session_state.config['model'],
|
240 |
+
headers=client["headers"],
|
241 |
+
json={"inputs": self._build_prompt(input_data)}
|
242 |
+
)
|
243 |
+
response.raise_for_status() #Check for HTTP errors
|
244 |
+
return response.json()
|
245 |
+
except requests.exceptions.RequestException as e:
|
246 |
+
self.log_error(f"Hugging Face Inference Error: {e}")
|
247 |
+
return None
|
248 |
+
|
249 |
+
def _google_inference(self, client, input_data):
|
250 |
+
"""Performs inference using Google Generative AI API."""
|
251 |
+
try:
|
252 |
+
|
253 |
+
model = client(st.session_state.config['model']) # Instantiate the model with the selected model name
|
254 |
+
response = model.generate_content(
|
255 |
+
self._build_prompt(input_data),
|
256 |
+
generation_config = genai.types.GenerationConfig(temperature=st.session_state.config['temperature'])
|
257 |
+
|
258 |
+
)
|
259 |
+
return response
|
260 |
+
except Exception as e:
|
261 |
+
self.log_error(f"Google GenAI Inference Error: {e}")
|
262 |
+
return None
|
263 |
|
264 |
def _build_prompt(self, input_data):
|
265 |
+
"""Builds the prompt for the LLM based on the input data type."""
|
266 |
+
base = "Generate 3 Q&A pairs from this financial content, formatted as a JSON list of dictionaries with 'question' and 'answer' keys:\n"
|
267 |
if input_data['meta']['type'] == 'csv':
|
268 |
return base + "Structured data:\n" + input_data['text']
|
269 |
elif input_data['meta']['type'] == 'api':
|
270 |
return base + "API response:\n" + input_data['text']
|
271 |
return base + input_data['text']
|
272 |
|
273 |
+
def _parse_response(self, response, provider):
|
274 |
+
"""Parses the response from the LLM into a list of Q&A pairs."""
|
275 |
try:
|
276 |
+
if provider == "HuggingFace":
|
277 |
return response[0]['generated_text']
|
278 |
+
elif provider == "Google":
|
279 |
+
# Expecting a text response from Gemini
|
280 |
+
try:
|
281 |
+
json_string = response.text.strip() # Removes surrounding whitespace that can cause errors
|
282 |
+
qa_pairs = json.loads(json_string).get("qa_pairs", []) # Extract the qa_pairs
|
283 |
+
|
284 |
+
# Validate the structure of qa_pairs
|
285 |
+
if not isinstance(qa_pairs, list):
|
286 |
+
raise ValueError("Expected a list of QA pairs.")
|
287 |
+
|
288 |
+
for pair in qa_pairs:
|
289 |
+
if not isinstance(pair, dict) or "question" not in pair or "answer" not in pair:
|
290 |
+
raise ValueError("Each item in the list must be a dictionary with 'question' and 'answer' keys.")
|
291 |
+
return qa_pairs # Return the extracted and validated list
|
292 |
+
except (json.JSONDecodeError, ValueError) as e:
|
293 |
+
self.log_error(f"Google JSON Parse Error: {e}. Raw Response: {response.text}")
|
294 |
+
return [] # Return empty in case of parsing failure
|
295 |
+
else:
|
296 |
+
# Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
|
297 |
+
json_output = json.loads(response.choices[0].message.content) # load the JSON data
|
298 |
+
return json_output.get("qa_pairs", []) # Return the qa_pairs
|
299 |
except Exception as e:
|
300 |
+
self.log_error(f"Parse Error: {e}. Raw Response: {response}")
|
301 |
return []
|
302 |
|
303 |
def log_error(self, message):
|
304 |
+
"""Logs an error message to the Streamlit session state and displays it in the UI."""
|
305 |
st.session_state.processing['errors'].append(message)
|
306 |
st.error(message)
|
307 |
|
308 |
# Streamlit UI Components
|
309 |
def input_sidebar(gen: SyntheticDataGenerator):
|
310 |
+
"""
|
311 |
+
Creates the input sidebar in the Streamlit UI.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
gen (SyntheticDataGenerator): The SyntheticDataGenerator instance.
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
str: The API key entered by the user.
|
318 |
+
"""
|
319 |
with st.sidebar:
|
320 |
st.header("⚙️ Configuration")
|
321 |
+
|
322 |
# AI Provider Settings
|
323 |
provider = st.selectbox("Provider", list(gen.providers.keys()))
|
324 |
provider_cfg = gen.providers[provider]
|
325 |
+
|
326 |
api_key = st.text_input(f"{provider} API Key", type="password")
|
327 |
+
st.session_state['api_key'] = api_key #Store API Key
|
328 |
+
|
329 |
model = st.selectbox("Model", provider_cfg["models"])
|
330 |
+
temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
331 |
+
|
332 |
# Update session config
|
333 |
st.session_state.config.update({
|
334 |
"provider": provider,
|
335 |
"model": model,
|
336 |
"temperature": temp
|
337 |
})
|
338 |
+
|
339 |
# Input Source Selection
|
340 |
st.header("🔗 Data Sources")
|
341 |
input_type = st.selectbox("Input Type", list(gen.input_handlers.keys()))
|
342 |
+
|
343 |
if input_type == "text":
|
344 |
domain_input = st.text_area("Domain Knowledge", height=150)
|
345 |
if st.button("Add Domain Input"):
|
346 |
+
st.session_state.inputs.append(gen.input_handlers["text"](domain_input)[0])
|
347 |
+
|
348 |
elif input_type == "csv":
|
349 |
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
350 |
if csv_file:
|
351 |
+
st.session_state.inputs.extend(gen.input_handlers["csv"](csv_file))
|
352 |
+
|
353 |
elif input_type == "api":
|
354 |
api_url = st.text_input("API Endpoint")
|
355 |
+
api_headers = st.text_area("API Headers (JSON format, optional)", height=50)
|
356 |
+
headers = {}
|
357 |
+
try:
|
358 |
+
if api_headers:
|
359 |
+
headers = json.loads(api_headers)
|
360 |
+
except json.JSONDecodeError:
|
361 |
+
st.error("Invalid JSON format for API headers.")
|
362 |
+
if st.button("Add API Input"):
|
363 |
+
st.session_state.inputs.extend(gen.input_handlers["api"]({"url": api_url, "headers": headers}))
|
364 |
+
|
365 |
+
elif input_type == "db":
|
366 |
+
db_connection = st.text_input("Database Connection String")
|
367 |
+
db_query = st.text_area("Database Query")
|
368 |
+
db_table = st.text_input("Table Name (optional)")
|
369 |
+
if st.button("Add DB Input"):
|
370 |
+
st.session_state.inputs.extend(gen.input_handlers["db"]({"connection": db_connection, "query": db_query, "table": db_table}))
|
371 |
+
|
372 |
return api_key
|
373 |
|
374 |
def main_display(gen: SyntheticDataGenerator):
|
375 |
+
"""
|
376 |
+
Creates the main display area in the Streamlit UI.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
gen (SyntheticDataGenerator): The SyntheticDataGenerator instance.
|
380 |
+
"""
|
381 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
382 |
+
|
383 |
# Input Processing
|
384 |
col1, col2 = st.columns([3, 1])
|
385 |
with col1:
|
386 |
pdf_file = st.file_uploader("Upload Document", type=["pdf"])
|
387 |
if pdf_file:
|
388 |
+
st.session_state.inputs.extend(gen.input_handlers["pdf"](pdf_file))
|
389 |
+
|
390 |
# Generation Controls
|
391 |
with col2:
|
392 |
if st.button("Start Generation"):
|
393 |
with st.status("Processing..."):
|
394 |
+
if not st.session_state.get('api_key'):
|
395 |
+
st.error("Please provide an API Key.")
|
396 |
+
else:
|
397 |
+
gen.generate(st.session_state.get('api_key'))
|
398 |
+
|
399 |
# Results Display
|
400 |
if st.session_state.qa_data:
|
401 |
st.header("Generated Data")
|
402 |
df = pd.DataFrame(st.session_state.qa_data)
|
403 |
st.dataframe(df)
|
404 |
+
|
405 |
# Export Options
|
406 |
st.download_button(
|
407 |
"Export CSV",
|
|
|
410 |
)
|
411 |
|
412 |
def main():
|
413 |
+
"""Main function to run the Streamlit application."""
|
414 |
gen = SyntheticDataGenerator()
|
415 |
api_key = input_sidebar(gen)
|
416 |
main_display(gen)
|