import gradio as gr from gradio_client import Client, handle_file import seaborn as sns import matplotlib.pyplot as plt import os import pandas as pd from io import StringIO, BytesIO import base64 import json import plotly.graph_objects as go # import plotly.io as pio # from linePlot import plot_stacked_time_series, plot_emotion_topic_grid # Define your Hugging Face token (make sure to set it as an environment variable) HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable # Initialize the Gradio Client for the specified API client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN) # query_input = "" def stream_chat_with_rag( message: str, history: list, year: str ): # print(f"Message: {message}") #answer = client.predict(question=question, api_name="/run_graph") answer, sources = client.predict( query= message, election_year=year, api_name="/process_query" ) # Debugging: Print the raw response response = f"Retrieving the submissions in {year}..." print("Raw answer from API:") print(answer) history.append((message, response +"\n"+ answer)) # Render the figure return answer def topic_plot_gener(message: str, year: str): fig = client.predict( query= message, election_year=year, api_name="/topics_plot_genera" ) # print("top works from API:") print(fig) # plot_base64 = fig # plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) # img = plt.imread(BytesIO(plot_bytes), format='PNG') # plt.figure(figsize = (12, 6), dpi = 150) # plt.imshow(img) # plt.axis('off') # plt.show() plot_json = json.loads(fig['plot']) # Create a figure using the decoded data fig = go.Figure(data=plot_json["data"]) # Show the plot return fig.show() # return plt.gcf() # def predict(message, history): # history_langchain_format = [] # for msg in history: # if msg['role'] == "user": # history_langchain_format.append(HumanMessage(content=msg['content'])) # elif msg['role'] == "assistant": # history_langchain_format.append(AIMessage(content=msg['content'])) # history_langchain_format.append(HumanMessage(content=message)) # gpt_response = llm(history_langchain_format) # return gpt_response.content def heatmap(top_n): # df = pd.read_csv('submission_emotiontopics2024GPTresult.csv') # topics_df = gr.Dataframe(value=df, label="Data Input") pivot_table = client.predict( top_n= top_n, api_name="/get_heatmap_pivot_table" ) print(pivot_table) print(type(pivot_table)) """ pivot_table is a dict like: {'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], 'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], ['disgust', 26911.0, 123112.0, 64567.0, 46460.0], ['fear', 51466.0, 188898.0, 113174.0, 150578.0], ['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], 'metadata': None} """ # transfere dictionary to df df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers']) df.set_index('Index', inplace=True) plt.figure(figsize=(10, 8)) sns.heatmap(df, cmap='YlOrRd', cbar_kws={'label': 'Weighted Frequency'}, square=True) plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency') plt.xlabel('Topics') plt.ylabel('Emotions') plt.xticks(rotation=45, ha='right') plt.tight_layout() return plt.gcf() # def decode_plot(plot_base64, top_n): # plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) # img = plt.imread(BytesIO(plot_bytes), format='PNG') # plt.figure(figsize = (12, 2*top_n), dpi = 150) # plt.imshow(img) # plt.axis('off') # plt.show() # return plt.gcf() def linePlot(viz_type, weight, top_n): # client = Client("mangoesai/Elections_Comparison_Agent_V4.1") result = client.predict( viz_type=viz_type, weight=weight, top_n=top_n, api_name="/linePlot_3C1" ) # print(result) # result is a tuble of dictionary of (plot_base64, str), string message of description of the plot plot_base64 = result[0] plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) img = plt.imread(BytesIO(plot_bytes), format='PNG') plt.figure(figsize = (12, 2*top_n), dpi = 150) plt.imshow(img) plt.axis('off') plt.show() return plt.gcf(), result[1] # Create Gradio interface with gr.Blocks(title="Reddit Election Analysis") as demo: gr.Markdown("# Reddit Public sentiment & Social topic distribution ") with gr.Row(): with gr.Column(): top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10]) fresh_btn = gr.Button("Refresh Heatmap") with gr.Column(): # with gr.Row(): output_heatmap = gr.Plot( label="Top Public sentiment & Social topic Heatmap", container=True, # Ensures the plot is contained within its area elem_classes="heatmap-plot" # Add a custom class for styling ) gr.Markdown("# Get the time series of the Public sentiment & Social topic") with gr.Row(): with gr.Column(scale=1): # Control panel lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix']) weight_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.1, label="Weight (Score vs. Frequency)" ) top_n_slider = gr.Slider( minimum=2, maximum=10, value=5, step=1, label="Top N Items" ) # with gr.Column(): viz_dropdown = gr.Dropdown( choices=["emotions", "topics", "grid"], value="emotions", label="Visualization Type", info="Select the type of visualization to display" ) linePlot_btn = gr.Button("Update Visualizations") linePlot_status_text = gr.Textbox(label="Status", interactive=False) with gr.Column(scale=3): time_series_fig = gr.Plot() gr.Markdown("# Reddit Election Posts/Comments Analysis") gr.Markdown("Ask questions about election-related comments and posts") with gr.Row(): with gr.Column(scale = 1): year_selector = gr.Radio( choices=["2016 Election", "2024 Election", "Comparison two years"], label="Select Election Year", value="2024 Election" ) slider = gr.Slider(50, 500, render=False, label= "Tokens") # query_input = gr.Textbox( # label="Your Question", # placeholder="Ask about election comments or posts..." # ) # submit_btn = gr.Button("Submit") gr.Markdown(""" ## Example Questions: - Is there any comments don't like the election results - Summarize the main discussions about voting process - What're the common opinions about candidates? - What're common opinions about immigrant topic? """) # with gr.Column(): # output_text = gr.Textbox( # label="Response", # lines=20 # ) with gr.Column(scale = 2): gr.ChatInterface(stream_chat_with_rag, type="messages", # chatbot=stream_chat_with_rag, additional_inputs = [year_selector] ) gr.Markdown("## Top words of the relevant Q&A") with gr.Row(): with gr.Column(scale = 1): query_input = gr.Textbox( label="Your Question For Topicalize", placeholder="Copy and past your question there to vilaulize the top words of relevant topic" ) topic_btn = gr.Button("Topicalize the RAG sources") with gr.Column(scale = 2): topic_plot = gr.Plot( label="Top Words Distribution", container=True, # Ensures the plot is contained within its area elem_classes="topic-plot" # Add a custom class for styling ) # Add custom CSS to ensure proper plot sizing gr.HTML(""" """) # topics_df = gr.Dataframe(value=df, label="Data Input") fresh_btn.click( fn=heatmap, inputs=top_n, outputs=output_heatmap ) linePlot_btn.click( fn = linePlot, inputs = [viz_dropdown,weight_slider,top_n_slider], outputs = [time_series_fig, linePlot_status_text] ) # Update both outputs when submit is clicked topic_btn.click( fn= topic_plot_gener, inputs=[query_input, year_selector], outputs= topic_plot ) if __name__ == "__main__": demo.launch(share=True)