Vageesh1 commited on
Commit
e74c12b
·
1 Parent(s): 524180e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -13
app.py CHANGED
@@ -56,9 +56,9 @@ def ui():
56
  pil_image = PIL.Image.fromarray(image)
57
  image = preprocess(pil_image).unsqueeze(0).to(device)
58
 
59
- option = st.selectbox('Please select the Model',('Model', 'COCO Model','PreTrained Model','Fine Tuned Model'))
60
 
61
- if option=='Model':
62
  with torch.no_grad():
63
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
64
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
@@ -66,17 +66,7 @@ def ui():
66
 
67
  st.image(uploaded_file, width = 500, channels = 'RGB')
68
  st.markdown("**PREDICTION:** " + generated_text_prefix)
69
-
70
- elif option=='COCO Model':
71
- with torch.no_grad():
72
- prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
73
- prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
74
- generated_text_prefix = generate2(coco_model, tokenizer, embed=prefix_embed)
75
-
76
- st.image(uploaded_file, width = 500, channels = 'RGB')
77
- st.markdown("**PREDICTION:** " + generated_text_prefix)
78
-
79
- elif option=='PreTrained Model':
80
  out = inference(uploaded_file)
81
  st.image(uploaded_file, width = 500, channels = 'RGB')
82
  st.markdown("**PREDICTION:** " + out)
 
56
  pil_image = PIL.Image.fromarray(image)
57
  image = preprocess(pil_image).unsqueeze(0).to(device)
58
 
59
+ option = st.selectbox('Please select the Model',(' PreTrained Model','Trained Model','Fine Tuned Model'))
60
 
61
+ if option=='PreTrained Model':
62
  with torch.no_grad():
63
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
64
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
 
66
 
67
  st.image(uploaded_file, width = 500, channels = 'RGB')
68
  st.markdown("**PREDICTION:** " + generated_text_prefix)
69
+ elif option=='Trained Model':
 
 
 
 
 
 
 
 
 
 
70
  out = inference(uploaded_file)
71
  st.image(uploaded_file, width = 500, channels = 'RGB')
72
  st.markdown("**PREDICTION:** " + out)