File size: 9,659 Bytes
96911b6
 
0ed32cd
16d491d
96911b6
1cf4aa1
290ac64
 
375dd21
fb10e94
9a76e3e
6ce4a98
96911b6
 
 
 
 
a708fda
96911b6
a84325a
96911b6
 
 
7fd5c4f
571cdc8
96911b6
89661b3
96911b6
3b25a34
96911b6
571cdc8
96911b6
 
 
 
2f83adb
96911b6
 
7fd5c4f
 
3b25a34
4f6e76c
3b25a34
 
375dd21
1fe2261
d766d8b
3b25a34
 
 
 
 
 
 
 
c04fc5a
3b25a34
c04fc5a
 
 
 
 
 
fb10e94
3b25a34
c04fc5a
 
 
 
a244f3c
c04fc5a
 
3b25a34
7fd5c4f
 
 
 
 
 
 
 
 
 
 
 
 
571cdc8
 
d766d8b
8edaa73
b36c45b
0ed32cd
028cea4
0ed32cd
d766d8b
4bc7cb3
51319c6
 
 
 
 
 
 
 
 
 
 
1cf4aa1
51319c6
366588b
51319c6
0ed32cd
 
9fc9533
0ed32cd
 
 
 
 
 
 
 
 
 
 
 
0f36100
a708fda
4f6e76c
 
 
 
 
 
 
 
a708fda
 
4ebe04e
6ce4a98
4ebe04e
 
 
 
 
 
cedf8bf
4f6e76c
 
 
 
 
 
 
 
 
 
4ebe04e
 
a708fda
e1b9d08
bb3ba32
3d7d31a
a708fda
d766d8b
a708fda
a59e807
a708fda
d766d8b
a59e807
a708fda
d766d8b
 
 
 
 
a708fda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb3ba32
 
e1b9d08
bb3ba32
e1b9d08
5e8cbf8
1fe2261
 
 
2f83adb
1fe2261
 
 
bb3ba32
9a76e3e
 
 
 
bb3ba32
9a76e3e
bb3ba32
2f83adb
 
 
 
847e897
 
2f83adb
2e61e42
 
 
 
 
 
5e8cbf8
90273d7
 
 
b7a0eba
90273d7
571cdc8
5e8cbf8
3b25a34
410d48c
7a222d6
a84325a
 
7a222d6
 
410d48c
7a222d6
 
 
 
 
bb3ba32
 
 
 
d766d8b
 
 
 
 
3b25a34
 
 
 
 
bb3ba32
 
a708fda
 
 
 
d766d8b
 
 
 
 
a708fda
 
4ebe04e
a708fda
 
 
bb3ba32
3b25a34
 
 
 
 
 
bb3ba32
 
96911b6
d766d8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import gradio as gr
from gradio_client import Client, handle_file
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pandas as pd
from io import StringIO, BytesIO
import base64
import json
import plotly.graph_objects as go
# import plotly.io as pio
# from linePlot import plot_stacked_time_series, plot_emotion_topic_grid

# Define your Hugging Face token (make sure to set it as an environment variable)
HF_TOKEN = os.getenv("HF_TOKEN")  # Replace with your actual token if not using an environment variable

# Initialize the Gradio Client for the specified API
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN)

# query_input = ""

def stream_chat_with_rag(
    message: str,
    history: list,
    year: str
):
    # print(f"Message: {message}")
    #answer = client.predict(question=question, api_name="/run_graph")
    answer, sources = client.predict(
    	query= message,
		election_year=year,
		api_name="/process_query"
    )

    # Debugging: Print the raw response
    response = f"Retrieving the submissions in {year}..."
    print("Raw answer from API:")
    print(answer)
    history.append((message, response +"\n"+ answer))
    

    
    # Render the figure

    
    return answer

def topic_plot_gener(message: str, year: str):
    fig = client.predict(
    	query= message,
		election_year=year,
		api_name="/topics_plot_genera"
    )
        # print("top works from API:")
    print(fig)
    # plot_base64 = fig

    # plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
    # img = plt.imread(BytesIO(plot_bytes), format='PNG')
    # plt.figure(figsize = (12, 6), dpi = 150)
    # plt.imshow(img)
    # plt.axis('off')
    # plt.show()
    plot_json = json.loads(fig['plot'])

    # Create a figure using the decoded data
    fig = go.Figure(data=plot_json["data"])
    
    # Show the plot
    return fig

    # return plt.gcf()


# def predict(message, history):
#     history_langchain_format = []
#     for msg in history:
#         if msg['role'] == "user":
#             history_langchain_format.append(HumanMessage(content=msg['content']))
#         elif msg['role'] == "assistant":
#             history_langchain_format.append(AIMessage(content=msg['content']))
#     history_langchain_format.append(HumanMessage(content=message))
#     gpt_response = llm(history_langchain_format)
#     return gpt_response.content
    


    
def heatmap(top_n):
    # df = pd.read_csv('submission_emotiontopics2024GPTresult.csv')
    # topics_df = gr.Dataframe(value=df, label="Data Input")
    pivot_table = client.predict(
    	top_n= top_n,
		api_name="/get_heatmap_pivot_table"
    )
    print(pivot_table)
    print(type(pivot_table)) 
    """
    pivot_table is a dict like:
    {'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], 
    'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], 
             ['disgust', 26911.0, 123112.0, 64567.0, 46460.0], 
             ['fear', 51466.0, 188898.0, 113174.0, 150578.0], 
             ['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], 
    'metadata': None}
    """
    

    # transfere dictionary to df
    df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers'])
    df.set_index('Index', inplace=True)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(df,
                cmap='YlOrRd',
                cbar_kws={'label': 'Weighted Frequency'},
                square=True)
    
    plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency')
    plt.xlabel('Topics')
    plt.ylabel('Emotions')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    
    return plt.gcf()



