Vera-ZWY's picture
Update app.py
cedf8bf verified
raw
history blame
9.36 kB
import gradio as gr
from gradio_client import Client, handle_file
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pandas as pd
from io import StringIO, BytesIO
import base64
# from linePlot import plot_stacked_time_series, plot_emotion_topic_grid
# Define your Hugging Face token (make sure to set it as an environment variable)
HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable
# Initialize the Gradio Client for the specified API
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN)
# client_name = ['2016 Election','2024 Election', 'Comparison two years']
def stream_chat_with_rag(
message: str,
# history: list,
client_name: str
):
# print(f"Message: {message}")
#answer = client.predict(question=question, api_name="/run_graph")
answer, fig = client.predict(
query= message,
election_year=client_name,
api_name="/process_query"
)
# Debugging: Print the raw response
print("Raw answer from API:")
print(answer)
print("top works from API:")
print(fig)
# return answer, fig
return answer
def heatmap(top_n):
# df = pd.read_csv('submission_emotiontopics2024GPTresult.csv')
# topics_df = gr.Dataframe(value=df, label="Data Input")
pivot_table = client.predict(
top_n= top_n,
api_name="/get_heatmap_pivot_table"
)
print(pivot_table)
print(type(pivot_table))
"""
pivot_table is a dict like:
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'],
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0],
['disgust', 26911.0, 123112.0, 64567.0, 46460.0],
['fear', 51466.0, 188898.0, 113174.0, 150578.0],
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]],
'metadata': None}
"""
# transfere dictionary to df
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers'])
df.set_index('Index', inplace=True)
plt.figure(figsize=(10, 8))
sns.heatmap(df,
cmap='YlOrRd',
cbar_kws={'label': 'Weighted Frequency'},
square=True)
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency')
plt.xlabel('Topics')
plt.ylabel('Emotions')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
return plt.gcf()
# def linePlot_time_series(viz_type, weight, top_n):
# result = client.predict(
# viz_type=viz_type,
# weight=weight,
# top_n=top_n,
# api_name="/linePlot_time_series"
# )
# print("============== timeseries df transfer from pivate to public ===============")
# print(result)
# print(type(result))
# df = pd.DataFrame(result['data'], columns=result['headers'])
# df.set_index('Index', inplace=True)
# return df
# def update_visualization(viz_type, weight, top_n):
# """
# Update visualization based on user inputs and selected visualization type
# Parameters:
# -----------
# viz_type : str
# Type of visualization to show ('emotions', 'topics', or 'grid')
# weight : float
# Weight for scoring (0-1)
# top_n : int
# Number of top items to show
# """
# try:
# # return None, "Error: Start date must be before end date"
# series = linePlot_time_series(viz_type, weight, top_n)
# if viz_type == "emotions":
# # Create emotion time series
# # series = linePlot_time_series(viz_type, weight, top_n)
# fig = plot_stacked_time_series(
# series,
# f'Top {top_n} Emotions Popularity'
# )
# message = "Emotion time series updated"
# elif viz_type == "topics":
# # Create topic time series
# # series = linePlot_time_series(viz_type, weight, top_n)
# fig = plot_stacked_time_series(
# series,
# f'Top {top_n} Topics Popularity'
# )
# message = "Topic time series updated"
# else: # viz_type == "grid"
# # Create emotion-topic grid
# # pair_series = linePlot_time_series(viz_type, weight, top_n)
# fig = plot_emotion_topic_grid(series, top_n)
# message = "Emotion-Topic grid updated"
# return fig, message
# except Exception as e:
# return None, f"Error: {str(e)}"
def decode_plot(plot_base64):
plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
img = plt.imread(BytesIO(plot_bytes), format='PNG')
plt.imshow(img)
plt.axis('off')
plt.show()
return plt.gcf()
def linePlot(viz_type, weight, top_n):
# client = Client("mangoesai/Elections_Comparison_Agent_V4.1")
result = client.predict(
viz_type=viz_type,
weight=weight,
top_n=top_n,
api_name="/linePlot_3C1"
)
# print(result)
# result is a tuble of dictionary of plot_base64, and a string message of description of the plot
return decode_plot(result[0])
# Create Gradio interface
with gr.Blocks(title="Reddit Election Analysis") as demo:
gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
with gr.Row():
with gr.Column():
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
fresh_btn = gr.Button("Refresh Heatmap")
with gr.Column():
# with gr.Row():
output_heatmap = gr.Plot(
label="Top Public sentiment & Social topic Heatmap",
container=True, # Ensures the plot is contained within its area
elem_classes="heatmap-plot" # Add a custom class for styling
)
gr.Markdown("# Get the time series of the Public sentiment & Social topic")
with gr.Row():
with gr.Column(scale=1):
# Control panel
lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix'])
weight_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Weight (Score vs. Frequency)"
)
top_n_slider = gr.Slider(
minimum=2,
maximum=10,
value=5,
step=1,
label="Top N Items"
)
# with gr.Column():
viz_dropdown = gr.Dropdown(
choices=["emotions", "topics", "grid"],
value="emotions",
label="Visualization Type",
info="Select the type of visualization to display"
)
linePlot_btn = gr.Button("Update Visualizations")
linePlot_status_text = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=3):
time_series_fig = gr.Plot()
gr.Markdown("# Reddit Election Posts/Comments Analysis")
gr.Markdown("Ask questions about election-related comments and posts")
with gr.Row():
with gr.Column():
year_selector = gr.Radio(
choices=["2016 Election", "2024 Election", "Comparison two years"],
label="Select Election Year",
value="2016 Election"
)
query_input = gr.Textbox(
label="Your Question",
placeholder="Ask about election comments or posts..."
)
submit_btn = gr.Button("Submit")
gr.Markdown("""
## Example Questions:
- Is there any comments don't like the election results
- Summarize the main discussions about voting process
- What are the common opinions about candidates?
""")
with gr.Column():
output_text = gr.Textbox(
label="Response",
lines=20
)
gr.Markdown("## Top works of the relevant Q&A")
with gr.Row():
output_plot = gr.Plot(
label="Topic Distribution",
container=True, # Ensures the plot is contained within its area
elem_classes="topic-plot" # Add a custom class for styling
)
# Add custom CSS to ensure proper plot sizing
gr.HTML("""
<style>
.topic-plot {
min-height: 600px;
width: 100%;
margin: auto;
}
.heatmap-plot {
min-height: 400px;
width: 100%;
margin: auto;
}
</style>
""")
# topics_df = gr.Dataframe(value=df, label="Data Input")
fresh_btn.click(
fn=heatmap,
inputs=top_n,
outputs=output_heatmap
)
linePlot_btn.click(
fn = linePlot,
inputs = [viz_dropdown,weight_slider,top_n_slider],
outputs = [time_series_fig, linePlot_status_text]
)
# Update both outputs when submit is clicked
submit_btn.click(
fn=stream_chat_with_rag,
inputs=[query_input, year_selector],
outputs=output_text
)
if __name__ == "__main__":
demo.launch(share=True)