Vera-ZWY commited on
Commit
3b25a34
·
verified ·
1 Parent(s): 5e8cbf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -26,7 +26,7 @@ def stream_chat_with_rag(
26
  ):
27
  # print(f"Message: {message}")
28
  #answer = client.predict(question=question, api_name="/run_graph")
29
- answer, fig = client.predict(
30
  query= message,
31
  election_year=year,
32
  api_name="/process_query"
@@ -38,16 +38,32 @@ def stream_chat_with_rag(
38
  print(answer)
39
  history.append((message, response +"\n"+ answer))
40
 
41
- # print("top works from API:")
42
- # print(fig)
43
 
44
- # fig_dict = json.loads(plotly_data['plot'])
 
45
 
46
- # # Render the figure
47
- # fig = pio.from_json(json.dumps(fig_dict))
48
- # fig.show()
49
  return answer
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # def predict(message, history):
53
  # history_langchain_format = []
@@ -228,26 +244,26 @@ with gr.Blocks(title="Reddit Election Analysis") as demo:
228
  )
229
 
230
  gr.Markdown("## Top words of the relevant Q&A")
231
- # with gr.Row():
232
- # output_plot = gr.Plot(
233
- # label="Topic Distribution",
234
- # container=True, # Ensures the plot is contained within its area
235
- # elem_classes="topic-plot" # Add a custom class for styling
236
- # )
237
 
238
  # Add custom CSS to ensure proper plot sizing
239
  gr.HTML("""
240
  <style>
241
- # .topic-plot {
242
- # min-height: 600px;
243
- # width: 100%;
244
- # margin: auto;
245
- # }
246
  .heatmap-plot {
247
  min-height: 400px;
248
  width: 100%;
249
  margin: auto;
250
  }
 
 
 
 
 
251
  </style>
252
  """)
253
  # topics_df = gr.Dataframe(value=df, label="Data Input")
@@ -266,12 +282,12 @@ with gr.Blocks(title="Reddit Election Analysis") as demo:
266
  outputs = [time_series_fig, linePlot_status_text]
267
  )
268
 
269
- # # Update both outputs when submit is clicked
270
- # submit_btn.click(
271
- # fn=chat_function,
272
- # inputs=[query_input, year_selector],
273
- # outputs= chatbot
274
- # )
275
 
276
 
277
  if __name__ == "__main__":
 
26
  ):
27
  # print(f"Message: {message}")
28
  #answer = client.predict(question=question, api_name="/run_graph")
29
+ answer, sources = client.predict(
30
  query= message,
31
  election_year=year,
32
  api_name="/process_query"
 
38
  print(answer)
39
  history.append((message, response +"\n"+ answer))
40
 
41
+
 
42
 
43
+ # Render the figure
44
+
45
 
 
 
 
46
  return answer
47
 
48
+ def topic_plot_gener(message: str, year: str):
49
+ fig = client.predict(
50
+ query= message,
51
+ election_year=year,
52
+ api_name="/topics_plot_genera"
53
+ )
54
+ # print("top works from API:")
55
+ print(fig)
56
+ plot_base64 = fig
57
+
58
+ plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
59
+ img = plt.imread(BytesIO(plot_bytes), format='PNG')
60
+ plt.figure(figsize = (12, 6), dpi = 150)
61
+ plt.imshow(img)
62
+ plt.axis('off')
63
+ plt.show()
64
+
65
+ return plt.gcf()
66
+
67
 
68
  # def predict(message, history):
69
  # history_langchain_format = []
 
244
  )
245
 
246
  gr.Markdown("## Top words of the relevant Q&A")
247
+ with gr.Row():
248
+ topic_plot = gr.Plot(
249
+ label="Topic Distribution",
250
+ container=True, # Ensures the plot is contained within its area
251
+ elem_classes="topic-plot" # Add a custom class for styling
252
+ )
253
 
254
  # Add custom CSS to ensure proper plot sizing
255
  gr.HTML("""
256
  <style>
 
 
 
 
 
257
  .heatmap-plot {
258
  min-height: 400px;
259
  width: 100%;
260
  margin: auto;
261
  }
262
+ .topic-plot {
263
+ min-width: 600px;
264
+ height: 100%;
265
+ margin: auto;
266
+ }
267
  </style>
268
  """)
269
  # topics_df = gr.Dataframe(value=df, label="Data Input")
 
282
  outputs = [time_series_fig, linePlot_status_text]
283
  )
284
 
285
+ # Update both outputs when submit is clicked
286
+ topic_btn.click(
287
+ fn= topic_plot_gener,
288
+ inputs=[query_input, year_selector],
289
+ outputs= topic_plot
290
+ )
291
 
292
 
293
  if __name__ == "__main__":