Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,178 +1,274 @@
|
|
1 |
import streamlit as st
|
2 |
import pdfplumber
|
3 |
import pytesseract
|
4 |
-
from PIL import Image
|
5 |
-
import json
|
6 |
import pandas as pd
|
|
|
|
|
|
|
7 |
from io import BytesIO
|
8 |
-
import time
|
9 |
from openai import OpenAI
|
10 |
import groq
|
|
|
|
|
11 |
|
12 |
-
class
|
13 |
-
PROVIDER_CONFIG = {
|
14 |
-
"Deepseek": {
|
15 |
-
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
|
16 |
-
"models": ["deepseek-chat"],
|
17 |
-
"key_label": "Deepseek API Key"
|
18 |
-
},
|
19 |
-
"OpenAI": {
|
20 |
-
"client": lambda key: OpenAI(api_key=key),
|
21 |
-
"models": ["gpt-4-turbo"],
|
22 |
-
"key_label": "OpenAI API Key"
|
23 |
-
},
|
24 |
-
"Groq": {
|
25 |
-
"client": lambda key: groq.Groq(api_key=key),
|
26 |
-
"models": ["mixtral-8x7b-32768", "llama2-70b-4096"],
|
27 |
-
"key_label": "Groq API Key"
|
28 |
-
}
|
29 |
-
}
|
30 |
-
|
31 |
def __init__(self):
|
32 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
def
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
'stage': 'idle',
|
40 |
-
'
|
41 |
-
'
|
|
|
|
|
|
|
|
|
|
|
42 |
}
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
try:
|
47 |
with pdfplumber.open(file) as pdf:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
# Process page content
|
54 |
-
text = page.extract_text() or ""
|
55 |
-
images = self.process_images(page)
|
56 |
-
|
57 |
-
# Store in session state
|
58 |
-
st.session_state.qa_data.append({
|
59 |
-
"page": i+1,
|
60 |
-
"text": text,
|
61 |
-
"images": images
|
62 |
-
})
|
63 |
-
time.sleep(0.1) # Simulate processing
|
64 |
-
return True
|
65 |
except Exception as e:
|
66 |
-
|
67 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def process_images(self, page):
|
70 |
-
"""Robust image processing"""
|
71 |
images = []
|
72 |
for img in page.images:
|
73 |
try:
|
74 |
-
# Handle different PDF image formats
|
75 |
stream = img['stream']
|
76 |
-
width = int(stream.get('Width',
|
77 |
-
height = int(stream.get('Height',
|
78 |
-
|
79 |
if width > 0 and height > 0:
|
80 |
-
|
81 |
-
"RGB"
|
82 |
-
(width, height)
|
83 |
-
|
84 |
-
)
|
85 |
-
images.append(image)
|
86 |
except Exception as e:
|
87 |
-
|
88 |
return images
|
89 |
|
90 |
-
|
91 |
-
|
92 |
try:
|
93 |
-
|
|
|
94 |
|
95 |
-
for
|
96 |
-
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
response_format={"type": "json_object"}
|
103 |
-
)
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
item["qa_pairs"] = result.get("qa_pairs", [])
|
108 |
-
except json.JSONDecodeError:
|
109 |
-
st.error("Failed to parse AI response")
|
110 |
|
111 |
-
st.session_state.processing['stage'] = 'complete'
|
112 |
return True
|
113 |
-
|
114 |
except Exception as e:
|
115 |
-
|
116 |
return False
|
117 |
|
118 |
-
def
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
with st.sidebar:
|
129 |
-
st.header("⚙️
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
134 |
temp = st.slider("Temperature", 0.0, 1.0, 0.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
138 |
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
st.
|
145 |
-
|
146 |
-
# Generate Q&A pairs
|
147 |
-
st.write("Generating synthetic data...")
|
148 |
-
if factory.generate_qa(provider, api_key, model, temp):
|
149 |
-
status.update(label="Processing complete!", state="complete", expanded=False)
|
150 |
|
151 |
-
# Display
|
152 |
-
if st.session_state.
|
153 |
-
st.
|
154 |
-
|
155 |
-
|
156 |
-
all_qa = []
|
157 |
-
for item in st.session_state.qa_data:
|
158 |
-
for qa in item.get("qa_pairs", []):
|
159 |
-
qa["page"] = item["page"]
|
160 |
-
all_qa.append(qa)
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
)
|
174 |
-
else:
|
175 |
-
st.warning("No Q&A pairs generated. Check your document content and API settings.")
|
176 |
|
177 |
if __name__ == "__main__":
|
178 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import pdfplumber
|
3 |
import pytesseract
|
|
|
|
|
4 |
import pandas as pd
|
5 |
+
import requests
|
6 |
+
import json
|
7 |
+
from PIL import Image
|
8 |
from io import BytesIO
|
|
|
9 |
from openai import OpenAI
|
10 |
import groq
|
11 |
+
import sqlalchemy
|
12 |
+
from typing import Dict, Any
|
13 |
|
14 |
+
class SyntheticDataGenerator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def __init__(self):
|
16 |
+
self.providers = {
|
17 |
+
"Deepseek": {
|
18 |
+
"client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
|
19 |
+
"models": ["deepseek-chat"]
|
20 |
+
},
|
21 |
+
"OpenAI": {
|
22 |
+
"client": lambda key: OpenAI(api_key=key),
|
23 |
+
"models": ["gpt-4-turbo"]
|
24 |
+
},
|
25 |
+
"Groq": {
|
26 |
+
"client": lambda key: groq.Groq(api_key=key),
|
27 |
+
"models": ["mixtral-8x7b-32768"]
|
28 |
+
},
|
29 |
+
"HuggingFace": {
|
30 |
+
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
31 |
+
"models": ["gpt2", "llama-2"]
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
self.input_handlers = {
|
36 |
+
"pdf": self.handle_pdf,
|
37 |
+
"text": self.handle_text,
|
38 |
+
"csv": self.handle_csv,
|
39 |
+
"api": self.handle_api,
|
40 |
+
"db": self.handle_db
|
41 |
+
}
|
42 |
+
|
43 |
+
self.init_session()
|
44 |
|
45 |
+
def init_session(self):
|
46 |
+
session_defaults = {
|
47 |
+
'inputs': [],
|
48 |
+
'qa_data': [],
|
49 |
+
'processing': {
|
50 |
'stage': 'idle',
|
51 |
+
'progress': 0,
|
52 |
+
'errors': []
|
53 |
+
},
|
54 |
+
'config': {
|
55 |
+
'provider': "Deepseek",
|
56 |
+
'model': "deepseek-chat",
|
57 |
+
'temperature': 0.3
|
58 |
}
|
59 |
+
}
|
60 |
+
|
61 |
+
for key, val in session_defaults.items():
|
62 |
+
if key not in st.session_state:
|
63 |
+
st.session_state[key] = val
|
64 |
|
65 |
+
# Input Processors
|
66 |
+
def handle_pdf(self, file):
|
67 |
try:
|
68 |
with pdfplumber.open(file) as pdf:
|
69 |
+
return [{
|
70 |
+
"text": page.extract_text() or "",
|
71 |
+
"images": self.process_images(page),
|
72 |
+
"meta": {"type": "pdf", "page": i+1}
|
73 |
+
} for i, page in enumerate(pdf.pages)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
except Exception as e:
|
75 |
+
self.log_error(f"PDF Error: {str(e)}")
|
76 |
+
return []
|
77 |
+
|
78 |
+
def handle_text(self, text):
|
79 |
+
return [{
|
80 |
+
"text": text,
|
81 |
+
"meta": {"type": "domain", "source": "manual"}
|
82 |
+
}]
|
83 |
+
|
84 |
+
def handle_csv(self, file):
|
85 |
+
try:
|
86 |
+
df = pd.read_csv(file)
|
87 |
+
return [{
|
88 |
+
"text": "\n".join([f"{col}: {row[col]}" for col in df.columns]),
|
89 |
+
"meta": {"type": "csv", "columns": list(df.columns)}
|
90 |
+
} for _, row in df.iterrows()]
|
91 |
+
except Exception as e:
|
92 |
+
self.log_error(f"CSV Error: {str(e)}")
|
93 |
+
return []
|
94 |
+
|
95 |
+
def handle_api(self, config):
|
96 |
+
try:
|
97 |
+
response = requests.get(config['url'], headers=config['headers'])
|
98 |
+
return [{
|
99 |
+
"text": json.dumps(response.json()),
|
100 |
+
"meta": {"type": "api", "endpoint": config['url']}
|
101 |
+
}]
|
102 |
+
except Exception as e:
|
103 |
+
self.log_error(f"API Error: {str(e)}")
|
104 |
+
return []
|
105 |
+
|
106 |
+
def handle_db(self, config):
|
107 |
+
try:
|
108 |
+
engine = sqlalchemy.create_engine(config['connection'])
|
109 |
+
with engine.connect() as conn:
|
110 |
+
result = conn.execute(sqlalchemy.text(config['query']))
|
111 |
+
return [{
|
112 |
+
"text": "\n".join([f"{col}: {val}" for col, val in row._asdict().items()]),
|
113 |
+
"meta": {"type": "db", "table": config.get('table', '')}
|
114 |
+
} for row in result]
|
115 |
+
except Exception as e:
|
116 |
+
self.log_error(f"DB Error: {str(e)}")
|
117 |
+
return []
|
118 |
|
119 |
def process_images(self, page):
|
|
|
120 |
images = []
|
121 |
for img in page.images:
|
122 |
try:
|
|
|
123 |
stream = img['stream']
|
124 |
+
width = int(stream.get('Width', 0))
|
125 |
+
height = int(stream.get('Height', 0))
|
|
|
126 |
if width > 0 and height > 0:
|
127 |
+
images.append({
|
128 |
+
"data": Image.frombytes("RGB", (width, height), stream.get_data()),
|
129 |
+
"meta": {"dims": (width, height)}
|
130 |
+
})
|
|
|
|
|
131 |
except Exception as e:
|
132 |
+
self.log_error(f"Image Error: {str(e)}")
|
133 |
return images
|
134 |
|
135 |
+
# Core Generation Engine
|
136 |
+
def generate(self, api_key: str) -> bool:
|
137 |
try:
|
138 |
+
provider_cfg = self.providers[st.session_state.config['provider']]
|
139 |
+
client = provider_cfg["client"](api_key)
|
140 |
|
141 |
+
for i, input_data in enumerate(st.session_state.inputs):
|
142 |
+
st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
|
143 |
|
144 |
+
if st.session_state.config['provider'] == "HuggingFace":
|
145 |
+
response = self._huggingface_inference(client, input_data)
|
146 |
+
else:
|
147 |
+
response = self._standard_inference(client, input_data)
|
|
|
|
|
148 |
|
149 |
+
if response:
|
150 |
+
st.session_state.qa_data.extend(self._parse_response(response))
|
|
|
|
|
|
|
151 |
|
|
|
152 |
return True
|
|
|
153 |
except Exception as e:
|
154 |
+
self.log_error(f"Generation Error: {str(e)}")
|
155 |
return False
|
156 |
|
157 |
+
def _standard_inference(self, client, input_data):
|
158 |
+
return client.chat.completions.create(
|
159 |
+
model=st.session_state.config['model'],
|
160 |
+
messages=[{
|
161 |
+
"role": "user",
|
162 |
+
"content": self._build_prompt(input_data)
|
163 |
+
}],
|
164 |
+
temperature=st.session_state.config['temperature'],
|
165 |
+
response_format={"type": "json_object"}
|
166 |
+
)
|
167 |
+
|
168 |
+
def _huggingface_inference(self, client, input_data):
|
169 |
+
API_URL = "https://api-inference.huggingface.co/models/"
|
170 |
+
response = requests.post(
|
171 |
+
API_URL + st.session_state.config['model'],
|
172 |
+
headers=client["headers"],
|
173 |
+
json={"inputs": self._build_prompt(input_data)}
|
174 |
+
)
|
175 |
+
return response.json()
|
176 |
+
|
177 |
+
def _build_prompt(self, input_data):
|
178 |
+
base = "Generate 3 Q&A pairs from this financial content:\n"
|
179 |
+
if input_data['meta']['type'] == 'csv':
|
180 |
+
return base + "Structured data:\n" + input_data['text']
|
181 |
+
elif input_data['meta']['type'] == 'api':
|
182 |
+
return base + "API response:\n" + input_data['text']
|
183 |
+
return base + input_data['text']
|
184 |
+
|
185 |
+
def _parse_response(self, response):
|
186 |
+
try:
|
187 |
+
if st.session_state.config['provider'] == "HuggingFace":
|
188 |
+
return response[0]['generated_text']
|
189 |
+
return json.loads(response.choices[0].message.content).get("qa_pairs", [])
|
190 |
+
except Exception as e:
|
191 |
+
self.log_error(f"Parse Error: {str(e)}")
|
192 |
+
return []
|
193 |
+
|
194 |
+
def log_error(self, message):
|
195 |
+
st.session_state.processing['errors'].append(message)
|
196 |
+
st.error(message)
|
197 |
+
|
198 |
+
# Streamlit UI Components
|
199 |
+
def input_sidebar(gen: SyntheticDataGenerator):
|
200 |
with st.sidebar:
|
201 |
+
st.header("⚙️ Configuration")
|
202 |
+
|
203 |
+
# AI Provider Settings
|
204 |
+
provider = st.selectbox("Provider", list(gen.providers.keys()))
|
205 |
+
provider_cfg = gen.providers[provider]
|
206 |
+
|
207 |
+
api_key = st.text_input(f"{provider} API Key", type="password")
|
208 |
+
model = st.selectbox("Model", provider_cfg["models"])
|
209 |
temp = st.slider("Temperature", 0.0, 1.0, 0.3)
|
210 |
+
|
211 |
+
# Update session config
|
212 |
+
st.session_state.config.update({
|
213 |
+
"provider": provider,
|
214 |
+
"model": model,
|
215 |
+
"temperature": temp
|
216 |
+
})
|
217 |
+
|
218 |
+
# Input Source Selection
|
219 |
+
st.header("🔗 Data Sources")
|
220 |
+
input_type = st.selectbox("Input Type", list(gen.input_handlers.keys()))
|
221 |
+
|
222 |
+
if input_type == "text":
|
223 |
+
domain_input = st.text_area("Domain Knowledge", height=150)
|
224 |
+
if st.button("Add Domain Input"):
|
225 |
+
gen.input_handlers["text"](domain_input)
|
226 |
+
|
227 |
+
elif input_type == "csv":
|
228 |
+
csv_file = st.file_uploader("Upload CSV", type=["csv"])
|
229 |
+
if csv_file:
|
230 |
+
gen.input_handlers["csv"](csv_file)
|
231 |
+
|
232 |
+
elif input_type == "api":
|
233 |
+
api_url = st.text_input("API Endpoint")
|
234 |
+
if st.button("Connect API"):
|
235 |
+
gen.input_handlers["api"]({"url": api_url})
|
236 |
+
|
237 |
+
return api_key
|
238 |
|
239 |
+
def main_display(gen: SyntheticDataGenerator):
|
240 |
st.title("🚀 Enterprise Synthetic Data Factory")
|
241 |
|
242 |
+
# Input Processing
|
243 |
+
col1, col2 = st.columns([3, 1])
|
244 |
+
with col1:
|
245 |
+
pdf_file = st.file_uploader("Upload Document", type=["pdf"])
|
246 |
+
if pdf_file:
|
247 |
+
gen.input_handlers["pdf"](pdf_file)
|
248 |
|
249 |
+
# Generation Controls
|
250 |
+
with col2:
|
251 |
+
if st.button("Start Generation"):
|
252 |
+
with st.status("Processing..."):
|
253 |
+
gen.generate(st.session_state.get('api_key'))
|
|
|
|
|
|
|
|
|
254 |
|
255 |
+
# Results Display
|
256 |
+
if st.session_state.qa_data:
|
257 |
+
st.header("Generated Data")
|
258 |
+
df = pd.DataFrame(st.session_state.qa_data)
|
259 |
+
st.dataframe(df)
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
+
# Export Options
|
262 |
+
st.download_button(
|
263 |
+
"Export CSV",
|
264 |
+
df.to_csv(index=False),
|
265 |
+
"synthetic_data.csv"
|
266 |
+
)
|
267 |
+
|
268 |
+
def main():
|
269 |
+
gen = SyntheticDataGenerator()
|
270 |
+
api_key = input_sidebar(gen)
|
271 |
+
main_display(gen)
|
|
|
|
|
|
|
272 |
|
273 |
if __name__ == "__main__":
|
274 |
main()
|