Lwasinam commited on
Commit
030fbc6
·
verified ·
1 Parent(s): 98fb2b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -31
app.py CHANGED
@@ -24,10 +24,6 @@ def process(model,image, tokenizer, device):
24
  model.eval()
25
  with torch.no_grad():
26
  encoder_input = image.unsqueeze(0).to(device) # (b, seq_len)
27
- # decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
28
- # encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
29
- # decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
30
-
31
  model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device)
32
  model_text = tokenizer.decode(model_out.detach().cpu().numpy())
33
  return model_text
@@ -111,13 +107,6 @@ def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
111
 
112
  # Append next word
113
  decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1)
114
- # # get next token
115
- # prob = model.project(out[:, -1])
116
- # _, next_word = torch.max(prob, dim=1)
117
- # # print(f'prob: {prob.shape}')
118
- # decoder_input = torch.cat(
119
- # [decoder_input, torch.empty(1, 1).long().fill_(next_word.item()).to(device)], dim=1
120
- # )
121
 
122
  if next_word.item() == eos_idx:
123
  break
@@ -127,7 +116,7 @@ def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
127
  def image_base64(image):
128
 
129
 
130
- # with open('C:/AI/projects/vision_model_pretrained/validation/content/memory_image_23330.jpg', 'rb') as image_file:
131
  base64_bytes = base64.b64encode(image_file.read())
132
 
133
 
@@ -135,27 +124,11 @@ def image_base64(image):
135
  return base64_string
136
 
137
 
138
- def start():
139
- print('start')
140
- accelerator = Accelerator()
141
- device = accelerator.device
142
-
143
- config = get_config()
144
- tokenizer = get_or_build_tokenizer(config)
145
- model = get_model(config, len(tokenizer))
146
- model = accelerator.prepare(model)
147
- accelerator.load_state('model.tensors')
148
-
149
- image = image_base64()
150
-
151
-
152
 
153
- process(model, image, tokenizer, device)
154
 
155
- # start()
156
 
157
  def main():
158
- st.title("Image Captioning with Transformer Models")
159
  image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
160
 
161
  if image is not None:
@@ -173,8 +146,7 @@ def main():
173
  model = get_model(config, len(tokenizer))
174
  model = accelerator.prepare(model)
175
  accelerator.load_state('models/')
176
- # model = get_model(config, len(tokenizer))
177
- # model.to(device)
178
 
179
 
180
  text_output = process(model, image, tokenizer, device)
 
24
  model.eval()
25
  with torch.no_grad():
26
  encoder_input = image.unsqueeze(0).to(device) # (b, seq_len)
 
 
 
 
27
  model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device)
28
  model_text = tokenizer.decode(model_out.detach().cpu().numpy())
29
  return model_text
 
107
 
108
  # Append next word
109
  decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1)
 
 
 
 
 
 
 
110
 
111
  if next_word.item() == eos_idx:
112
  break
 
116
  def image_base64(image):
117
 
118
 
119
+
120
  base64_bytes = base64.b64encode(image_file.read())
121
 
122
 
 
124
  return base64_string
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
 
128
 
 
129
 
130
  def main():
131
+ st.title("Image Captioning with Vision Transformer")
132
  image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
133
 
134
  if image is not None:
 
146
  model = get_model(config, len(tokenizer))
147
  model = accelerator.prepare(model)
148
  accelerator.load_state('models/')
149
+
 
150
 
151
 
152
  text_output = process(model, image, tokenizer, device)