Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -25,6 +25,11 @@ current_vis = []
|
|
25 |
current_bpe = []
|
26 |
current_index = 0
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def load_model(check_type):
|
30 |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -130,11 +135,17 @@ def format_bpe_display(bpe):
|
|
130 |
# else:
|
131 |
# return None, "索引超出范围"
|
132 |
# 状态更新函数,利用传递的状态(vis, bpe)
|
133 |
-
|
|
|
|
|
|
|
134 |
if 0 <= x < len(vis):
|
135 |
-
return vis[x], format_bpe_display(bpe[x]),
|
136 |
else:
|
137 |
-
return None, "索引超出范围",
|
|
|
|
|
|
|
138 |
|
139 |
|
140 |
# Gradio界面
|
@@ -191,14 +202,17 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
191 |
# return image, vis, bpe_text, slider_max_val
|
192 |
|
193 |
@spaces.GPU
|
194 |
-
def on_run_clicked(model_type, image, text):
|
195 |
model, tokenizer, transform, device = load_model(model_type)
|
196 |
current_index = 0 # Reset index when new image is processed
|
197 |
image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
|
198 |
slider_max_val = len(bpe) - 1
|
199 |
bpe_text = format_bpe_display(bpe[current_index])
|
200 |
-
#
|
201 |
-
|
|
|
|
|
|
|
202 |
|
203 |
|
204 |
|
@@ -213,14 +227,12 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
213 |
# outputs=[prev_btn, index_slider, next_btn, bpe_display],
|
214 |
# )
|
215 |
# Gradio 按钮点击后的处理
|
|
|
216 |
run_btn.click(
|
217 |
on_run_clicked,
|
218 |
-
inputs=[model_type, image_input, text_input],
|
219 |
-
outputs=[orig_img, heatmap, bpe_display, index_slider, 'state',
|
220 |
-
|
221 |
-
lambda outputs: (gr.update(visible=True), gr.update(visible=True, maximum=outputs[3], value=0), gr.update(visible=True), gr.update(visible=True), outputs[4], outputs[5]),
|
222 |
-
inputs=index_slider,
|
223 |
-
outputs=[prev_btn, index_slider, next_btn, bpe_display, 'state', 'state']
|
224 |
)
|
225 |
|
226 |
prev_btn.click(
|
@@ -241,8 +253,8 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
|
|
241 |
# )
|
242 |
index_slider.change(
|
243 |
update_slider_index,
|
244 |
-
inputs=[index_slider, 'state'
|
245 |
-
outputs=[heatmap, bpe_display, 'state'
|
246 |
)
|
247 |
|
248 |
|
|
|
25 |
current_bpe = []
|
26 |
current_index = 0
|
27 |
|
28 |
+
# 设置初始状态
|
29 |
+
initial_state = {
|
30 |
+
"vis": [],
|
31 |
+
"bpe": []
|
32 |
+
}
|
33 |
|
34 |
def load_model(check_type):
|
35 |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
135 |
# else:
|
136 |
# return None, "索引超出范围"
|
137 |
# 状态更新函数,利用传递的状态(vis, bpe)
|
138 |
+
# 使用状态信息来处理滑动条改变
|
139 |
+
def update_slider_index(x, state):
|
140 |
+
vis = state['vis']
|
141 |
+
bpe = state['bpe']
|
142 |
if 0 <= x < len(vis):
|
143 |
+
return vis[x], format_bpe_display(bpe[x]), state
|
144 |
else:
|
145 |
+
return None, "索引超出范围", state
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
|
150 |
|
151 |
# Gradio界面
|
|
|
202 |
# return image, vis, bpe_text, slider_max_val
|
203 |
|
204 |
@spaces.GPU
|
205 |
+
def on_run_clicked(model_type, image, text, state):
|
206 |
model, tokenizer, transform, device = load_model(model_type)
|
207 |
current_index = 0 # Reset index when new image is processed
|
208 |
image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
|
209 |
slider_max_val = len(bpe) - 1
|
210 |
bpe_text = format_bpe_display(bpe[current_index])
|
211 |
+
# 更新状态并返回
|
212 |
+
state['vis'] = vis
|
213 |
+
state['bpe'] = bpe
|
214 |
+
return image, vis[current_index], bpe_text, slider_max_val, state
|
215 |
+
|
216 |
|
217 |
|
218 |
|
|
|
227 |
# outputs=[prev_btn, index_slider, next_btn, bpe_display],
|
228 |
# )
|
229 |
# Gradio 按钮点击后的处理
|
230 |
+
# Gradio 按钮点击后的处理
|
231 |
run_btn.click(
|
232 |
on_run_clicked,
|
233 |
+
inputs=[model_type, image_input, text_input, 'state'],
|
234 |
+
outputs=[orig_img, heatmap, bpe_display, index_slider, 'state'],
|
235 |
+
_js="{state: { vis: [], bpe: []}}"
|
|
|
|
|
|
|
236 |
)
|
237 |
|
238 |
prev_btn.click(
|
|
|
253 |
# )
|
254 |
index_slider.change(
|
255 |
update_slider_index,
|
256 |
+
inputs=[index_slider, 'state'],
|
257 |
+
outputs=[heatmap, bpe_display, 'state']
|
258 |
)
|
259 |
|
260 |
|