|
import gradio as gr |
|
from gradio_client import Client, handle_file |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import os |
|
import pandas as pd |
|
from io import StringIO |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
client = Client("mangoesai/Elections_Comparison_Agent_V4", hf_token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
def stream_chat_with_rag( |
|
message: str, |
|
|
|
client_name: str |
|
): |
|
|
|
|
|
answer, fig = client.predict( |
|
query= message, |
|
election_year=client_name, |
|
api_name="/process_query" |
|
) |
|
|
|
|
|
print("Raw answer from API:") |
|
print(answer) |
|
print("top works from API:") |
|
print(fig) |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
|
def heatmap(top_n): |
|
|
|
|
|
pivot_table = client.predict( |
|
top_n= top_n, |
|
api_name="/get_heatmap_pivot_table" |
|
) |
|
print(pivot_table) |
|
print(type(pivot_table)) |
|
""" |
|
pivot_table is a dict like: |
|
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], |
|
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], |
|
['disgust', 26911.0, 123112.0, 64567.0, 46460.0], |
|
['fear', 51466.0, 188898.0, 113174.0, 150578.0], |
|
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], |
|
'metadata': None} |
|
""" |
|
|
|
|
|
|
|
df = pd.DataFrame(diction['data'], columns=diction['headers']) |
|
df.set_index('Index', inplace=True) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(df, |
|
cmap='YlOrRd', |
|
cbar_kws={'label': 'Weighted Frequency'}, |
|
square=True) |
|
|
|
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency') |
|
plt.xlabel('Topics') |
|
plt.ylabel('Emotions') |
|
plt.xticks(rotation=45, ha='right') |
|
plt.tight_layout() |
|
|
|
return plt.gcf() |
|
|
|
|
|
|
|
with gr.Blocks(title="Reddit Election Analysis") as demo: |
|
gr.Markdown("# Reddit Public sentiment & Social topic distribution ") |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10]) |
|
with gr.Row(): |
|
fresh_btn = gr.Button("Refresh Heatmap") |
|
with gr.Column(): |
|
output_heatmap = gr.Plot( |
|
label="Top Public sentiment & Social topic Heatmap", |
|
container=True, |
|
elem_classes="heatmap-plot" |
|
) |
|
|
|
gr.Markdown("# Reddit Election Posts/Comments Analysis") |
|
gr.Markdown("Ask questions about election-related comments and posts") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
year_selector = gr.Radio( |
|
choices=["2016 Election", "2024 Election", "Comparison two years"], |
|
label="Select Election Year", |
|
value="2016 Election" |
|
) |
|
|
|
query_input = gr.Textbox( |
|
label="Your Question", |
|
placeholder="Ask about election comments or posts..." |
|
) |
|
|
|
submit_btn = gr.Button("Submit") |
|
|
|
gr.Markdown(""" |
|
## Example Questions: |
|
- Is there any comments don't like the election results |
|
- Summarize the main discussions about voting process |
|
- What are the common opinions about candidates? |
|
""") |
|
with gr.Column(): |
|
output_text = gr.Textbox( |
|
label="Response", |
|
lines=20 |
|
) |
|
|
|
with gr.Row(): |
|
output_plot = gr.Plot( |
|
label="Topic Distribution", |
|
container=True, |
|
elem_classes="topic-plot" |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<style> |
|
.topic-plot { |
|
min-height: 600px; |
|
width: 100%; |
|
margin: auto; |
|
} |
|
.heatmap-plot { |
|
min-height: 400px; |
|
width: 100%; |
|
margin: auto; |
|
} |
|
</style> |
|
""") |
|
fresh_btn.click( |
|
fn=heatmap, |
|
inputs=top_n, |
|
outputs=output_heatmap |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_btn.click( |
|
fn=stream_chat_with_rag, |
|
inputs=[query_input, year_selector], |
|
outputs=output_text |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |