File size: 3,132 Bytes
21bd1c0
 
 
 
2bd2954
21bd1c0
 
2bd2954
21bd1c0
2bd2954
21bd1c0
 
 
 
 
 
 
 
 
 
2bd2954
21bd1c0
2bd2954
21bd1c0
 
 
 
 
 
 
 
 
9c3faa5
71fc0ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21bd1c0
71fc0ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21bd1c0
71fc0ff
 
21bd1c0
 
71fc0ff
9c3faa5
 
 
 
71fc0ff
9c3faa5
 
 
 
 
71fc0ff
 
fb3a99d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from io import StringIO

# Function to plot histogram
def plot_histogram(file_contents, column):
    # Read the CSV file
    custom_df = pd.read_csv(StringIO(file_contents))
    
    # Plot histogram
    plt.figure(figsize=(8, 6))
    sns.histplot(custom_df[column])
    plt.title(f'Histogram for {column}')
    plt.xlabel(column)
    plt.ylabel('Frequency')
    return plt

# Function to plot scatter plot
def plot_scatter(file_contents, x_axis, y_axis):
    # Read the CSV file
    custom_df = pd.read_csv(StringIO(file_contents))
    
    # Plot scatter plot
    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=x_axis, y=y_axis, data=custom_df)
    plt.title(f'Scatter Plot ({x_axis} vs {y_axis})')
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    return plt

def layout_fn(inputs):
    # Get the data from uploaded file
    data = pd.read_csv(StringIO(inputs.file))

    # Get selected columns from dropdown options
    column_options = list(data.columns)
    selected_column = inputs.text if inputs.text else None
    selected_x_axis = inputs.text_1 if inputs.text_1 else None
    selected_y_axis = inputs.text_2 if inputs.text_2 else None

    # Create the figure with subplots
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Check if data is uploaded and a column is selected for histogram
    if inputs.file and selected_column:
        plot_histogram(inputs.file.getvalue(), selected_column, ax=axes[0])
    else:
        axes[0].text(0.5, 0.5, "Upload a CSV and select a column", ha='center', va='center')

    # Check if data is uploaded and both x and y columns are selected for scatter plot
    if inputs.file and selected_x_axis and selected_y_axis:
        plot_scatter(inputs.file.getvalue(), selected_x_axis, selected_y_axis, ax=axes[1])
    else:
        axes[1].text(0.5, 0.5, "Upload a CSV, select X and Y columns", ha='center', va='center')

    # Adjust layout
    fig.suptitle("Data Visualization")
    plt.tight_layout()
    return fig


# Create the Gradio interface
interface = gr.Interface(
    fn=layout_fn,
    inputs=[
        gr.components.File(label="Upload CSV file"),
        gr.components.Dropdown(label="Select Column (Histogram)", choices="infer"),
        gr.components.Dropdown(label="Select X-axis (Scatter)", choices="infer"),
        gr.components.Dropdown(label="Select Y-axis (Scatter)", choices="infer"),
    ],
    outputs="plot",
    title="Data Visualization Tool",
    description="Upload a CSV file, select columns for histogram and scatter plots.",
)


@interface.update  # Use update method instead of component_args
def update_choices(inputs):
    if inputs.file:
        data = pd.read_csv(StringIO(inputs.file.getvalue()))
        choices = list(data.columns)
        interface.update(  # Update interface options using update
            text=gr.components.Dropdown(choices=choices),
            text_1=gr.components.Dropdown(choices=choices),
            text_2=gr.components.Dropdown(choices=choices),
        )


interface.launch(share=True,fn_change=update_choices)