mgbam commited on
Commit
7b16658
·
verified ·
1 Parent(s): 8d86c18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -17
app.py CHANGED
@@ -14,7 +14,8 @@ from typing import Dict, Any
14
 
15
  # Constants for Default Values and API URLs
16
  HF_API_URL = "https://api-inference.huggingface.co/models/"
17
- DEFAULT_TEMPERATURE = 0.3
 
18
 
19
  class SyntheticDataGenerator:
20
  """
@@ -34,7 +35,7 @@ class SyntheticDataGenerator:
34
  },
35
  "Groq": {
36
  "client": lambda key: groq.Groq(api_key=key),
37
- "models": ["mixtral-8x7b-32768"]
38
  },
39
  "HuggingFace": {
40
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
@@ -42,7 +43,7 @@ class SyntheticDataGenerator:
42
  },
43
  "Google": {
44
  "client": lambda key: self._configure_google_genai(key), # Using a custom configure function
45
- "models": ["gemini-2.0-pro"] # Add supported Gemini models. Consider adding "gemini-1.5-pro" when released.
46
  },
47
  }
48
 
@@ -76,8 +77,8 @@ class SyntheticDataGenerator:
76
  'errors': []
77
  },
78
  'config': {
79
- 'provider': "Deepseek",
80
- 'model': "deepseek-chat",
81
  'temperature': DEFAULT_TEMPERATURE
82
  }
83
  }
@@ -160,13 +161,22 @@ class SyntheticDataGenerator:
160
  stream = img['stream']
161
  width = int(stream.get('Width', 0))
162
  height = int(stream.get('Height', 0))
163
- if width > 0 and height > 0:
164
- images.append({
165
- "data": Image.frombytes("RGB", (width, height), stream.get_data()),
166
- "meta": {"dims": (width, height)}
167
- })
 
 
 
 
 
 
 
 
 
168
  except Exception as e:
169
- self.log_error(f"Image Error: {str(e)}")
170
  return images
171
 
172
  # Core Generation Engine
@@ -198,6 +208,7 @@ class SyntheticDataGenerator:
198
  client = client_initializer(api_key)
199
 
200
  for i, input_data in enumerate(st.session_state.inputs):
 
201
  st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
202
 
203
  if st.session_state.config['provider'] == "HuggingFace":
@@ -219,6 +230,8 @@ class SyntheticDataGenerator:
219
  def _standard_inference(self, client, input_data):
220
  """Performs inference using standard OpenAI-compatible API."""
221
  try:
 
 
222
  return client.chat.completions.create(
223
  model=st.session_state.config['model'],
224
  messages=[{
@@ -249,13 +262,16 @@ class SyntheticDataGenerator:
249
  def _google_inference(self, client, input_data):
250
  """Performs inference using Google Generative AI API."""
251
  try:
252
-
253
  model = client(st.session_state.config['model']) # Instantiate the model with the selected model name
254
  response = model.generate_content(
255
  self._build_prompt(input_data),
256
  generation_config = genai.types.GenerationConfig(temperature=st.session_state.config['temperature'])
257
 
258
  )
 
 
 
 
259
  return response
260
  except Exception as e:
261
  self.log_error(f"Google GenAI Inference Error: {e}")
@@ -263,9 +279,13 @@ class SyntheticDataGenerator:
263
 
264
  def _build_prompt(self, input_data):
265
  """Builds the prompt for the LLM based on the input data type."""
266
- base = "Generate 3 Q&A pairs from this financial content, formatted as a JSON list of dictionaries with 'question' and 'answer' keys:\n"
 
 
 
 
267
  if input_data['meta']['type'] == 'csv':
268
- return base + "Structured data:\n" + input_data['text']
269
  elif input_data['meta']['type'] == 'api':
270
  return base + "API response:\n" + input_data['text']
271
  return base + input_data['text']
@@ -294,8 +314,16 @@ class SyntheticDataGenerator:
294
  return [] # Return empty in case of parsing failure
295
  else:
296
  # Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
297
- json_output = json.loads(response.choices[0].message.content) # load the JSON data
298
- return json_output.get("qa_pairs", []) # Return the qa_pairs
 
 
 
 
 
 
 
 
299
  except Exception as e:
300
  self.log_error(f"Parse Error: {e}. Raw Response: {response}")
301
  return []
@@ -327,7 +355,7 @@ def input_sidebar(gen: SyntheticDataGenerator):
327
  st.session_state['api_key'] = api_key #Store API Key
328
 
329
  model = st.selectbox("Model", provider_cfg["models"])
330
- temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
331
 
332
  # Update session config
333
  st.session_state.config.update({
 
14
 
15
  # Constants for Default Values and API URLs
16
  HF_API_URL = "https://api-inference.huggingface.co/models/"
17
+ DEFAULT_TEMPERATURE = 0.1 # Lower Temperature
18
+ MODEL = "mixtral-8x7b-32768" #constant string
19
 
20
  class SyntheticDataGenerator:
21
  """
 
