Spaces:
Sleeping
Sleeping
import plotly.graph_objects as go | |
import plotly.express as px | |
import pandas as pd | |
import numpy as np | |
import re | |
from typing import Dict, List, Union, Optional, Any | |
def create_line_chart( | |
data: Union[pd.DataFrame, Dict[str, List[Union[int, float]]], List[Dict[str, Union[int, float]]]], | |
title: str = "Line Chart", | |
x_label: str = "X-Axis", | |
y_label: str = "Y-Axis", | |
color_sequence: Optional[List[str]] = None, | |
height: int = 400, | |
width: int = 700 | |
) -> go.Figure: | |
""" | |
Create a line chart using Plotly. | |
Args: | |
data: Data for the chart. Can be a pandas DataFrame, a dictionary with lists as values, | |
or a list of dictionaries. | |
title: Title of the chart. | |
x_label: Label for the x-axis. | |
y_label: Label for the y-axis. | |
color_sequence: Optional list of colors for the lines. | |
height: Height of the chart in pixels. | |
width: Width of the chart in pixels. | |
Returns: | |
A Plotly Figure object. | |
""" | |
fig = go.Figure() | |
# Convert data to pandas DataFrame if it's not already | |
if isinstance(data, dict): | |
df = pd.DataFrame(data) | |
elif isinstance(data, list) and all(isinstance(item, dict) for item in data): | |
df = pd.DataFrame(data) | |
elif isinstance(data, pd.DataFrame): | |
df = data | |
else: | |
raise ValueError("Data must be a pandas DataFrame, a dictionary with lists as values, or a list of dictionaries.") | |
# If the DataFrame has only two columns, use them as x and y | |
if len(df.columns) == 2: | |
x_col = df.columns[0] | |
y_col = df.columns[1] | |
fig.add_trace(go.Scatter(x=df[x_col], y=df[y_col], mode='lines+markers', name=y_col)) | |
else: | |
# Assume first column is x and the rest are y values | |
x_col = df.columns[0] | |
for i, col in enumerate(df.columns[1:]): | |
color = color_sequence[i % len(color_sequence)] if color_sequence else None | |
fig.add_trace(go.Scatter( | |
x=df[x_col], | |
y=df[col], | |
mode='lines+markers', | |
name=col, | |
line=dict(color=color) if color else None | |
)) | |
# Update layout | |
fig.update_layout( | |
title=title, | |
xaxis_title=x_label, | |
yaxis_title=y_label, | |
height=height, | |
width=width, | |
template="plotly_white", | |
hovermode="x unified" | |
) | |
return fig | |
def create_bar_chart( | |
data: Union[pd.DataFrame, Dict[str, List[Union[int, float]]], List[Dict[str, Union[int, float]]]], | |
title: str = "Bar Chart", | |
x_label: str = "X-Axis", | |
y_label: str = "Y-Axis", | |
color_sequence: Optional[List[str]] = None, | |
orientation: str = 'v', # 'v' for vertical, 'h' for horizontal | |
height: int = 400, | |
width: int = 700 | |
) -> go.Figure: | |
""" | |
Create a bar chart using Plotly. | |
Args: | |
data: Data for the chart. Can be a pandas DataFrame, a dictionary with lists as values, | |
or a list of dictionaries. | |
title: Title of the chart. | |
x_label: Label for the x-axis. | |
y_label: Label for the y-axis. | |
color_sequence: Optional list of colors for the bars. | |
orientation: 'v' for vertical bars, 'h' for horizontal bars. | |
height: Height of the chart in pixels. | |
width: Width of the chart in pixels. | |
Returns: | |
A Plotly Figure object. | |
""" | |
# Convert data to pandas DataFrame if it's not already | |
if isinstance(data, dict): | |
df = pd.DataFrame(data) | |
elif isinstance(data, list) and all(isinstance(item, dict) for item in data): | |
df = pd.DataFrame(data) | |
elif isinstance(data, pd.DataFrame): | |
df = data | |
else: | |
raise ValueError("Data must be a pandas DataFrame, a dictionary with lists as values, or a list of dictionaries.") | |
# Create the bar chart | |
if orientation == 'v': | |
# If the DataFrame has only two columns, use them as x and y | |
if len(df.columns) == 2: | |
x_col = df.columns[0] | |
y_col = df.columns[1] | |
fig = px.bar(df, x=x_col, y=y_col, title=title, color_discrete_sequence=color_sequence) | |
else: | |
# For multiple columns, create a grouped bar chart | |
fig = go.Figure() | |
x_col = df.columns[0] | |
for i, col in enumerate(df.columns[1:]): | |
color = color_sequence[i % len(color_sequence)] if color_sequence else None | |
fig.add_trace(go.Bar( | |
x=df[x_col], | |
y=df[col], | |
name=col, | |
marker_color=color | |
)) | |
else: # horizontal | |
# If the DataFrame has only two columns, use them as y and x | |
if len(df.columns) == 2: | |
y_col = df.columns[0] | |
x_col = df.columns[1] | |
fig = px.bar(df, y=y_col, x=x_col, title=title, orientation='h', color_discrete_sequence=color_sequence) | |
else: | |
# For multiple columns, create a grouped bar chart | |
fig = go.Figure() | |
y_col = df.columns[0] | |
for i, col in enumerate(df.columns[1:]): | |
color = color_sequence[i % len(color_sequence)] if color_sequence else None | |
fig.add_trace(go.Bar( | |
y=df[y_col], | |
x=df[col], | |
name=col, | |
marker_color=color, | |
orientation='h' | |
)) | |
# Update layout | |
fig.update_layout( | |
title=title, | |
xaxis_title=x_label, | |
yaxis_title=y_label, | |
height=height, | |
width=width, | |
template="plotly_white", | |
barmode='group' | |
) | |
return fig | |
def create_scatter_plot( | |
data: Union[pd.DataFrame, Dict[str, List[Union[int, float]]], List[Dict[str, Union[int, float]]]], | |
title: str = "Scatter Plot", | |
x_label: str = "X-Axis", | |
y_label: str = "Y-Axis", | |
color_column: Optional[str] = None, | |
size_column: Optional[str] = None, | |
hover_data: Optional[List[str]] = None, | |
height: int = 400, | |
width: int = 700 | |
) -> go.Figure: | |
""" | |
Create a scatter plot using Plotly. | |
Args: | |
data: Data for the chart. Can be a pandas DataFrame, a dictionary with lists as values, | |
or a list of dictionaries. | |
title: Title of the chart. | |
x_label: Label for the x-axis. | |
y_label: Label for the y-axis. | |
color_column: Optional column name to use for coloring points. | |
size_column: Optional column name to use for sizing points. | |
hover_data: Optional list of column names to include in hover information. | |
height: Height of the chart in pixels. | |
width: Width of the chart in pixels. | |
Returns: | |
A Plotly Figure object. | |
""" | |
# Convert data to pandas DataFrame if it's not already | |
if isinstance(data, dict): | |
df = pd.DataFrame(data) | |
elif isinstance(data, list) and all(isinstance(item, dict) for item in data): | |
df = pd.DataFrame(data) | |
elif isinstance(data, pd.DataFrame): | |
df = data | |
else: | |
raise ValueError("Data must be a pandas DataFrame, a dictionary with lists as values, or a list of dictionaries.") | |
# If the DataFrame has only two columns, use them as x and y | |
if len(df.columns) == 2: | |
x_col = df.columns[0] | |
y_col = df.columns[1] | |
fig = px.scatter(df, x=x_col, y=y_col, title=title) | |
else: | |
# Assume first two columns are x and y, and use additional columns for color, size, etc. | |
x_col = df.columns[0] | |
y_col = df.columns[1] | |
# Create the scatter plot | |
fig = px.scatter( | |
df, | |
x=x_col, | |
y=y_col, | |
color=color_column if color_column and color_column in df.columns else None, | |
size=size_column if size_column and size_column in df.columns else None, | |
hover_data=hover_data if hover_data else None, | |
title=title | |
) | |
# Update layout | |
fig.update_layout( | |
title=title, | |
xaxis_title=x_label, | |
yaxis_title=y_label, | |
height=height, | |
width=width, | |
template="plotly_white" | |
) | |
return fig | |
def detect_visualization_request(user_input: str) -> Dict[str, Any]: | |
""" | |
Detect if the user is requesting a visualization and extract relevant information. | |
Args: | |
user_input: The user's input message. | |
Returns: | |
A dictionary containing: | |
- 'is_visualization': Boolean indicating if a visualization is requested. | |
- 'chart_type': The type of chart requested ('line', 'bar', 'scatter', or None). | |
- 'data_description': Description of the data to visualize. | |
- 'parameters': Additional parameters extracted from the request. | |
""" | |
# Convert to lowercase for case-insensitive matching | |
user_input_lower = user_input.lower() | |
# Check for visualization keywords | |
viz_keywords = ['plot', 'chart', 'graph', 'visualize', 'visualisation', 'visualization', 'display'] | |
is_visualization = any(keyword in user_input_lower for keyword in viz_keywords) | |
if not is_visualization: | |
return { | |
'is_visualization': False, | |
'chart_type': None, | |
'data_description': None, | |
'parameters': {} | |
} | |
# Detect chart type | |
chart_type = None | |
if any(term in user_input_lower for term in ['line chart', 'line graph', 'line plot']): | |
chart_type = 'line' | |
elif any(term in user_input_lower for term in ['bar chart', 'bar graph', 'histogram']): | |
chart_type = 'bar' | |
elif any(term in user_input_lower for term in ['scatter plot', 'scatter chart', 'scatter graph']): | |
chart_type = 'scatter' | |
# Extract data description | |
data_description = None | |
data_patterns = [ | |
r'(?:of|for|using|with)\s+([^.?!]+?)(?:\s+(?:by|over|across|versus|vs\.?|against))', | |
r'(?:of|for|using|with)\s+([^.?!]+?)(?:\s+data)', | |
r'(?:of|for|using|with)\s+([^.?!]+?)(?:\s+(?:from|in))' | |
] | |
for pattern in data_patterns: | |
match = re.search(pattern, user_input_lower) | |
if match: | |
data_description = match.group(1).strip() | |
break | |
# If no match found with specific patterns, try a more general approach | |
if not data_description: | |
# Look for text between the chart type and the end of the sentence | |
chart_type_terms = ['line chart', 'bar chart', 'scatter plot', 'chart', 'graph', 'plot'] | |
for term in chart_type_terms: | |
if term in user_input_lower: | |
parts = user_input_lower.split(term, 1) | |
if len(parts) > 1: | |
# Extract text after the chart type until the end of the sentence | |
after_chart_type = parts[1].strip() | |
end_sentence = re.search(r'^[^.!?]*', after_chart_type) | |
if end_sentence: | |
data_description = end_sentence.group(0).strip() | |
# Remove common prepositions at the beginning | |
data_description = re.sub(r'^(?:of|for|using|with)\s+', '', data_description) | |
break | |
# Extract additional parameters | |
parameters = {} | |
# Title | |
title_match = re.search(r'title[d:]?\s+["\']?([^"\'.?!]+)["\']?', user_input_lower) | |
if title_match: | |
parameters['title'] = title_match.group(1).strip() | |
# X-axis label | |
x_label_match = re.search(r'x[-\s]?(?:axis|label)[:]?\s+["\']?([^"\'.?!]+)["\']?', user_input_lower) | |
if x_label_match: | |
parameters['x_label'] = x_label_match.group(1).strip() | |
# Y-axis label | |
y_label_match = re.search(r'y[-\s]?(?:axis|label)[:]?\s+["\']?([^"\'.?!]+)["\']?', user_input_lower) | |
if y_label_match: | |
parameters['y_label'] = y_label_match.group(1).strip() | |
return { | |
'is_visualization': is_visualization, | |
'chart_type': chart_type, | |
'data_description': data_description, | |
'parameters': parameters | |
} | |
def generate_sample_data(data_description: str, chart_type: str) -> pd.DataFrame: | |
""" | |
Generate sample data based on the description and chart type. | |
This is a fallback when no actual data is available. | |
Args: | |
data_description: Description of the data to generate. | |
chart_type: Type of chart ('line', 'bar', 'scatter'). | |
Returns: | |
A pandas DataFrame with sample data. | |
""" | |
np.random.seed(42) # For reproducibility | |
# Default data | |
if chart_type == 'line': | |
# Generate time series data | |
dates = pd.date_range(start='2023-01-01', periods=30, freq='D') | |
values = np.cumsum(np.random.randn(30)) + 10 | |
df = pd.DataFrame({'Date': dates, 'Value': values}) | |
# Try to customize based on description | |
if data_description: | |
if 'temperature' in data_description or 'weather' in data_description: | |
df.columns = ['Date', 'Temperature (°C)'] | |
df['Temperature (°C)'] = np.random.normal(20, 5, 30) | |
elif 'stock' in data_description or 'price' in data_description: | |
df.columns = ['Date', 'Price ($)'] | |
df['Price ($)'] = 100 + np.cumsum(np.random.normal(0, 2, 30)) | |
elif 'sales' in data_description or 'revenue' in data_description: | |
df.columns = ['Date', 'Sales ($)'] | |
df['Sales ($)'] = 1000 + np.cumsum(np.random.normal(0, 100, 30)) | |
else: | |
df.columns = ['Date', data_description.capitalize() if data_description else 'Value'] | |
elif chart_type == 'bar': | |
# Generate categorical data | |
categories = ['A', 'B', 'C', 'D', 'E'] | |
values = np.random.randint(10, 100, size=len(categories)) | |
df = pd.DataFrame({'Category': categories, 'Value': values}) | |
# Try to customize based on description | |
if data_description: | |
if 'sales by region' in data_description or 'regional' in data_description: | |
df['Category'] = ['North', 'South', 'East', 'West', 'Central'] | |
df.columns = ['Region', 'Sales ($)'] | |
elif 'product' in data_description: | |
df['Category'] = ['Product A', 'Product B', 'Product C', 'Product D', 'Product E'] | |
df.columns = ['Product', 'Units Sold'] | |
elif 'age' in data_description or 'demographic' in data_description: | |
df['Category'] = ['0-18', '19-35', '36-50', '51-65', '65+'] | |
df.columns = ['Age Group', 'Count'] | |
else: | |
df.columns = ['Category', data_description.capitalize() if data_description else 'Value'] | |
elif chart_type == 'scatter': | |
# Generate x-y data | |
x = np.random.normal(0, 1, 50) | |
y = x + np.random.normal(0, 0.5, 50) | |
df = pd.DataFrame({'X': x, 'Y': y}) | |
# Try to customize based on description | |
if data_description: | |
if 'height' in data_description and 'weight' in data_description: | |
df['X'] = np.random.normal(170, 10, 50) # Heights in cm | |
df['Y'] = df['X'] * 0.5 + np.random.normal(0, 5, 50) # Weights in kg | |
df.columns = ['Height (cm)', 'Weight (kg)'] | |
elif 'age' in data_description and ('income' in data_description or 'salary' in data_description): | |
df['X'] = np.random.normal(40, 10, 50) # Ages | |
df['Y'] = df['X'] * 1000 + 20000 + np.random.normal(0, 5000, 50) # Incomes | |
df.columns = ['Age', 'Income ($)'] | |
elif 'study' in data_description or 'exam' in data_description: | |
df['X'] = np.random.normal(5, 2, 50) # Study hours | |
df['Y'] = df['X'] * 10 + 50 + np.random.normal(0, 5, 50) # Exam scores | |
df.columns = ['Study Hours', 'Exam Score'] | |
else: | |
x_label = 'X' | |
y_label = 'Y' | |
if ' vs ' in data_description: | |
parts = data_description.split(' vs ') | |
if len(parts) == 2: | |
x_label = parts[0].strip().capitalize() | |
y_label = parts[1].strip().capitalize() | |
df.columns = [x_label, y_label] | |
else: | |
# Default fallback | |
df = pd.DataFrame({ | |
'X': range(1, 11), | |
'Y': np.random.randint(1, 100, 10) | |
}) | |
return df |