Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -56,9 +56,9 @@ def ui():
|
|
56 |
pil_image = PIL.Image.fromarray(image)
|
57 |
image = preprocess(pil_image).unsqueeze(0).to(device)
|
58 |
|
59 |
-
option = st.selectbox('Please select the Model',('
|
60 |
|
61 |
-
if option=='Model':
|
62 |
with torch.no_grad():
|
63 |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
64 |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
@@ -66,17 +66,7 @@ def ui():
|
|
66 |
|
67 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
68 |
st.markdown("**PREDICTION:** " + generated_text_prefix)
|
69 |
-
|
70 |
-
elif option=='COCO Model':
|
71 |
-
with torch.no_grad():
|
72 |
-
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
73 |
-
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
74 |
-
generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed)
|
75 |
-
|
76 |
-
st.image(uploaded_file, width = 500, channels = 'RGB')
|
77 |
-
st.markdown("**PREDICTION:** " + generated_text_prefix)
|
78 |
-
|
79 |
-
elif option=='PreTrained Model':
|
80 |
out = inference(uploaded_file)
|
81 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
82 |
st.markdown("**PREDICTION:** " + out)
|
|
|
56 |
pil_image = PIL.Image.fromarray(image)
|
57 |
image = preprocess(pil_image).unsqueeze(0).to(device)
|
58 |
|
59 |
+
option = st.selectbox('Please select the Model',(' PreTrained Model','Trained Model','Fine Tuned Model'))
|
60 |
|
61 |
+
if option=='PreTrained Model':
|
62 |
with torch.no_grad():
|
63 |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
64 |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
|
|
66 |
|
67 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
68 |
st.markdown("**PREDICTION:** " + generated_text_prefix)
|
69 |
+
elif option=='Trained Model':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
out = inference(uploaded_file)
|
71 |
st.image(uploaded_file, width = 500, channels = 'RGB')
|
72 |
st.markdown("**PREDICTION:** " + out)
|