Vera-ZWY commited on
Commit
a708fda
1 Parent(s): 32821ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -13
app.py CHANGED
@@ -5,12 +5,13 @@ import matplotlib.pyplot as plt
5
  import os
6
  import pandas as pd
7
  from io import StringIO
 
8
 
9
  # Define your Hugging Face token (make sure to set it as an environment variable)
10
  HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable
11
 
12
  # Initialize the Gradio Client for the specified API
13
- client = Client("mangoesai/Elections_Comparison_Agent_V4", hf_token=HF_TOKEN)
14
 
15
  # client_name = ['2016 Election','2024 Election', 'Comparison two years']
16
 
@@ -79,22 +80,136 @@ def heatmap(top_n):
79
 
80
  return plt.gcf()
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Create Gradio interface
84
  with gr.Blocks(title="Reddit Election Analysis") as demo:
85
  gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
86
- with gr.Row():
87
  with gr.Column():
88
- with gr.Row():
89
- top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
90
- with gr.Row():
91
- fresh_btn = gr.Button("Refresh Heatmap")
92
  with gr.Column():
 
 
 
 
93
  output_heatmap = gr.Plot(
94
  label="Top Public sentiment & Social topic Heatmap",
95
  container=True, # Ensures the plot is contained within its area
96
  elem_classes="heatmap-plot" # Add a custom class for styling
97
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  gr.Markdown("# Reddit Election Posts/Comments Analysis")
100
  gr.Markdown("Ask questions about election-related comments and posts")
@@ -125,7 +240,8 @@ with gr.Blocks(title="Reddit Election Analysis") as demo:
125
  label="Response",
126
  lines=20
127
  )
128
-
 
129
  with gr.Row():
130
  output_plot = gr.Plot(
131
  label="Topic Distribution",
@@ -148,24 +264,38 @@ with gr.Blocks(title="Reddit Election Analysis") as demo:
148
  }
149
  </style>
