Vera-ZWY's picture
Update app.py
4f6e76c verified
raw
history blame
7.77 kB
import gradio as gr
from gradio_client import Client, handle_file
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pandas as pd
from io import StringIO, BytesIO
import base64
# from linePlot import plot_stacked_time_series, plot_emotion_topic_grid
# Define your Hugging Face token (make sure to set it as an environment variable)
HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable
# Initialize the Gradio Client for the specified API
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN)
# client_name = ['2016 Election','2024 Election', 'Comparison two years']
def stream_chat_with_rag(
message: str,
# history: list,
client_name: str
):
# print(f"Message: {message}")
#answer = client.predict(question=question, api_name="/run_graph")
answer, fig = client.predict(
query= message,
election_year=client_name,
api_name="/process_query"
)
# Debugging: Print the raw response
print("Raw answer from API:")
print(answer)
print("top works from API:")
print(fig)
plot_bytes = base64.b64decode(fig['plot'].split(',')[1])
img = plt.imread(BytesIO(fig), format='PNG')
plt.figure(dpi = 150)
plt.imshow(img)
plt.axis('off')
plt.show()
# return answer, fig
return answe, plt.gcf()
def heatmap(top_n):
# df = pd.read_csv('submission_emotiontopics2024GPTresult.csv')
# topics_df = gr.Dataframe(value=df, label="Data Input")
pivot_table = client.predict(
top_n= top_n,
api_name="/get_heatmap_pivot_table"
)
print(pivot_table)
print(type(pivot_table))
"""
pivot_table is a dict like:
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'],
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0],
['disgust', 26911.0, 123112.0, 64567.0, 46460.0],
['fear', 51466.0, 188898.0, 113174.0, 150578.0],
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]],
'metadata': None}
"""
# transfere dictionary to df
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers'])
df.set_index('Index', inplace=True)
plt.figure(figsize=(10, 8))
sns.heatmap(df,
cmap='YlOrRd',
cbar_kws={'label': 'Weighted Frequency'},
square=True)
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency')
plt.xlabel('Topics')
plt.ylabel('Emotions')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
return plt.gcf()
# def decode_plot(plot_base64, top_n):
# plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
# img = plt.imread(BytesIO(plot_bytes), format='PNG')
# plt.figure(figsize = (12, 2*top_n), dpi = 150)
# plt.imshow(img)
# plt.axis('off')
# plt.show()
# return plt.gcf()
def linePlot(viz_type, weight, top_n):
# client = Client("mangoesai/Elections_Comparison_Agent_V4.1")
result = client.predict(
viz_type=viz_type,
weight=weight,
top_n=top_n,
api_name="/linePlot_3C1"
)
# print(result)
# result is a tuble of dictionary of (plot_base64, str), string message of description of the plot
plot_base64 = result[0]
plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
img = plt.imread(BytesIO(plot_bytes), format='PNG')
plt.figure(figsize = (12, 2*top_n), dpi = 150)
plt.imshow(img)
plt.axis('off')
plt.show()
return plt.gcf(), result[1]
# Create Gradio interface
with gr.Blocks(title="Reddit Election Analysis") as demo:
gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
with gr.Row():
with gr.Column():
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
fresh_btn = gr.Button("Refresh Heatmap")
with gr.Column():
# with gr.Row():
output_heatmap = gr.Plot(
label="Top Public sentiment & Social topic Heatmap",
container=True, # Ensures the plot is contained within its area
elem_classes="heatmap-plot" # Add a custom class for styling
)
gr.Markdown("# Get the time series of the Public sentiment & Social topic")
with gr.Row():
with gr.Column(scale=1):
# Control panel
lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix'])
weight_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Weight (Score vs. Frequency)"
)
top_n_slider = gr.Slider(
minimum=2,
maximum=10,
value=5,
step=1,
label="Top N Items"
)
# with gr.Column():
viz_dropdown = gr.Dropdown(
choices=["emotions", "topics", "grid"],
value="emotions",
label="Visualization Type",
info="Select the type of visualization to display"
)
linePlot_btn = gr.Button("Update Visualizations")
linePlot_status_text = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=3):
time_series_fig = gr.Plot()
gr.Markdown("# Reddit Election Posts/Comments Analysis")
gr.Markdown("Ask questions about election-related comments and posts")
with gr.Row():
with gr.Column():
year_selector = gr.Radio(
choices=["2016 Election", "2024 Election", "Comparison two years"],
label="Select Election Year",
value="2016 Election"
)
query_input = gr.Textbox(
label="Your Question",
placeholder="Ask about election comments or posts..."
)
submit_btn = gr.Button("Submit")
gr.Markdown("""
## Example Questions:
- Is there any comments don't like the election results
- Summarize the main discussions about voting process
- What are the common opinions about candidates?
""")
with gr.Column():
output_text = gr.Textbox(
label="Response",
lines=20
)
gr.Markdown("## Top works of the relevant Q&A")
with gr.Row():
output_plot = gr.Plot(
label="Topic Distribution",
container=True, # Ensures the plot is contained within its area
elem_classes="topic-plot" # Add a custom class for styling
)
# Add custom CSS to ensure proper plot sizing
gr.HTML("""
<style>
.topic-plot {
min-height: 600px;
width: 100%;
margin: auto;
}
.heatmap-plot {
min-height: 400px;
width: 100%;
margin: auto;
}
</style>
""")
# topics_df = gr.Dataframe(value=df, label="Data Input")
fresh_btn.click(
fn=heatmap,
inputs=top_n,
outputs=output_heatmap
)
linePlot_btn.click(
fn = linePlot,
inputs = [viz_dropdown,weight_slider,top_n_slider],
outputs = [time_series_fig, linePlot_status_text]
)
# Update both outputs when submit is clicked
submit_btn.click(
fn=stream_chat_with_rag,
inputs=[query_input, year_selector],
outputs=output_text
)
if __name__ == "__main__":
demo.launch(share=True)