Vera-ZWY commited on
Commit
4f6e76c
1 Parent(s): baddb8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -75
app.py CHANGED
@@ -36,9 +36,15 @@ def stream_chat_with_rag(
36
  print(answer)
37
  print("top works from API:")
38
  print(fig)
39
-
 
 
 
 
 
 
40
  # return answer, fig
41
- return answer
42
 
43
 
44
 
@@ -81,79 +87,16 @@ def heatmap(top_n):
81
 
82
  return plt.gcf()
83
 
84
- # def linePlot_time_series(viz_type, weight, top_n):
85
- # result = client.predict(
86
- # viz_type=viz_type,
87
- # weight=weight,
88
- # top_n=top_n,
89
- # api_name="/linePlot_time_series"
90
- # )
91
-
92
- # print("============== timeseries df transfer from pivate to public ===============")
93
- # print(result)
94
- # print(type(result))
95
-
96
- # df = pd.DataFrame(result['data'], columns=result['headers'])
97
- # df.set_index('Index', inplace=True)
98
- # return df
99
-
100
-
101
- # def update_visualization(viz_type, weight, top_n):
102
- # """
103
- # Update visualization based on user inputs and selected visualization type
104
-
105
- # Parameters:
106
- # -----------
107
- # viz_type : str
108
- # Type of visualization to show ('emotions', 'topics', or 'grid')
109
- # weight : float
110
- # Weight for scoring (0-1)
111
- # top_n : int
112
- # Number of top items to show
113
- # """
114
- # try:
115
-
116
- # # return None, "Error: Start date must be before end date"
117
- # series = linePlot_time_series(viz_type, weight, top_n)
118
- # if viz_type == "emotions":
119
- # # Create emotion time series
120
- # # series = linePlot_time_series(viz_type, weight, top_n)
121
- # fig = plot_stacked_time_series(
122
- # series,
123
- # f'Top {top_n} Emotions Popularity'
124
- # )
125
- # message = "Emotion time series updated"
126
-
127
- # elif viz_type == "topics":
128
- # # Create topic time series
129
- # # series = linePlot_time_series(viz_type, weight, top_n)
130
- # fig = plot_stacked_time_series(
131
- # series,
132
- # f'Top {top_n} Topics Popularity'
133
- # )
134
- # message = "Topic time series updated"
135
-
136
- # else: # viz_type == "grid"
137
- # # Create emotion-topic grid
138
- # # pair_series = linePlot_time_series(viz_type, weight, top_n)
139
- # fig = plot_emotion_topic_grid(series, top_n)
140
- # message = "Emotion-Topic grid updated"
141
-
142
- # return fig, message
143
-
144
- # except Exception as e:
145
- # return None, f"Error: {str(e)}"
146
-
147
 
148
 
149
- def decode_plot(plot_base64, top_n):
150
- plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
151
- img = plt.imread(BytesIO(plot_bytes), format='PNG')
152
- plt.figure(figsize = (12, 2*top_n), dpi = 150)
153
- plt.imshow(img)
154
- plt.axis('off')
155
- plt.show()
156
- return plt.gcf()
157
 
158
 
159
  def linePlot(viz_type, weight, top_n):
@@ -165,8 +108,16 @@ def linePlot(viz_type, weight, top_n):
165
  api_name="/linePlot_3C1"
166
  )
167
  # print(result)
168
- # result is a tuble of dictionary of plot_base64, and a string message of description of the plot
169
- return decode_plot(result[0],top_n), result[1]
 
 
 
 
 
 
 
 
170
 
171
 
172
 
 
36
  print(answer)
37
  print("top works from API:")
38
  print(fig)
39
+
40
+ plot_bytes = base64.b64decode(fig['plot'].split(',')[1])
41
+ img = plt.imread(BytesIO(fig), format='PNG')
42
+ plt.figure(dpi = 150)
43
+ plt.imshow(img)
44
+ plt.axis('off')
45
+ plt.show()
46
  # return answer, fig
47
+ return answe, plt.gcf()
48
 
49
 
50
 
 
87
 
88
  return plt.gcf()
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
 
92
+ # def decode_plot(plot_base64, top_n):
93
+ # plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
94
+ # img = plt.imread(BytesIO(plot_bytes), format='PNG')
95
+ # plt.figure(figsize = (12, 2*top_n), dpi = 150)
96
+ # plt.imshow(img)
97
+ # plt.axis('off')
98
+ # plt.show()
99
+ # return plt.gcf()
100
 
101
 
102
  def linePlot(viz_type, weight, top_n):
 
108
  api_name="/linePlot_3C1"
109
  )
110
  # print(result)
111
+ # result is a tuble of dictionary of (plot_base64, str), string message of description of the plot
112
+ plot_base64 = result[0]
113
+
114
+ plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
115
+ img = plt.imread(BytesIO(plot_bytes), format='PNG')
116
+ plt.figure(figsize = (12, 2*top_n), dpi = 150)
117
+ plt.imshow(img)
118
+ plt.axis('off')
119
+ plt.show()
120
+ return plt.gcf(), result[1]
121
 
122
 
123