File size: 16,611 Bytes
47728dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
import re
from typing import Dict, List, Union, Optional, Any

def create_line_chart(
    data: Union[pd.DataFrame, Dict[str, List[Union[int, float]]], List[Dict[str, Union[int, float]]]],
    title: str = "Line Chart",
    x_label: str = "X-Axis",
    y_label: str = "Y-Axis",
    color_sequence: Optional[List[str]] = None,
    height: int = 400,
    width: int = 700
) -> go.Figure:
    """
    Create a line chart using Plotly.
    
    Args:
        data: Data for the chart. Can be a pandas DataFrame, a dictionary with lists as values,
              or a list of dictionaries.
        title: Title of the chart.
        x_label: Label for the x-axis.
        y_label: Label for the y-axis.
        color_sequence: Optional list of colors for the lines.
        height: Height of the chart in pixels.
        width: Width of the chart in pixels.
        
    Returns:
        A Plotly Figure object.
    """
    fig = go.Figure()
    
    # Convert data to pandas DataFrame if it's not already
    if isinstance(data, dict):
        df = pd.DataFrame(data)
    elif isinstance(data, list) and all(isinstance(item, dict) for item in data):
        df = pd.DataFrame(data)
    elif isinstance(data, pd.DataFrame):
        df = data
    else:
        raise ValueError("Data must be a pandas DataFrame, a dictionary with lists as values, or a list of dictionaries.")
    
    # If the DataFrame has only two columns, use them as x and y
    if len(df.columns) == 2:
        x_col = df.columns[0]
        y_col = df.columns[1]
        fig.add_trace(go.Scatter(x=df[x_col], y=df[y_col], mode='lines+markers', name=y_col))
    else:
        # Assume first column is x and the rest are y values
        x_col = df.columns[0]
        for i, col in enumerate(df.columns[1:]):
            color = color_sequence[i % len(color_sequence)] if color_sequence else None
            fig.add_trace(go.Scatter(
                x=df[x_col], 
                y=df[col], 
                mode='lines+markers', 
                name=col,
                line=dict(color=color) if color else None
            ))
    
    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title=x_label,
        yaxis_title=y_label,
        height=height,
        width=width,
        template="plotly_white",
        hovermode="x unified"
    )
    
    return fig

def create_bar_chart(
    data: Union[pd.DataFrame, Dict[str, List[Union[int, float]]], List[Dict[str, Union[int, float]]]],
    title: str = "Bar Chart",
    x_label: str = "X-Axis",
    y_label: str = "Y-Axis",
    color_sequence: Optional[List[str]] = None,
    orientation: str = 'v',  # 'v' for vertical, 'h' for horizontal
    height: int = 400,
    width: int = 700
) -> go.Figure:
    """
    Create a bar chart using Plotly.
    
    Args:
        data: Data for the chart. Can be a pandas DataFrame, a dictionary with lists as values,
              or a list of dictionaries.
        title: Title of the chart.
        x_label: Label for the x-axis.
        y_label: Label for the y-axis.
        color_sequence: Optional list of colors for the bars.
        orientation: 'v' for vertical bars, 'h' for horizontal bars.
        height: Height of the chart in pixels.
        width: Width of the chart in pixels.
        
    Returns:
        A Plotly Figure object.
    """
    # Convert data to pandas DataFrame if it's not already
    if isinstance(data, dict):
        df = pd.DataFrame(data)
    elif isinstance(data, list) and all(isinstance(item, dict) for item in data):
        df = pd.DataFrame(data)
    elif isinstance(data, pd.DataFrame):
        df = data
    else:
        raise ValueError("Data must be a pandas DataFrame, a dictionary with lists as values, or a list of dictionaries.")
    
    # Create the bar chart
    if orientation == 'v':
        # If the DataFrame has only two columns, use them as x and y
        if len(df.columns) == 2:
            x_col = df.columns[0]
            y_col = df.columns[1]
            fig = px.bar(df, x=x_col, y=y_col, title=title, color_discrete_sequence=color_sequence)
        else:
            # For multiple columns, create a grouped bar chart
            fig = go.Figure()
            x_col = df.columns[0]
            for i, col in enumerate(df.columns[1:]):
                color = color_sequence[i % len(color_sequence)] if color_sequence else None
                fig.add_trace(go.Bar(
                    x=df[x_col],
                    y=df[col],
                    name=col,
                    marker_color=color
                ))
    else:  # horizontal
        # If the DataFrame has only two columns, use them as y and x
        if len(df.columns) == 2:
            y_col = df.columns[0]
            x_col = df.columns[1]
            fig = px.bar(df, y=y_col, x=x_col, title=title, orientation='h', color_discrete_sequence=color_sequence)
        else:
            # For multiple columns, create a grouped bar chart
            fig = go.Figure()
            y_col = df.columns[0]
            for i, col in enumerate(df.columns[1:]):
                color = color_sequence[i % len(color_sequence)] if color_sequence else None
                fig.add_trace(go.Bar(
                    y=df[y_col],
                    x=df[col],
                    name=col,
                    marker_color=color,
                    orientation='h'
                ))
    
    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title=x_label,
        yaxis_title=y_label,
        height=height,
        width=width,
        template="plotly_white",
        barmode='group'
    )
    
    return fig