35
  },
36
  "Groq": {
37
  "client": lambda key: groq.Groq(api_key=key),
38
+ "models": [MODEL]
39
  },
40
  "HuggingFace": {
41
  "client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
 
43
  },
44
  "Google": {
45
  "client": lambda key: self._configure_google_genai(key), # Using a custom configure function
46
+ "models": ["gemini-pro"] # Use gemini-pro. Consider adding "gemini-pro" when released.
47
  },
48
  }
49
 
 
77
  'errors': []
78
  },
79
  'config': {
80
+ 'provider': "Groq",
81
+ 'model': MODEL,
82
  'temperature': DEFAULT_TEMPERATURE
83
  }
84
  }
 
161
  stream = img['stream']
162
  width = int(stream.get('Width', 0))
163
  height = int(stream.get('Height', 0))
164
+ image_data = stream.get_data() # Get the image data
165
+ if width > 0 and height > 0 and image_data: #CHECK image_data
166
+ try:
167
+ image = Image.frombytes("RGB", (width, height), image_data)
168
+ images.append({
169
+ "data": image,
170
+ "meta": {"dims": (width, height)}
171
+ })
172
+ except Exception as e:
173
+ self.log_error(f"Image Creation Error: {str(e)}") # Log specific image creation errors.
174
+ else:
175
+ self.log_error(f"Image Error: Insufficient image data or invalid dimensions (width={width}, height={height})")
176
+
177
+
178
  except Exception as e:
179
+ self.log_error(f"Image Extraction Error: {str(e)}") # More general extraction error
180
  return images
181
 
182
  # Core Generation Engine
 
208
  client = client_initializer(api_key)
209
 
210
  for i, input_data in enumerate(st.session_state.inputs):
211
+
212
  st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
213
 
214
  if st.session_state.config['provider'] == "HuggingFace":
 
230
  def _standard_inference(self, client, input_data):
231
  """Performs inference using standard OpenAI-compatible API."""
232
  try:
233
+
234
+ #st.write(input_data['text']) # debugging data
235
  return client.chat.completions.create(
236
  model=st.session_state.config['model'],
237
  messages=[{
 
262
  def _google_inference(self, client, input_data):
263
  """Performs inference using Google Generative AI API."""
264
  try:
 
265
  model = client(st.session_state.config['model']) # Instantiate the model with the selected model name
266
  response = model.generate_content(
267
  self._build_prompt(input_data),
268
  generation_config = genai.types.GenerationConfig(temperature=st.session_state.config['temperature'])
269
 
270
  )
271
+
272
+ st.write("Google API Response:") # Debugging: Print the raw response
273
+ st.write(response.text)
274
+
275
  return response
276
  except Exception as e:
277
  self.log_error(f"Google GenAI Inference Error: {e}")
 
279
 
280
  def _build_prompt(self, input_data):
281
  """Builds the prompt for the LLM based on the input data type."""
282
+ base = "Generate a JSON list of 3 dictionaries like this: \n"
283
+ base+= '[{"question":"Example Question", "answer":"Example Answer"},'
284
+ base+= '{"question":"Example Question", "answer":"Example Answer"},'
285
+ base+= '{"question":"Example Question", "answer":"Example Answer"}]'
286
+ base+= 'Here is the data:\n'
287
  if input_data['meta']['type'] == 'csv':
288
+ return base + "Data:\n" + input_data['text']
289
  elif input_data['meta']['type'] == 'api':
290
  return base + "API response:\n" + input_data['text']
291
  return base + input_data['text']
 
314
  return [] # Return empty in case of parsing failure
315
  else:
316
  # Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
317
+ if not response or not response.choices or not response.choices[0].message.content:
318
+ self.log_error("Empty or malformed response from LLM.")
319
+ return []
320
+
321
+ try:
322
+ json_output = json.loads(response.choices[0].message.content) # load the JSON data
323
+ return json_output.get("qa_pairs", []) # Return the qa_pairs
324
+ except json.JSONDecodeError as e:
325
+ self.log_error(f"JSON Parse Error: {e}. Raw Response: {response.choices[0].message.content}")
326
+ return []
327
  except Exception as e:
328
  self.log_error(f"Parse Error: {e}. Raw Response: {response}")
329
  return []
 
355
  st.session_state['api_key'] = api_key #Store API Key
356
 
357
  model = st.selectbox("Model", provider_cfg["models"])
358
+ temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) #Lower
359
 
360
  # Update session config
361
  st.session_state.config.update({