auto-ml-gradio / app.py
harikrishnad1997's picture
Update app.py
03371ed verified
import gradio as gr
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from io import StringIO
# Function to plot histogram
def plot_histogram(file_contents, column, ax=None):
# Read the CSV file
custom_df = pd.read_csv(StringIO(file_contents))
# Plot histogram
sns.histplot(custom_df[column], ax=ax)
ax.set_title(f'Histogram for {column}')
ax.set_xlabel(column)
ax.set_ylabel('Frequency')
# Function to plot scatter plot
def plot_scatter(file_contents, x_axis, y_axis, ax=None):
# Read the CSV file
custom_df = pd.read_csv(StringIO(file_contents))
# Plot scatter plot
sns.scatterplot(x=x_axis, y=y_axis, data=custom_df, ax=ax)
ax.set_title(f'Scatter Plot ({x_axis} vs {y_axis})')
ax.set_xlabel(x_axis)
ax.set_ylabel(y_axis)
def layout_fn(file, text, text_1, text_2):
# Create the figure with subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# Check if data is uploaded and a column is selected for histogram
if file and text:
plot_histogram(file.getvalue(), text, ax=axes[0])
else:
axes[0].text(0.5, 0.5, "Upload a CSV and select a column", ha='center', va='center')
# Check if data is uploaded and both x and y columns are selected for scatter plot
if file and text_1 and text_2:
plot_scatter(file.getvalue(), text_1, text_2, ax=axes[1])
else:
axes[1].text(0.5, 0.5, "Upload a CSV, select X and Y columns", ha='center', va='center')
# Adjust layout
fig.suptitle("Data Visualization")
plt.tight_layout()
return fig
# Create the Gradio interface
interface = gr.Interface(
fn=layout_fn,
inputs=[
gr.inputs.File(label="Upload CSV file"),
gr.inputs.Dropdown(label="Select Column (Histogram)", choices=[]),
gr.inputs.Dropdown(label="Select X-axis (Scatter)", choices=[]),
gr.inputs.Dropdown(label="Select Y-axis (Scatter)", choices=[]),
],
outputs="plot",
title="Data Visualization Tool",
description="Upload a CSV file, select columns for histogram and scatter plots.",
)
def update_choices(file):
if file:
data = pd.read_csv(StringIO(file.getvalue()))
choices = list(data.columns)
interface.set_config(
inputs=[
gr.inputs.File(label="Upload CSV file"),
gr.inputs.Dropdown(label="Select Column (Histogram)", choices=choices),
gr.inputs.Dropdown(label="Select X-axis (Scatter)", choices=choices),
gr.inputs.Dropdown(label="Select Y-axis (Scatter)", choices=choices),
]
)
interface.run(share=True,fn_change=update_choices)