def create_scatter_plot(
    data: Union[pd.DataFrame, Dict[str, List[Union[int, float]]], List[Dict[str, Union[int, float]]]],
    title: str = "Scatter Plot",
    x_label: str = "X-Axis",
    y_label: str = "Y-Axis",
    color_column: Optional[str] = None,
    size_column: Optional[str] = None,
    hover_data: Optional[List[str]] = None,
    height: int = 400,
    width: int = 700
) -> go.Figure:
    """
    Create a scatter plot using Plotly.
    
    Args:
        data: Data for the chart. Can be a pandas DataFrame, a dictionary with lists as values,
              or a list of dictionaries.
        title: Title of the chart.
        x_label: Label for the x-axis.
        y_label: Label for the y-axis.
        color_column: Optional column name to use for coloring points.
        size_column: Optional column name to use for sizing points.
        hover_data: Optional list of column names to include in hover information.
        height: Height of the chart in pixels.
        width: Width of the chart in pixels.
        
    Returns:
        A Plotly Figure object.
    """
    # Convert data to pandas DataFrame if it's not already
    if isinstance(data, dict):
        df = pd.DataFrame(data)
    elif isinstance(data, list) and all(isinstance(item, dict) for item in data):
        df = pd.DataFrame(data)
    elif isinstance(data, pd.DataFrame):
        df = data
    else:
        raise ValueError("Data must be a pandas DataFrame, a dictionary with lists as values, or a list of dictionaries.")
    
    # If the DataFrame has only two columns, use them as x and y
    if len(df.columns) == 2:
        x_col = df.columns[0]
        y_col = df.columns[1]
        fig = px.scatter(df, x=x_col, y=y_col, title=title)
    else:
        # Assume first two columns are x and y, and use additional columns for color, size, etc.
        x_col = df.columns[0]
        y_col = df.columns[1]
        
        # Create the scatter plot
        fig = px.scatter(
            df, 
            x=x_col, 
            y=y_col, 
            color=color_column if color_column and color_column in df.columns else None,
            size=size_column if size_column and size_column in df.columns else None,
            hover_data=hover_data if hover_data else None,
            title=title
        )
    
    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title=x_label,
        yaxis_title=y_label,
        height=height,
        width=width,
        template="plotly_white"
    )
    
    return fig

def detect_visualization_request(user_input: str) -> Dict[str, Any]:
    """
    Detect if the user is requesting a visualization and extract relevant information.
    
    Args:
        user_input: The user's input message.
        
    Returns:
        A dictionary containing:
        - 'is_visualization': Boolean indicating if a visualization is requested.
        - 'chart_type': The type of chart requested ('line', 'bar', 'scatter', or None).
        - 'data_description': Description of the data to visualize.
        - 'parameters': Additional parameters extracted from the request.
    """
    # Convert to lowercase for case-insensitive matching
    user_input_lower = user_input.lower()
    
    # Check for visualization keywords
    viz_keywords = ['plot', 'chart', 'graph', 'visualize', 'visualisation', 'visualization', 'display']
    is_visualization = any(keyword in user_input_lower for keyword in viz_keywords)
    
    if not is_visualization:
        return {
            'is_visualization': False,
            'chart_type': None,
            'data_description': None,
            'parameters': {}
        }
    
    # Detect chart type
    chart_type = None
    if any(term in user_input_lower for term in ['line chart', 'line graph', 'line plot']):
        chart_type = 'line'
    elif any(term in user_input_lower for term in ['bar chart', 'bar graph', 'histogram']):
        chart_type = 'bar'
    elif any(term in user_input_lower for term in ['scatter plot', 'scatter chart', 'scatter graph']):
        chart_type = 'scatter'
    
    # Extract data description
    data_description = None
    data_patterns = [
        r'(?:of|for|using|with)\s+([^.?!]+?)(?:\s+(?:by|over|across|versus|vs\.?|against))',
        r'(?:of|for|using|with)\s+([^.?!]+?)(?:\s+data)',
        r'(?:of|for|using|with)\s+([^.?!]+?)(?:\s+(?:from|in))'
    ]
    
    for pattern in data_patterns:
        match = re.search(pattern, user_input_lower)
        if match:
            data_description = match.group(1).strip()
            break
    
    # If no match found with specific patterns, try a more general approach
    if not data_description:
        # Look for text between the chart type and the end of the sentence
        chart_type_terms = ['line chart', 'bar chart', 'scatter plot', 'chart', 'graph', 'plot']
        for term in chart_type_terms:
            if term in user_input_lower:
                parts = user_input_lower.split(term, 1)
                if len(parts) > 1:
                    # Extract text after the chart type until the end of the sentence
                    after_chart_type = parts[1].strip()
                    end_sentence = re.search(r'^[^.!?]*', after_chart_type)
                    if end_sentence:
                        data_description = end_sentence.group(0).strip()
                        # Remove common prepositions at the beginning
                        data_description = re.sub(r'^(?:of|for|using|with)\s+', '', data_description)
                        break
    
    # Extract additional parameters
    parameters = {}
    
    # Title
    title_match = re.search(r'title[d:]?\s+["\']?([^"\'.?!]+)["\']?', user_input_lower)
    if title_match:
        parameters['title'] = title_match.group(1).strip()
    
    # X-axis label
    x_label_match = re.search(r'x[-\s]?(?:axis|label)[:]?\s+["\']?([^"\'.?!]+)["\']?', user_input_lower)
    if x_label_match:
        parameters['x_label'] = x_label_match.group(1).strip()
    
    # Y-axis label
    y_label_match = re.search(r'y[-\s]?(?:axis|label)[:]?\s+["\']?([^"\'.?!]+)["\']?', user_input_lower)
    if y_label_match:
        parameters['y_label'] = y_label_match.group(1).strip()
    
    return {
        'is_visualization': is_visualization,
        'chart_type': chart_type,
        'data_description': data_description,
        'parameters': parameters
    }

