dolphinium commited on
Commit
8290c25
·
1 Parent(s): 840c57d

enhance viz code generation prompt

Browse files
Files changed (1) hide show
  1. app.py +161 -20
app.py CHANGED
@@ -328,33 +328,174 @@ def llm_generate_visualization_code(query_context, facet_data):
328
  """Generates Python code for visualization based on query and data."""
329
  prompt = f"""
330
  You are a Python Data Visualization expert specializing in Matplotlib and Seaborn.
331
- Your task is to generate Python code to create a single, insightful visualization.
332
 
333
- **Context:**
334
- 1. **User's Analytical Goal:** "{query_context}"
335
- 2. **Aggregated Data (from Solr Facets):**
336
- ```json
337
- {json.dumps(facet_data, indent=2)}
338
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
- **Instructions:**
341
- 1. **Goal:** Write Python code to generate a chart that best visualizes the answer to the user's goal using the provided data.
342
- 2. **Data Access:** The data is available in a Python dictionary named `facet_data`. Your code must parse this dictionary.
343
- 3. **Code Requirements:**
344
- * Start with `import matplotlib.pyplot as plt` and `import seaborn as sns`.
345
- * Use `plt.style.use('seaborn-v0_8-whitegrid')` and `fig, ax = plt.subplots(figsize=(12, 7))`. Plot using the `ax` object.
346
- * Always include a clear `ax.set_title(...)`, `ax.set_xlabel(...)`, and `ax.set_ylabel(...)`.
347
- * Dynamically find the primary facet key and extract the 'buckets'.
348
- * For each bucket, extract the 'val' (label) and the relevant metric ('count' or a nested metric).
349
- * Use `plt.tight_layout()` and rotate x-axis labels if needed.
350
- 4. **Output Format:** ONLY output raw Python code. Do not wrap it in ```python ... ```. Do not include `plt.show()` or any explanation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  """
352
  try:
353
- response = llm_model.generate_content(prompt)
 
 
 
354
  code = re.sub(r'^```python\s*|\s*```$', '', response.text, flags=re.MULTILINE)
355
  return code
356
  except Exception as e:
357
- print(f"Error in llm_generate_visualization_code: {e}")
358
  return None
359
 
360
  def execute_viz_code_and_get_path(viz_code, facet_data):
 
328
  """Generates Python code for visualization based on query and data."""