150
  """)
 
 
 
 
151
  fresh_btn.click(
152
  fn=heatmap,
153
  inputs=top_n,
154
  outputs=output_heatmap
155
  )
 
 
 
 
 
 
156
 
 
 
 
 
 
157
  # Update both outputs when submit is clicked
158
- # submit_btn.click(
159
- # fn=stream_chat_with_rag,
160
- # inputs=[query_input, year_selector],
161
- # outputs=[output_text, output_plot]
162
- # )
163
  submit_btn.click(
164
- fn=stream_chat_with_rag,
165
  inputs=[query_input, year_selector],
166
  outputs=output_text
167
  )
168
 
169
 
 
 
 
 
170
  if __name__ == "__main__":
171
  demo.launch(share=True)
 
5
  import os
6
  import pandas as pd
7
  from io import StringIO
8
+ from linePlot import plot_stacked_time_series, plot_emotion_topic_grid
9
 
10
  # Define your Hugging Face token (make sure to set it as an environment variable)
11
  HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable
12
 
13
  # Initialize the Gradio Client for the specified API
14
+ client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN)
15
 
16
  # client_name = ['2016 Election','2024 Election', 'Comparison two years']
17
 
 
80
 
81
  return plt.gcf()
82
 
83
+ def linePlot_time_series(viz_type, weight, top_n):
84
+ result = client.predict(
85
+ viz_type=viz_type,
86
+ weight=weight,
87
+ top_n=top_n,
88
+ api_name="/linePlot_time_series"
89
+ )
90
+ print("============== timeseries df transfer from pivate to public ===============")
91
+ print(result)
92
+ print(type(result))
93
+ return result
94
+
95
+
96
+ def update_visualization(viz_type, weight, top_n):
97
+ """
98
+ Update visualization based on user inputs and selected visualization type
99
+
100
+ Parameters:
101
+ -----------
102
+ viz_type : str
103
+ Type of visualization to show ('emotions', 'topics', or 'grid')
104
+ weight : float
105
+ Weight for scoring (0-1)
106
+ top_n : int
107
+ Number of top items to show
108
+ """
109
+ try:
110
+
111
+ # return None, "Error: Start date must be before end date"
112
+ series = linePlot_time_series(viz_type, weight, top_n)
113
+ if viz_type == "emotions":
114
+ # Create emotion time series
115
+ # series = linePlot_time_series(viz_type, weight, top_n)
116
+ fig = plot_stacked_time_series(
117
+ series,
118
+ f'Top {top_n} Emotions Popularity'
119
+ )
120
+ message = "Emotion time series updated"
121
+
122
+ elif viz_type == "topics":
123
+ # Create topic time series
124
+ # series = linePlot_time_series(viz_type, weight, top_n)
125
+ fig = plot_stacked_time_series(
126
+ series,
127
+ f'Top {top_n} Topics Popularity'
128
+ )
129
+ message = "Topic time series updated"
130
+
131
+ else: # viz_type == "grid"
132
+ # Create emotion-topic grid
133
+ # pair_series = linePlot_time_series(viz_type, weight, top_n)
134
+ fig = plot_emotion_topic_grid(series, top_n)
135
+ message = "Emotion-Topic grid updated"
136
+
137
+ return fig, message
138
+
139
+ except Exception as e:
140
+ return None, f"Error: {str(e)}"
141
+
142
+
143
+
144
+
145
+
146
 
147
  # Create Gradio interface
148
  with gr.Blocks(title="Reddit Election Analysis") as demo:
149
  gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
150
+ with gr.Row():
151
  with gr.Column():
152
+ top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
153
+ table_btn = gr.Button("Overall pivot table")
154
+ show_pivot_table = gr.Dataframe(headers=['Index'] + list(df.columns))
155
+
156
  with gr.Column():
157
+
158
+ # top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
159
+ fresh_btn = gr.Button("Refresh Heatmap")
160
+ # with gr.Row():
161
  output_heatmap = gr.Plot(
162
  label="Top Public sentiment & Social topic Heatmap",
163
  container=True, # Ensures the plot is contained within its area
164
  elem_classes="heatmap-plot" # Add a custom class for styling
165
  )
166
+ gr.Markdown("# Get the time series of the Public sentiment & Social topic")
167
+ with gr.Row():
168
+ with gr.Column(scale=1):
169
+ # Control panel
170
+ lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix'])
171
+
172
+ weight_slider = gr.Slider(
173
+ minimum=0,
174
+ maximum=1,
175
+ value=0.5,
176
+ step=0.1,
177
+ label="Weight (Score vs. Frequency)"
178
+ )
179
+
180
+ top_n_slider = gr.Slider(
181
+ minimum=2,
182
+ maximum=10,
183
+ value=5,
184
+ step=1,
185
+ label="Top N Items"
186
+ )
187
+
188
+ # start_date_picker = gr.Date(
189
+ # value=date_min.date(),
190
+ # label="Start Date",
191
+ # info=f"Available from: {date_min.strftime('%Y-%m-%d')}"
192
+ # )
193
+
194
+ # end_date_picker = gr.Date(
195
+ # value=date_max.date(),
196
+ # label="End Date",
197
+ # info=f"Available until: {date_min.strftime('%Y-%m-%d')}"
198
+ # )
199
+
200
+
201
+ # with gr.Column():
202
+ viz_dropdown = gr.Dropdown(
203
+ choices=["emotions", "topics", "grid"],
204
+ value="emotions",
205
+ label="Visualization Type",
206
+ info="Select the type of visualization to display"
207
+ )
208
+ linePlot_btn = gr.Button("Update Visualizations")
209
+ linePlot_status_text = gr.Textbox(label="Status", interactive=False)
210
+
211
+ with gr.Column(scale=3):
212
+ time_series_fig = gr.Plot()
213
 
214
  gr.Markdown("# Reddit Election Posts/Comments Analysis")
215
  gr.Markdown("Ask questions about election-related comments and posts")
 
240
  label="Response",
241
  lines=20
242
  )
243
+
244
+ gr.Markdown("## Top works of the relevant Q&A")
245
  with gr.Row():
246
  output_plot = gr.Plot(
247
  label="Topic Distribution",
 
264
  }
265
  </style>
266
  """)
267
+ # topics_df = gr.Dataframe(value=df, label="Data Input")
268
+
269
+
270
+
271
  fresh_btn.click(
272
  fn=heatmap,
273
  inputs=top_n,
274
  outputs=output_heatmap
275
  )
276
+
277
+ linePlot_btn.click(
278
+ fn = update_visualization,
279
+ inputs = [viz_dropdown,weight_slider,top_n_slider],
280
+ outputs = [time_series_fig, linePlot_status_text]
281
+ )
282
 
283
+ table_btn.click(
284
+ fn=get_heatmap_pivot_table,
285
+ inputs= top_n,
286
+ outputs=show_pivot_table
287
+ )
288
  # Update both outputs when submit is clicked
 
 
 
 
 
289
  submit_btn.click(
290
+ fn=process_query,
291
  inputs=[query_input, year_selector],
292
  outputs=output_text
293
  )
294
 
295
 
296
+ if __name__ == "__main__":
297
+ demo.launch(share=True)
298
+
299
+
300
  if __name__ == "__main__":
301
  demo.launch(share=True)