mgbam commited on
Commit
3045f18
·
verified ·
1 Parent(s): 58e9888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -132
app.py CHANGED
@@ -1,178 +1,274 @@
1
  import streamlit as st
2
  import pdfplumber
3
  import pytesseract
4
- from PIL import Image
5
- import json
6
  import pandas as pd
 
 
 
7
  from io import BytesIO
8
- import time
9
  from openai import OpenAI
10
  import groq
 
 
11
 
12
- class SyntheticDataFactory:
13
- PROVIDER_CONFIG = {
14
- "Deepseek": {
15
- "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
16
- "models": ["deepseek-chat"],
17
- "key_label": "Deepseek API Key"
18
- },
19
- "OpenAI": {
20
- "client": lambda key: OpenAI(api_key=key),
21
- "models": ["gpt-4-turbo"],
22
- "key_label": "OpenAI API Key"
23
- },
24
- "Groq": {
25
- "client": lambda key: groq.Groq(api_key=key),
26
- "models": ["mixtral-8x7b-32768", "llama2-70b-4096"],
27
- "key_label": "Groq API Key"
28
- }
29
- }
30
-
31
  def __init__(self):
32
- self.init_session_state()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def init_session_state(self):
35
- if 'qa_data' not in st.session_state:
36
- st.session_state.qa_data = []
37
- if 'processing' not in st.session_state:
38
- st.session_state.processing = {
39
  'stage': 'idle',
40
- 'errors': [],
41
- 'progress': 0
 
 
 
 
 
42
  }
 
 
 
 
 
43
 
44
- def process_pdf(self, file):
45
- """Process PDF with error handling"""
46
  try:
47
  with pdfplumber.open(file) as pdf:
48
- pages = pdf.pages
49
- for i, page in enumerate(pages):
50
- # Update progress
51
- st.session_state.processing['progress'] = (i+1)/len(pages)
52
-
53
- # Process page content
54
- text = page.extract_text() or ""
55
- images = self.process_images(page)
56
-
57
- # Store in session state
58
- st.session_state.qa_data.append({
59
- "page": i+1,
60
- "text": text,
61
- "images": images
62
- })
63
- time.sleep(0.1) # Simulate processing
64
- return True
65
  except Exception as e:
66
- st.error(f"PDF processing failed: {str(e)}")
67
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def process_images(self, page):
70
- """Robust image processing"""
71
  images = []
72
  for img in page.images:
73
  try:
74
- # Handle different PDF image formats
75
  stream = img['stream']
76
- width = int(stream.get('Width', stream.get('W', 0)))
77
- height = int(stream.get('Height', stream.get('H', 0)))
78
-
79
  if width > 0 and height > 0:
80
- image = Image.frombytes(
81
- "RGB" if 'ColorSpace' in stream else "L",
82
- (width, height),
83
- stream.get_data()
84
- )
85
- images.append(image)
86
  except Exception as e:
87
- st.warning(f"Image processing error: {str(e)[:100]}")
88
  return images
89
 
90
- def generate_qa(self, provider, api_key, model, temp):
91
- """Generate Q&A pairs with selected provider"""
92
  try:
93
- client = self.PROVIDER_CONFIG[provider]["client"](api_key)
 
94
 
95
- for item in st.session_state.qa_data:
96
- prompt = f"Generate 3 Q&A pairs from this financial content:\n{item['text']}\nOutput JSON format with keys: question, answer_1, answer_2"
97
 
98
- response = client.chat.completions.create(
99
- model=model,
100
- messages=[{"role": "user", "content": prompt}],
101
- temperature=temp,
102
- response_format={"type": "json_object"}
103
- )
104
 
105
- try:
106
- result = json.loads(response.choices[0].message.content)
107
- item["qa_pairs"] = result.get("qa_pairs", [])
108
- except json.JSONDecodeError:
109
- st.error("Failed to parse AI response")
110
 
111
- st.session_state.processing['stage'] = 'complete'
112
  return True
113
-
114
  except Exception as e:
115
- st.error(f"Generation failed: {str(e)}")
116
  return False
117
 