# def decode_plot(plot_base64, top_n):
#     plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
#     img = plt.imread(BytesIO(plot_bytes), format='PNG')
#     plt.figure(figsize = (12, 2*top_n), dpi = 150)
#     plt.imshow(img)
#     plt.axis('off')
#     plt.show()
#     return plt.gcf()


def linePlot(viz_type, weight, top_n):
    # client = Client("mangoesai/Elections_Comparison_Agent_V4.1")
    result = client.predict(
    		viz_type=viz_type,
    		weight=weight,
    		top_n=top_n,
    		api_name="/linePlot_3C1"
    )
    # print(result)
    # result is a tuble of dictionary of (plot_base64, str), string message of description of the plot
    plot_base64 = result[0]

    plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
    img = plt.imread(BytesIO(plot_bytes), format='PNG')
    plt.figure(figsize = (12, 2*top_n), dpi = 150)
    plt.imshow(img)
    plt.axis('off')
    plt.show()   
    return plt.gcf(), result[1]

    

# Create Gradio interface
with gr.Blocks(title="Reddit Election Analysis") as demo:
    gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
    with gr.Row():        
        with gr.Column():
            top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
            fresh_btn = gr.Button("Refresh Heatmap")

        with gr.Column():
            
            # with gr.Row():
            output_heatmap = gr.Plot(
                label="Top Public sentiment & Social topic Heatmap",
                container=True,  # Ensures the plot is contained within its area
                elem_classes="heatmap-plot"  # Add a custom class for styling
            )
    gr.Markdown("# Get the time series of the Public sentiment & Social topic")
    with gr.Row():
        with gr.Column(scale=1):
            # Control panel
            lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix'])
                            
            weight_slider = gr.Slider(
                minimum=0,
                maximum=1,
                value=0.5,
                step=0.1,
                label="Weight (Score vs. Frequency)"
            )
            
            top_n_slider = gr.Slider(
                minimum=2,
                maximum=10,
                value=5,
                step=1,
                label="Top N Items"
            )


        # with gr.Column():
            viz_dropdown = gr.Dropdown(
                choices=["emotions", "topics", "grid"],
                value="emotions",
                label="Visualization Type",
                info="Select the type of visualization to display"
            )
            linePlot_btn = gr.Button("Update Visualizations")
            linePlot_status_text = gr.Textbox(label="Status", interactive=False)
        
        with gr.Column(scale=3):
            time_series_fig = gr.Plot()

    gr.Markdown("# Reddit Election Posts/Comments Analysis")
    gr.Markdown("Ask questions about election-related comments and posts")

    with gr.Row():
        with gr.Column(scale = 1):
            year_selector = gr.Radio(
                choices=["2016 Election", "2024 Election", "Comparison two years"],
                label="Select Election Year",
                value="2024 Election"
            )
            slider = gr.Slider(50, 500, render=False, label= "Tokens")


        #     query_input = gr.Textbox(
        #         label="Your Question",
        #         placeholder="Ask about election comments or posts..."
        #     )

        #     submit_btn = gr.Button("Submit")

            gr.Markdown("""
            ## Example Questions:
            - Is there any comments don't like the election results
            - Summarize the main discussions about voting process
            - What're the common opinions about candidates?
            - What're common opinions about immigrant topic?
            """)
        # with gr.Column():
        #     output_text = gr.Textbox(
        #         label="Response",
        #         lines=20
        #     )
        
        with gr.Column(scale = 2):
            gr.ChatInterface(stream_chat_with_rag, 
                             type="messages", 
                             # chatbot=stream_chat_with_rag,
                             additional_inputs = [year_selector]
                            )
            
    gr.Markdown("## Top words of the relevant Q&A")
    with gr.Row():
        with gr.Column(scale = 1):
            query_input = gr.Textbox(
                label="Your Question For Topicalize",
                placeholder="Copy and past your question there to vilaulize the top words of relevant topic"
            )
            topic_btn = gr.Button("Topicalize the RAG sources")
        with gr.Column(scale = 2):       
            topic_plot = gr.Plot(
                label="Top Words Distribution",
                container=True,  # Ensures the plot is contained within its area
                elem_classes="topic-plot"  # Add a custom class for styling
            )

    # Add custom CSS to ensure proper plot sizing
    gr.HTML("""
        <style>
            .heatmap-plot {
                min-height: 400px;
                width: 100%;
                margin: auto;
            }
            .topic-plot {
                min-width: 600px;
                height: 100%;
                margin: auto;
            }
        </style>
    """)
    # topics_df = gr.Dataframe(value=df, label="Data Input")


    
    fresh_btn.click(
        fn=heatmap,
        inputs=top_n,
        outputs=output_heatmap
    )
    
    linePlot_btn.click(
        fn = linePlot,
        inputs = [viz_dropdown,weight_slider,top_n_slider],
        outputs = [time_series_fig, linePlot_status_text]
    )

    # Update both outputs when submit is clicked
    topic_btn.click(
        fn= topic_plot_gener,
        inputs=[query_input, year_selector],
        outputs= topic_plot
    )


if __name__ == "__main__":
    demo.launch(share=True)