329
  prompt = f"""
330
  You are a Python Data Visualization expert specializing in Matplotlib and Seaborn.
331
+ Your task is to generate robust, error-free Python code to create a single, insightful visualization based on the user's query and the provided Solr facet data.
332
 
333
+ **User's Analytical Goal:**
334
+ "{query_context}"
335
+
336
+ **Aggregated Data (from Solr Facets):**
337
+ ```json
338
+ {json.dumps(facet_data, indent=2)}
339
+ ```
340
+
341
+ ---
342
+ ### **CRITICAL INSTRUCTIONS: CODE GENERATION RULES**
343
+ You MUST follow these rules to avoid errors.
344
+
345
+ **1. Identify the Data Structure FIRST:**
346
+ Before writing any code, analyze the `facet_data` JSON to determine its structure. There are three common patterns. Choose the correct template below.
347
+
348
+ * **Pattern A: Simple `terms` Facet.** The JSON has ONE main key (besides "count") which contains a list of "buckets". Each bucket has a "val" and a "count". Use this for standard bar charts.
349
+ * **Pattern B: Multiple `query` Facets.** The JSON has MULTIPLE keys (besides "count"), and each key is an object containing metrics like "count" or "sum(...)". Use this for comparing a few distinct items (e.g., "oral vs injection").
350
+ * **Pattern C: Nested `terms` Facet.** The JSON has one main key with a list of "buckets", but inside EACH bucket, there are nested metric objects. This is used for grouped comparisons (e.g., "compare 2024 vs 2025 across categories"). This almost always requires `pandas`.
351
+
352
+ **2. Use the Correct Parsing Template:**
353
+
354
+ ---
355
+ **TEMPLATE FOR PATTERN A (Simple Bar Chart from `terms` facet):**
356
+ ```python
357
+ import matplotlib.pyplot as plt
358
+ import seaborn as sns
359
+ import pandas as pd
360
+
361
+ plt.style.use('seaborn-v0_8-whitegrid')
362
+ fig, ax = plt.subplots(figsize=(12, 8))
363
+
364
+ # Dynamically find the main facet key (the one with 'buckets')
365
+ facet_key = None
366
+ for key, value in facet_data.items():
367
+ if isinstance(value, dict) and 'buckets' in value:
368
+ facet_key = key
369
+ break
370
+
371
+ if facet_key:
372
+ buckets = facet_data[facet_key].get('buckets', [])
373
+ # Check if buckets contain data
374
+ if buckets:
375
+ df = pd.DataFrame(buckets)
376
+ # Check for a nested metric or use 'count'
377
+ if 'total_deal_value' in df.columns and pd.api.types.is_dict_like(df['total_deal_value'].iloc):
378
+ # Example for nested sum metric
379
+ df['value'] = df['total_deal_value'].apply(lambda x: x.get('sum', 0))
380
+ y_axis_label = 'Sum of Total Deal Value'
381
+ else:
382
+ df.rename(columns={{'count': 'value'}}, inplace=True)
383
+ y_axis_label = 'Count'
384
+
385
+ sns.barplot(data=df, x='val', y='value', ax=ax, palette='viridis')
386
+ ax.set_xlabel('Category')
387
+ ax.set_ylabel(y_axis_label)
388
+ else:
389
+ ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
390
+
391
+
392
+ ax.set_title('Your Insightful Title Here')
393
+ # Correct way to rotate labels to prevent errors
394
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
395
+ plt.tight_layout()
396
+ ```
397
+ ---
398
+ **TEMPLATE FOR PATTERN B (Comparison Bar Chart from `query` facets):**
399
+ ```python
400
+ import matplotlib.pyplot as plt
401
+ import seaborn as sns
402
+ import pandas as pd
403
+
404
+ plt.style.use('seaborn-v0_8-whitegrid')
405
+ fig, ax = plt.subplots(figsize=(10, 6))
406
+
407
+ labels = []
408
+ values = []
409
+ # Iterate through top-level keys, skipping the 'count'
410
+ for key, data_dict in facet_data.items():
411
+ if key == 'count' or not isinstance(data_dict, dict):
412
+ continue
413
+ # Extract the label (e.g., 'oral_deals' -> 'Oral')
414
+ label = key.replace('_deals', '').replace('_', ' ').title()
415
+ # Find the metric value, which is NOT 'count'
416
+ metric_value = 0
417
+ for sub_key, sub_value in data_dict.items():
418
+ if sub_key != 'count':
419
+ metric_value = sub_value
420
+ break # Found the metric
421
+ labels.append(label)
422
+ values.append(metric_value)
423
+
424
+ if labels:
425
+ sns.barplot(x=labels, y=values, ax=ax, palette='mako')
426
+ ax.set_ylabel('Total Deal Value') # Or other metric name
427
+ ax.set_xlabel('Category')
428
+ else:
429
+ ax.text(0.5, 0.5, 'No query facet data to plot.', ha='center')
430
+
431
+
432
+ ax.set_title('Your Insightful Title Here')
433
+ plt.tight_layout()
434
+ ```
435
+ ---
436
+ **TEMPLATE FOR PATTERN C (Grouped Bar Chart from nested `terms` facet):**
437
+ ```python
438
+ import matplotlib.pyplot as plt
439
+ import seaborn as sns
440
+ import pandas as pd
441
 
442
+ plt.style.use('seaborn-v0_8-whitegrid')
443
+ fig, ax = plt.subplots(figsize=(14, 8))
444
+
445
+ # Find the key that has the buckets
446
+ facet_key = None
447
+ for key, value in facet_data.items():
448
+ if isinstance(value, dict) and 'buckets' in value:
449
+ facet_key = key
450
+ break
451
+
452
+ if facet_key and facet_data[facet_key].get('buckets'):
453
+ # This list comprehension is robust for parsing nested metrics
454
+ plot_data = []
455
+ for bucket in facet_data[facet_key]['buckets']:
456
+ category = bucket['val']
457
+ # Find all nested metrics (e.g., total_deal_value_2025)
458
+ for sub_key, sub_value in bucket.items():
459
+ if isinstance(sub_value, dict) and 'sum' in sub_value:
460
+ # Extracts year from 'total_deal_value_2025' -> '2025'
461
+ year = sub_key.split('_')[-1]
462
+ value = sub_value['sum']
463
+ plot_data.append({{'Category': category, 'Year': year, 'Value': value}})
464
+
465
+ if plot_data:
466
+ df = pd.DataFrame(plot_data)
467
+ sns.barplot(data=df, x='Category', y='Value', hue='Year', ax=ax)
468
+ ax.set_ylabel('Total Deal Value')
469
+ ax.set_xlabel('Business Model')
470
+ # Correct way to rotate labels to prevent errors
471
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
472
+ else:
473
+ ax.text(0.5, 0.5, 'No nested data found to plot.', ha='center')
474
+ else:
475
+ ax.text(0.5, 0.5, 'No data in buckets to plot.', ha='center')
476
+
477
+ ax.set_title('Your Insightful Title Here')
478
+ plt.tight_layout()
479
+ ```
480
+ ---
481
+ **3. Final Code Generation:**
482
+ - **DO NOT** include `plt.show()`.
483
+ - **DO** set a dynamic and descriptive `ax.set_title()`, `ax.set_xlabel()`, and `ax.set_ylabel()`.
484
+ - **DO NOT** wrap the code in ```python ... ```. Output only the raw Python code.
485
+ - Adapt the chosen template to the specific keys and metrics in the provided `facet_data`.
486
+
487
+ **Your Task:**
488
+ Now, generate the Python code.
489
  """
490
  try:
491
+ # Increase the timeout for potentially complex generation
492
+ generation_config = genai.types.GenerationConfig(temperature=0, max_output_tokens=2048)
493
+ response = llm_model.generate_content(prompt, generation_config=generation_config)
494
+ # Clean the response to remove markdown formatting
495
  code = re.sub(r'^```python\s*|\s*```$', '', response.text, flags=re.MULTILINE)
496
  return code
497
  except Exception as e:
498
+ print(f"Error in llm_generate_visualization_code: {e}\nRaw response: {response.text}")
499
  return None
500
 
501
  def execute_viz_code_and_get_path(viz_code, facet_data):