118
- def main():
119
- st.set_page_config(
120
- page_title="Enterprise Data Factory",
121
- page_icon="🏭",
122
- layout="wide"
123
- )
124
-
125
- factory = SyntheticDataFactory()
126
-
127
- # Sidebar Configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  with st.sidebar:
129
- st.header("⚙️ AI Configuration")
130
- provider = st.selectbox("Provider", list(factory.PROVIDER_CONFIG.keys()))
131
- config = factory.PROVIDER_CONFIG[provider]
132
- api_key = st.text_input(config["key_label"], type="password")
133
- model = st.selectbox("Model", config["models"])
 
 
 
134
  temp = st.slider("Temperature", 0.0, 1.0, 0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Main Interface
137
  st.title("🚀 Enterprise Synthetic Data Factory")
138
 
139
- uploaded_file = st.file_uploader("Upload Financial PDF", type=["pdf"])
 
 
 
 
 
140
 
141
- if uploaded_file and api_key and st.button("Start Synthetic Generation"):
142
- with st.status("Processing document...", expanded=True) as status:
143
- # Process PDF
144
- st.write("Extracting text and images...")
145
- if factory.process_pdf(uploaded_file):
146
- # Generate Q&A pairs
147
- st.write("Generating synthetic data...")
148
- if factory.generate_qa(provider, api_key, model, temp):
149
- status.update(label="Processing complete!", state="complete", expanded=False)
150
 
151
- # Display Results
152
- if st.session_state.processing.get('stage') == 'complete':
153
- st.subheader("Generated Q&A Pairs")
154
-
155
- # Convert to DataFrame
156
- all_qa = []
157
- for item in st.session_state.qa_data:
158
- for qa in item.get("qa_pairs", []):
159
- qa["page"] = item["page"]
160
- all_qa.append(qa)
161
 
162
- if len(all_qa) > 0:
163
- df = pd.DataFrame(all_qa)
164
- st.dataframe(df)
165
-
166
- # Export options
167
- csv = df.to_csv(index=False).encode('utf-8')
168
- st.download_button(
169
- label="Download as CSV",
170
- data=csv,
171
- file_name="synthetic_data.csv",
172
- mime="text/csv"
173
- )
174
- else:
175
- st.warning("No Q&A pairs generated. Check your document content and API settings.")
176
 
177
  if __name__ == "__main__":
178
  main()
 
1
  import streamlit as st
2
  import pdfplumber
3
  import pytesseract
 
 
4
  import pandas as pd
5
+ import requests
6
+ import json
7
+ from PIL import Image
8
  from io import BytesIO
 
9
  from openai import OpenAI
10
  import groq
11
+ import sqlalchemy
12
+ from typing import Dict, Any
13
 
14
+ class SyntheticDataGenerator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __init__(self):
16
+ self.providers = {
17
+ "Deepseek": {
18
+ "client": lambda key: OpenAI(base_url="https://api.deepseek.com/v1", api_key=key),
19
+ "models": ["deepseek-chat"]
20
+ },
21
+ "OpenAI": {
22
+ "client": lambda key: OpenAI(api_key=key),
23
+ "models": ["gpt-4-turbo"]
24
+ },
25
+ "Groq": {
26
+ "client": lambda key: groq.Groq(api_key=key),
27
+ "models": ["mixtral-8x7b-32768"]
28
+ },
29
+ "HuggingFace": {
30
+ "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
31
+ "models": ["gpt2", "llama-2"]
32
+ }
33
+ }
34
+
35
+ self.input_handlers = {
36
+ "pdf": self.handle_pdf,
37
+ "text": self.handle_text,
38
+ "csv": self.handle_csv,
39
+ "api": self.handle_api,
40
+ "db": self.handle_db
41
+ }
42
+
43
+ self.init_session()
44
 
45
+ def init_session(self):
46
+ session_defaults = {
47
+ 'inputs': [],
48
+ 'qa_data': [],
49
+ 'processing': {
50
  'stage': 'idle',
51
+ 'progress': 0,
52
+ 'errors': []
53
+ },
54
+ 'config': {
55
+ 'provider': "Deepseek",
56
+ 'model': "deepseek-chat",
57
+ 'temperature': 0.3
58
  }
59
+ }
60
+
61
+ for key, val in session_defaults.items():
62
+ if key not in st.session_state:
63
+ st.session_state[key] = val
64
 
65
+ # Input Processors
66
+ def handle_pdf(self, file):
67
  try:
68
  with pdfplumber.open(file) as pdf:
69
+ return [{
70
+ "text": page.extract_text() or "",
71
+ "images": self.process_images(page),
72
+ "meta": {"type": "pdf", "page": i+1}
73
+ } for i, page in enumerate(pdf.pages)]
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
+ self.log_error(f"PDF Error: {str(e)}")
76
+ return []
77
+
78
+ def handle_text(self, text):
79
+ return [{
80
+ "text": text,
81
+ "meta": {"type": "domain", "source": "manual"}
82
+ }]
83
+
84
+ def handle_csv(self, file):
85
+ try:
86
+ df = pd.read_csv(file)
87
+ return [{
88
+ "text": "\n".join([f"{col}: {row[col]}" for col in df.columns]),
89
+ "meta": {"type": "csv", "columns": list(df.columns)}
90
+ } for _, row in df.iterrows()]
91
+ except Exception as e:
92
+ self.log_error(f"CSV Error: {str(e)}")
93
+ return []
94
+
95
+ def handle_api(self, config):
96
+ try:
97
+ response = requests.get(config['url'], headers=config['headers'])
98
+ return [{
99
+ "text": json.dumps(response.json()),
100
+ "meta": {"type": "api", "endpoint": config['url']}
101
+ }]
102
+ except Exception as e:
103
+ self.log_error(f"API Error: {str(e)}")
104
+ return []
105
+
106
+ def handle_db(self, config):
107
+ try:
108
+ engine = sqlalchemy.create_engine(config['connection'])
109
+ with engine.connect() as conn:
110
+ result = conn.execute(sqlalchemy.text(config['query']))
111
+ return [{
112
+ "text": "\n".join([f"{col}: {val}" for col, val in row._asdict().items()]),
113
+ "meta": {"type": "db", "table": config.get('table', '')}
114
+ } for row in result]
115
+ except Exception as e:
116
+ self.log_error(f"DB Error: {str(e)}")
117
+ return []
118
 
119
  def process_images(self, page):
 
120
  images = []
121
  for img in page.images:
122
  try:
 
123
  stream = img['stream']
124
+ width = int(stream.get('Width', 0))
125
+ height = int(stream.get('Height', 0))
 
126
  if width > 0 and height > 0:
127
+ images.append({
128
+ "data": Image.frombytes("RGB", (width, height), stream.get_data()),
129
+ "meta": {"dims": (width, height)}
130
+ })
 
 
131
  except Exception as e:
132
+ self.log_error(f"Image Error: {str(e)}")
133
  return images
134
 
135
+ # Core Generation Engine
136
+ def generate(self, api_key: str) -> bool:
137
  try:
138
+ provider_cfg = self.providers[st.session_state.config['provider']]
139
+ client = provider_cfg["client"](api_key)
140
 
141
+ for i, input_data in enumerate(st.session_state.inputs):
142
+ st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
143
 
144
+ if st.session_state.config['provider'] == "HuggingFace":
145
+ response = self._huggingface_inference(client, input_data)
146
+ else:
147
+ response = self._standard_inference(client, input_data)
 
 
148
 
149
+ if response:
150
+ st.session_state.qa_data.extend(self._parse_response(response))
 
 
 
151
 
 
152
  return True
 
153
  except Exception as e:
154
+ self.log_error(f"Generation Error: {str(e)}")
155
  return False
156
 
157
+ def _standard_inference(self, client, input_data):
158
+ return client.chat.completions.create(
159
+ model=st.session_state.config['model'],
160
+ messages=[{
161
+ "role": "user",
162
+ "content": self._build_prompt(input_data)
163
+ }],
164
+ temperature=st.session_state.config['temperature'],
165
+ response_format={"type": "json_object"}
166
+ )
167
+
168
+ def _huggingface_inference(self, client, input_data):
169
+ API_URL = "https://api-inference.huggingface.co/models/"
170
+ response = requests.post(
171
+ API_URL + st.session_state.config['model'],
172
+ headers=client["headers"],
173
+ json={"inputs": self._build_prompt(input_data)}
174
+ )
175
+ return response.json()
176
+
177
+ def _build_prompt(self, input_data):
178
+ base = "Generate 3 Q&A pairs from this financial content:\n"
179
+ if input_data['meta']['type'] == 'csv':
180
+ return base + "Structured data:\n" + input_data['text']
181
+ elif input_data['meta']['type'] == 'api':
182
+ return base + "API response:\n" + input_data['text']
183
+ return base + input_data['text']
184
+
185
+ def _parse_response(self, response):
186
+ try:
187
+ if st.session_state.config['provider'] == "HuggingFace":
188
+ return response[0]['generated_text']
189
+ return json.loads(response.choices[0].message.content).get("qa_pairs", [])
190
+ except Exception as e:
191
+ self.log_error(f"Parse Error: {str(e)}")
192
+ return []
193
+
194
+ def log_error(self, message):
195
+ st.session_state.processing['errors'].append(message)
196
+ st.error(message)
197
+
198
+ # Streamlit UI Components
199
+ def input_sidebar(gen: SyntheticDataGenerator):
200
  with st.sidebar:
201
+ st.header("⚙️ Configuration")
202
+
203
+ # AI Provider Settings
204
+ provider = st.selectbox("Provider", list(gen.providers.keys()))
205
+ provider_cfg = gen.providers[provider]
206
+
207
+ api_key = st.text_input(f"{provider} API Key", type="password")
208
+ model = st.selectbox("Model", provider_cfg["models"])
209
  temp = st.slider("Temperature", 0.0, 1.0, 0.3)
210
+
211
+ # Update session config
212
+ st.session_state.config.update({
213
+ "provider": provider,
214
+ "model": model,
215
+ "temperature": temp
216
+ })
217
+
218
+ # Input Source Selection
219
+ st.header("🔗 Data Sources")
220
+ input_type = st.selectbox("Input Type", list(gen.input_handlers.keys()))
221
+
222
+ if input_type == "text":
223
+ domain_input = st.text_area("Domain Knowledge", height=150)
224
+ if st.button("Add Domain Input"):
225
+ gen.input_handlers["text"](domain_input)
226
+
227
+ elif input_type == "csv":
228
+ csv_file = st.file_uploader("Upload CSV", type=["csv"])
229
+ if csv_file:
230
+ gen.input_handlers["csv"](csv_file)
231
+
232
+ elif input_type == "api":
233
+ api_url = st.text_input("API Endpoint")
234
+ if st.button("Connect API"):
235
+ gen.input_handlers["api"]({"url": api_url})
236
+
237
+ return api_key
238
 
239
+ def main_display(gen: SyntheticDataGenerator):
240
  st.title("🚀 Enterprise Synthetic Data Factory")
241
 
242
+ # Input Processing
243
+ col1, col2 = st.columns([3, 1])
244
+ with col1:
245
+ pdf_file = st.file_uploader("Upload Document", type=["pdf"])
246
+ if pdf_file:
247
+ gen.input_handlers["pdf"](pdf_file)
248
 
249
+ # Generation Controls
250
+ with col2:
251
+ if st.button("Start Generation"):
252
+ with st.status("Processing..."):
253
+ gen.generate(st.session_state.get('api_key'))
 
 
 
 
254
 
255
+ # Results Display
256
+ if st.session_state.qa_data:
257
+ st.header("Generated Data")
258
+ df = pd.DataFrame(st.session_state.qa_data)
259
+ st.dataframe(df)
 
 
 
 
 
260
 
261
+ # Export Options
262
+ st.download_button(
263
+ "Export CSV",
264
+ df.to_csv(index=False),
265
+ "synthetic_data.csv"
266
+ )
267
+
268
+ def main():
269
+ gen = SyntheticDataGenerator()
270
+ api_key = input_sidebar(gen)
271
+ main_display(gen)
 
 
 
272
 
273
  if __name__ == "__main__":
274
  main()