Spaces:
Sleeping
Sleeping
dolphinium
commited on
Commit
·
8290c25
1
Parent(s):
840c57d
enhance viz code generation prompt
Browse files
app.py
CHANGED
@@ -328,33 +328,174 @@ def llm_generate_visualization_code(query_context, facet_data):
|
|
328 |
"""Generates Python code for visualization based on query and data."""
|
329 |
prompt = f"""
|
330 |
You are a Python Data Visualization expert specializing in Matplotlib and Seaborn.
|
331 |
-
Your task is to generate Python code to create a single, insightful visualization.
|
332 |
|
333 |
-
**
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
"""
|
352 |
try:
|
353 |
-
|
|
|
|
|
|
|
354 |
code = re.sub(r'^```python\s*|\s*```$', '', response.text, flags=re.MULTILINE)
|
355 |
return code
|
356 |
except Exception as e:
|
357 |
-
print(f"Error in llm_generate_visualization_code: {e}")
|
358 |
return None
|
359 |
|
360 |
def execute_viz_code_and_get_path(viz_code, facet_data):
|
|
|
328 |
"""Generates Python code for visualization based on query and data."""
|
329 |
prompt = f"""
|
330 |
You are a Python Data Visualization expert specializing in Matplotlib and Seaborn.
|
331 |
+
Your task is to generate robust, error-free Python code to create a single, insightful visualization based on the user's query and the provided Solr facet data.
|
332 |
|
333 |
+
**User's Analytical Goal:**
|
334 |
+
"{query_context}"
|
335 |
+
|
336 |
+
**Aggregated Data (from Solr Facets):**
|
337 |
+
```json
|
338 |
+
{json.dumps(facet_data, indent=2)}
|
339 |
+
```
|
340 |
+
|
341 |
+
---
|
342 |
+
### **CRITICAL INSTRUCTIONS: CODE GENERATION RULES**
|
343 |
+
You MUST follow these rules to avoid errors.
|
344 |
+
|
345 |
+
**1. Identify the Data Structure FIRST:**
|
346 |
+
Before writing any code, analyze the `facet_data` JSON to determine its structure. There are three common patterns. Choose the correct template below.
|
347 |
+
|
348 |
+
* **Pattern A: Simple `terms` Facet.** The JSON has ONE main key (besides "count") which contains a list of "buckets". Each bucket has a "val" and a "count". Use this for standard bar charts.
|
349 |
+
* **Pattern B: Multiple `query` Facets.** The JSON has MULTIPLE keys (besides "count"), and each key is an object containing metrics like "count" or "sum(...)". Use this for comparing a few distinct items (e.g., "oral vs injection").
|
350 |
+
* **Pattern C: Nested `terms` Facet.** The JSON has one main key with a list of "buckets", but inside EACH bucket, there are nested metric objects. This is used for grouped comparisons (e.g., "compare 2024 vs 2025 across categories"). This almost always requires `pandas`.
|
351 |
+
|
352 |
+
**2. Use the Correct Parsing Template:**
|
353 |
+
|
354 |
+
---
|
355 |
+
**TEMPLATE FOR PATTERN A (Simple Bar Chart from `terms` facet):**
|
356 |
+
```python
|
357 |
+
import matplotlib.pyplot as plt
|
358 |
+
import seaborn as sns
|
359 |
+
import pandas as pd
|
360 |
+
|
361 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
362 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
363 |
+
|
364 |
+
# Dynamically find the main facet key (the one with 'buckets')
|
365 |
+
facet_key = None
|
366 |
+
for key, value in facet_data.items():
|
367 |
+
if isinstance(value, dict) and 'buckets' in value:
|
368 |
+
facet_key = key
|
369 |
+
break
|
370 |
+
|
371 |
+
if facet_key:
|
372 |
+
buckets = facet_data[facet_key].get('buckets', [])
|
373 |
+
# Check if buckets contain data
|
374 |
+
if buckets:
|
375 |
+
df = pd.DataFrame(buckets)
|
376 |
+
# Check for a nested metric or use 'count'
|
377 |
+
if 'total_deal_value' in df.columns and pd.api.types.is_dict_like(df['total_deal_value'].iloc):
|
378 |
+
# Example for nested sum metric
|
379 |
+
df['value'] = df['total_deal_value'].apply(lambda x: x.get('sum', 0))
|
380 |
+
y_axis_label = 'Sum of Total Deal Value'
|
381 |
+
else:
|
382 |
+
df.rename(columns={{'count': 'value'}}, inplace=True)
|
383 |
+
y_axis_label = 'Count'
|
384 |
+
|
385 |
+
sns.barplot(data=df, x='val', y='value', ax=ax, palette='viridis')
|
386 |
+
ax.set_xlabel('Category')
|
387 |
+
ax.set_ylabel(y_axis_label)
|
388 |
+
else:
|
389 |
+
ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
|
390 |
+
|
391 |
+
|
392 |
+
ax.set_title('Your Insightful Title Here')
|
393 |
+
# Correct way to rotate labels to prevent errors
|
394 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
395 |
+
plt.tight_layout()
|
396 |
+
```
|
397 |
+
---
|
398 |
+
**TEMPLATE FOR PATTERN B (Comparison Bar Chart from `query` facets):**
|
399 |
+
```python
|
400 |
+
import matplotlib.pyplot as plt
|
401 |
+
import seaborn as sns
|
402 |
+
import pandas as pd
|
403 |
+
|
404 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
405 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
406 |
+
|
407 |
+
labels = []
|
408 |
+
values = []
|
409 |
+
# Iterate through top-level keys, skipping the 'count'
|
410 |
+
for key, data_dict in facet_data.items():
|
411 |
+
if key == 'count' or not isinstance(data_dict, dict):
|
412 |
+
continue
|
413 |
+
# Extract the label (e.g., 'oral_deals' -> 'Oral')
|
414 |
+
label = key.replace('_deals', '').replace('_', ' ').title()
|
415 |
+
# Find the metric value, which is NOT 'count'
|
416 |
+
metric_value = 0
|
417 |
+
for sub_key, sub_value in data_dict.items():
|
418 |
+
if sub_key != 'count':
|
419 |
+
metric_value = sub_value
|
420 |
+
break # Found the metric
|
421 |
+
labels.append(label)
|
422 |
+
values.append(metric_value)
|
423 |
+
|
424 |
+
if labels:
|
425 |
+
sns.barplot(x=labels, y=values, ax=ax, palette='mako')
|
426 |
+
ax.set_ylabel('Total Deal Value') # Or other metric name
|
427 |
+
ax.set_xlabel('Category')
|
428 |
+
else:
|
429 |
+
ax.text(0.5, 0.5, 'No query facet data to plot.', ha='center')
|
430 |
+
|
431 |
+
|
432 |
+
ax.set_title('Your Insightful Title Here')
|
433 |
+
plt.tight_layout()
|
434 |
+
```
|
435 |
+
---
|
436 |
+
**TEMPLATE FOR PATTERN C (Grouped Bar Chart from nested `terms` facet):**
|
437 |
+
```python
|
438 |
+
import matplotlib.pyplot as plt
|
439 |
+
import seaborn as sns
|
440 |
+
import pandas as pd
|
441 |
|
442 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
443 |
+
fig, ax = plt.subplots(figsize=(14, 8))
|
444 |
+
|
445 |
+
# Find the key that has the buckets
|
446 |
+
facet_key = None
|
447 |
+
for key, value in facet_data.items():
|
448 |
+
if isinstance(value, dict) and 'buckets' in value:
|
449 |
+
facet_key = key
|
450 |
+
break
|
451 |
+
|
452 |
+
if facet_key and facet_data[facet_key].get('buckets'):
|
453 |
+
# This list comprehension is robust for parsing nested metrics
|
454 |
+
plot_data = []
|
455 |
+
for bucket in facet_data[facet_key]['buckets']:
|
456 |
+
category = bucket['val']
|
457 |
+
# Find all nested metrics (e.g., total_deal_value_2025)
|
458 |
+
for sub_key, sub_value in bucket.items():
|
459 |
+
if isinstance(sub_value, dict) and 'sum' in sub_value:
|
460 |
+
# Extracts year from 'total_deal_value_2025' -> '2025'
|
461 |
+
year = sub_key.split('_')[-1]
|
462 |
+
value = sub_value['sum']
|
463 |
+
plot_data.append({{'Category': category, 'Year': year, 'Value': value}})
|
464 |
+
|
465 |
+
if plot_data:
|
466 |
+
df = pd.DataFrame(plot_data)
|
467 |
+
sns.barplot(data=df, x='Category', y='Value', hue='Year', ax=ax)
|
468 |
+
ax.set_ylabel('Total Deal Value')
|
469 |
+
ax.set_xlabel('Business Model')
|
470 |
+
# Correct way to rotate labels to prevent errors
|
471 |
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
472 |
+
else:
|
473 |
+
ax.text(0.5, 0.5, 'No nested data found to plot.', ha='center')
|
474 |
+
else:
|
475 |
+
ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
|
476 |
+
|
477 |
+
ax.set_title('Your Insightful Title Here')
|
478 |
+
plt.tight_layout()
|
479 |
+
```
|
480 |
+
---
|
481 |
+
**3. Final Code Generation:**
|
482 |
+
- **DO NOT** include `plt.show()`.
|
483 |
+
- **DO** set a dynamic and descriptive `ax.set_title()`, `ax.set_xlabel()`, and `ax.set_ylabel()`.
|
484 |
+
- **DO NOT** wrap the code in ```python ... ```. Output only the raw Python code.
|
485 |
+
- Adapt the chosen template to the specific keys and metrics in the provided `facet_data`.
|
486 |
+
|
487 |
+
**Your Task:**
|
488 |
+
Now, generate the Python code.
|
489 |
"""
|
490 |
try:
|
491 |
+
# Increase the timeout for potentially complex generation
|
492 |
+
generation_config = genai.types.GenerationConfig(temperature=0, max_output_tokens=2048)
|
493 |
+
response = llm_model.generate_content(prompt, generation_config=generation_config)
|
494 |
+
# Clean the response to remove markdown formatting
|
495 |
code = re.sub(r'^```python\s*|\s*```$', '', response.text, flags=re.MULTILINE)
|
496 |
return code
|
497 |
except Exception as e:
|
498 |
+
print(f"Error in llm_generate_visualization_code: {e}\nRaw response: {response.text}")
|
499 |
return None
|
500 |
|
501 |
def execute_viz_code_and_get_path(viz_code, facet_data):
|