Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import pathlib
|
2 |
|
3 |
import gradio as gr
|
|
|
4 |
import open_clip
|
5 |
import torch
|
6 |
|
@@ -25,6 +26,7 @@ def output_generate(image):
|
|
25 |
generated = model.generate(im, seq_len=20)
|
26 |
return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
|
27 |
|
|
|
28 |
def inference_caption(image, decoding_method="Beam search", rep_penalty=1.2, top_p=0.5, min_seq_len=5, seq_len=20):
|
29 |
im = transform(image).unsqueeze(0).to(device)
|
30 |
generation_type = "beam_search" if decoding_method == "Beam search" else "top_p"
|
|
|
1 |
import pathlib
|
2 |
|
3 |
import gradio as gr
|
4 |
+
import spaces
|
5 |
import open_clip
|
6 |
import torch
|
7 |
|
|
|
26 |
generated = model.generate(im, seq_len=20)
|
27 |
return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
|
28 |
|
29 |
+
@spaces.GPU
|
30 |
def inference_caption(image, decoding_method="Beam search", rep_penalty=1.2, top_p=0.5, min_seq_len=5, seq_len=20):
|
31 |
im = transform(image).unsqueeze(0).to(device)
|
32 |
generation_type = "beam_search" if decoding_method == "Beam search" else "top_p"
|