def generate_sample_data(data_description: str, chart_type: str) -> pd.DataFrame:
    """
    Generate sample data based on the description and chart type.
    This is a fallback when no actual data is available.
    
    Args:
        data_description: Description of the data to generate.
        chart_type: Type of chart ('line', 'bar', 'scatter').
        
    Returns:
        A pandas DataFrame with sample data.
    """
    np.random.seed(42)  # For reproducibility
    
    # Default data
    if chart_type == 'line':
        # Generate time series data
        dates = pd.date_range(start='2023-01-01', periods=30, freq='D')
        values = np.cumsum(np.random.randn(30)) + 10
        df = pd.DataFrame({'Date': dates, 'Value': values})
        
        # Try to customize based on description
        if data_description:
            if 'temperature' in data_description or 'weather' in data_description:
                df.columns = ['Date', 'Temperature (°C)']
                df['Temperature (°C)'] = np.random.normal(20, 5, 30)
            elif 'stock' in data_description or 'price' in data_description:
                df.columns = ['Date', 'Price ($)']
                df['Price ($)'] = 100 + np.cumsum(np.random.normal(0, 2, 30))
            elif 'sales' in data_description or 'revenue' in data_description:
                df.columns = ['Date', 'Sales ($)']
                df['Sales ($)'] = 1000 + np.cumsum(np.random.normal(0, 100, 30))
            else:
                df.columns = ['Date', data_description.capitalize() if data_description else 'Value']
        
    elif chart_type == 'bar':
        # Generate categorical data
        categories = ['A', 'B', 'C', 'D', 'E']
        values = np.random.randint(10, 100, size=len(categories))
        df = pd.DataFrame({'Category': categories, 'Value': values})
        
        # Try to customize based on description
        if data_description:
            if 'sales by region' in data_description or 'regional' in data_description:
                df['Category'] = ['North', 'South', 'East', 'West', 'Central']
                df.columns = ['Region', 'Sales ($)']
            elif 'product' in data_description:
                df['Category'] = ['Product A', 'Product B', 'Product C', 'Product D', 'Product E']
                df.columns = ['Product', 'Units Sold']
            elif 'age' in data_description or 'demographic' in data_description:
                df['Category'] = ['0-18', '19-35', '36-50', '51-65', '65+']
                df.columns = ['Age Group', 'Count']
            else:
                df.columns = ['Category', data_description.capitalize() if data_description else 'Value']
    
    elif chart_type == 'scatter':
        # Generate x-y data
        x = np.random.normal(0, 1, 50)
        y = x + np.random.normal(0, 0.5, 50)
        df = pd.DataFrame({'X': x, 'Y': y})
        
        # Try to customize based on description
        if data_description:
            if 'height' in data_description and 'weight' in data_description:
                df['X'] = np.random.normal(170, 10, 50)  # Heights in cm
                df['Y'] = df['X'] * 0.5 + np.random.normal(0, 5, 50)  # Weights in kg
                df.columns = ['Height (cm)', 'Weight (kg)']
            elif 'age' in data_description and ('income' in data_description or 'salary' in data_description):
                df['X'] = np.random.normal(40, 10, 50)  # Ages
                df['Y'] = df['X'] * 1000 + 20000 + np.random.normal(0, 5000, 50)  # Incomes
                df.columns = ['Age', 'Income ($)']
            elif 'study' in data_description or 'exam' in data_description:
                df['X'] = np.random.normal(5, 2, 50)  # Study hours
                df['Y'] = df['X'] * 10 + 50 + np.random.normal(0, 5, 50)  # Exam scores
                df.columns = ['Study Hours', 'Exam Score']
            else:
                x_label = 'X'
                y_label = 'Y'
                if ' vs ' in data_description:
                    parts = data_description.split(' vs ')
                    if len(parts) == 2:
                        x_label = parts[0].strip().capitalize()
                        y_label = parts[1].strip().capitalize()
                df.columns = [x_label, y_label]
    
    else:
        # Default fallback
        df = pd.DataFrame({
            'X': range(1, 11),
            'Y': np.random.randint(1, 100, 10)
        })
    
    return df