Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,8 @@ from typing import Dict, Any
|
|
14 |
|
15 |
# Constants for Default Values and API URLs
|
16 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
17 |
-
DEFAULT_TEMPERATURE = 0.
|
|
|
18 |
|
19 |
class SyntheticDataGenerator:
|
20 |
"""
|
@@ -34,7 +35,7 @@ class SyntheticDataGenerator:
|
|
34 |
},
|
35 |
"Groq": {
|
36 |
"client": lambda key: groq.Groq(api_key=key),
|
37 |
-
"models": [
|
38 |
},
|
39 |
"HuggingFace": {
|
40 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
@@ -42,7 +43,7 @@ class SyntheticDataGenerator:
|
|
42 |
},
|
43 |
"Google": {
|
44 |
"client": lambda key: self._configure_google_genai(key), # Using a custom configure function
|
45 |
-
"models": ["gemini-
|
46 |
},
|
47 |
}
|
48 |
|
@@ -76,8 +77,8 @@ class SyntheticDataGenerator:
|
|
76 |
'errors': []
|
77 |
},
|
78 |
'config': {
|
79 |
-
'provider': "
|
80 |
-
'model':
|
81 |
'temperature': DEFAULT_TEMPERATURE
|
82 |
}
|
83 |
}
|
@@ -160,13 +161,22 @@ class SyntheticDataGenerator:
|
|
160 |
stream = img['stream']
|
161 |
width = int(stream.get('Width', 0))
|
162 |
height = int(stream.get('Height', 0))
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
except Exception as e:
|
169 |
-
self.log_error(f"Image Error: {str(e)}")
|
170 |
return images
|
171 |
|
172 |
# Core Generation Engine
|
@@ -198,6 +208,7 @@ class SyntheticDataGenerator:
|
|
198 |
client = client_initializer(api_key)
|
199 |
|
200 |
for i, input_data in enumerate(st.session_state.inputs):
|
|
|
201 |
st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
|
202 |
|
203 |
if st.session_state.config['provider'] == "HuggingFace":
|
@@ -219,6 +230,8 @@ class SyntheticDataGenerator:
|
|
219 |
def _standard_inference(self, client, input_data):
|
220 |
"""Performs inference using standard OpenAI-compatible API."""
|
221 |
try:
|
|
|
|
|
222 |
return client.chat.completions.create(
|
223 |
model=st.session_state.config['model'],
|
224 |
messages=[{
|
@@ -249,13 +262,16 @@ class SyntheticDataGenerator:
|
|
249 |
def _google_inference(self, client, input_data):
|
250 |
"""Performs inference using Google Generative AI API."""
|
251 |
try:
|
252 |
-
|
253 |
model = client(st.session_state.config['model']) # Instantiate the model with the selected model name
|
254 |
response = model.generate_content(
|
255 |
self._build_prompt(input_data),
|
256 |
generation_config = genai.types.GenerationConfig(temperature=st.session_state.config['temperature'])
|
257 |
|
258 |
)
|
|
|
|
|
|
|
|
|
259 |
return response
|
260 |
except Exception as e:
|
261 |
self.log_error(f"Google GenAI Inference Error: {e}")
|
@@ -263,9 +279,13 @@ class SyntheticDataGenerator:
|
|
263 |
|
264 |
def _build_prompt(self, input_data):
|
265 |
"""Builds the prompt for the LLM based on the input data type."""
|
266 |
-
base = "Generate
|
|
|
|
|
|
|
|
|
267 |
if input_data['meta']['type'] == 'csv':
|
268 |
-
return base + "
|
269 |
elif input_data['meta']['type'] == 'api':
|
270 |
return base + "API response:\n" + input_data['text']
|
271 |
return base + input_data['text']
|
@@ -294,8 +314,16 @@ class SyntheticDataGenerator:
|
|
294 |
return [] # Return empty in case of parsing failure
|
295 |
else:
|
296 |
# Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
except Exception as e:
|
300 |
self.log_error(f"Parse Error: {e}. Raw Response: {response}")
|
301 |
return []
|
@@ -327,7 +355,7 @@ def input_sidebar(gen: SyntheticDataGenerator):
|
|
327 |
st.session_state['api_key'] = api_key #Store API Key
|
328 |
|
329 |
model = st.selectbox("Model", provider_cfg["models"])
|
330 |
-
temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE)
|
331 |
|
332 |
# Update session config
|
333 |
st.session_state.config.update({
|
|
|
14 |
|
15 |
# Constants for Default Values and API URLs
|
16 |
HF_API_URL = "https://api-inference.huggingface.co/models/"
|
17 |
+
DEFAULT_TEMPERATURE = 0.1 # Lower Temperature
|
18 |
+
MODEL = "mixtral-8x7b-32768" #constant string
|
19 |
|
20 |
class SyntheticDataGenerator:
|
21 |
"""
|
|
|
35 |
},
|
36 |
"Groq": {
|
37 |
"client": lambda key: groq.Groq(api_key=key),
|
38 |
+
"models": [MODEL]
|
39 |
},
|
40 |
"HuggingFace": {
|
41 |
"client": lambda key: {"headers": {"Authorization": f"Bearer {key}"}},
|
|
|
43 |
},
|
44 |
"Google": {
|
45 |
"client": lambda key: self._configure_google_genai(key), # Using a custom configure function
|
46 |
+
"models": ["gemini-pro"] # Use gemini-pro. Consider adding "gemini-pro" when released.
|
47 |
},
|
48 |
}
|
49 |
|
|
|
77 |
'errors': []
|
78 |
},
|
79 |
'config': {
|
80 |
+
'provider': "Groq",
|
81 |
+
'model': MODEL,
|
82 |
'temperature': DEFAULT_TEMPERATURE
|
83 |
}
|
84 |
}
|
|
|
161 |
stream = img['stream']
|
162 |
width = int(stream.get('Width', 0))
|
163 |
height = int(stream.get('Height', 0))
|
164 |
+
image_data = stream.get_data() # Get the image data
|
165 |
+
if width > 0 and height > 0 and image_data: #CHECK image_data
|
166 |
+
try:
|
167 |
+
image = Image.frombytes("RGB", (width, height), image_data)
|
168 |
+
images.append({
|
169 |
+
"data": image,
|
170 |
+
"meta": {"dims": (width, height)}
|
171 |
+
})
|
172 |
+
except Exception as e:
|
173 |
+
self.log_error(f"Image Creation Error: {str(e)}") # Log specific image creation errors.
|
174 |
+
else:
|
175 |
+
self.log_error(f"Image Error: Insufficient image data or invalid dimensions (width={width}, height={height})")
|
176 |
+
|
177 |
+
|
178 |
except Exception as e:
|
179 |
+
self.log_error(f"Image Extraction Error: {str(e)}") # More general extraction error
|
180 |
return images
|
181 |
|
182 |
# Core Generation Engine
|
|
|
208 |
client = client_initializer(api_key)
|
209 |
|
210 |
for i, input_data in enumerate(st.session_state.inputs):
|
211 |
+
|
212 |
st.session_state.processing['progress'] = (i+1)/len(st.session_state.inputs)
|
213 |
|
214 |
if st.session_state.config['provider'] == "HuggingFace":
|
|
|
230 |
def _standard_inference(self, client, input_data):
|
231 |
"""Performs inference using standard OpenAI-compatible API."""
|
232 |
try:
|
233 |
+
|
234 |
+
#st.write(input_data['text']) # debugging data
|
235 |
return client.chat.completions.create(
|
236 |
model=st.session_state.config['model'],
|
237 |
messages=[{
|
|
|
262 |
def _google_inference(self, client, input_data):
|
263 |
"""Performs inference using Google Generative AI API."""
|
264 |
try:
|
|
|
265 |
model = client(st.session_state.config['model']) # Instantiate the model with the selected model name
|
266 |
response = model.generate_content(
|
267 |
self._build_prompt(input_data),
|
268 |
generation_config = genai.types.GenerationConfig(temperature=st.session_state.config['temperature'])
|
269 |
|
270 |
)
|
271 |
+
|
272 |
+
st.write("Google API Response:") # Debugging: Print the raw response
|
273 |
+
st.write(response.text)
|
274 |
+
|
275 |
return response
|
276 |
except Exception as e:
|
277 |
self.log_error(f"Google GenAI Inference Error: {e}")
|
|
|
279 |
|
280 |
def _build_prompt(self, input_data):
|
281 |
"""Builds the prompt for the LLM based on the input data type."""
|
282 |
+
base = "Generate a JSON list of 3 dictionaries like this: \n"
|
283 |
+
base+= '[{"question":"Example Question", "answer":"Example Answer"},'
|
284 |
+
base+= '{"question":"Example Question", "answer":"Example Answer"},'
|
285 |
+
base+= '{"question":"Example Question", "answer":"Example Answer"}]'
|
286 |
+
base+= 'Here is the data:\n'
|
287 |
if input_data['meta']['type'] == 'csv':
|
288 |
+
return base + "Data:\n" + input_data['text']
|
289 |
elif input_data['meta']['type'] == 'api':
|
290 |
return base + "API response:\n" + input_data['text']
|
291 |
return base + input_data['text']
|
|
|
314 |
return [] # Return empty in case of parsing failure
|
315 |
else:
|
316 |
# Assuming JSON response from other providers (OpenAI, Deepseek, Groq)
|
317 |
+
if not response or not response.choices or not response.choices[0].message.content:
|
318 |
+
self.log_error("Empty or malformed response from LLM.")
|
319 |
+
return []
|
320 |
+
|
321 |
+
try:
|
322 |
+
json_output = json.loads(response.choices[0].message.content) # load the JSON data
|
323 |
+
return json_output.get("qa_pairs", []) # Return the qa_pairs
|
324 |
+
except json.JSONDecodeError as e:
|
325 |
+
self.log_error(f"JSON Parse Error: {e}. Raw Response: {response.choices[0].message.content}")
|
326 |
+
return []
|
327 |
except Exception as e:
|
328 |
self.log_error(f"Parse Error: {e}. Raw Response: {response}")
|
329 |
return []
|
|
|
355 |
st.session_state['api_key'] = api_key #Store API Key
|
356 |
|
357 |
model = st.selectbox("Model", provider_cfg["models"])
|
358 |
+
temp = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE) #Lower
|
359 |
|
360 |
# Update session config
|
361 |
st.session_state.config.update({
|