Vageesh1 commited on
Commit
bbba82b
·
1 Parent(s): c0ebb9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -56,9 +56,9 @@ def ui():
56
  pil_image = PIL.Image.fromarray(image)
57
  image = preprocess(pil_image).unsqueeze(0).to(device)
58
 
59
- option = st.selectbox('Please select the Model',('PreTrained Model','Trained Model'))
60
 
61
- if option=='PreTrained Model':
62
  with torch.no_grad():
63
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
64
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
@@ -66,12 +66,12 @@ def ui():
66
 
67
  st.image(uploaded_file, width = 500, channels = 'RGB')
68
  st.markdown("**PREDICTION:** " + generated_text_prefix)
69
- elif option=='Trained Model':
70
  out = inference(uploaded_file)
71
  st.image(uploaded_file, width = 500, channels = 'RGB')
72
  st.markdown("**PREDICTION:** " + out)
73
 
74
- elif option=='Fine Tuned Model':
75
  out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
76
  st.image(uploaded_file, width = 500, channels = 'RGB')
77
  st.markdown("**PREDICTION:** " + out)
 
56
  pil_image = PIL.Image.fromarray(image)
57
  image = preprocess(pil_image).unsqueeze(0).to(device)
58
 
59
+ option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2'))
60
 
61
+ if option=='Clip Captioning':
62
  with torch.no_grad():
63
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
64
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
 
66
 
67
  st.image(uploaded_file, width = 500, channels = 'RGB')
68
  st.markdown("**PREDICTION:** " + generated_text_prefix)
69
+ elif option=='Attention Decoder':
70
  out = inference(uploaded_file)
71
  st.image(uploaded_file, width = 500, channels = 'RGB')
72
  st.markdown("**PREDICTION:** " + out)
73
 
74
+ elif option=='VIT+GPT2':
75
  out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
76
  st.image(uploaded_file, width = 500, channels = 'RGB')
77
  st.markdown("**PREDICTION:** " + out)