shredder-31 commited on
Commit
b43f4b4
·
verified ·
1 Parent(s): 9907984

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -8
main.py CHANGED
@@ -6,9 +6,8 @@ import torchvision.transforms as T
6
  from utils import load_checkpoint
7
  from trainning import ImgCap, beam_search_caption, decoder
8
 
9
- def ImgCap_inference(img, beam_width):
10
- root_path = "/teamspace/studios/this_studio"
11
- with open(f"{root_path}/ImgCap/vocab.pkl", 'rb') as f:
12
  vocab = pickle.load(f)
13
 
14
  transforms = T.Compose([
@@ -18,11 +17,13 @@ def ImgCap_inference(img, beam_width):
18
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
19
  ])
20
 
21
- checkpoint_path = f"{root_path}/ImgCap/trainning/checkpoints/checkpoint_epoch_40.pth"
22
-
23
  model = ImgCap(cnn_feature_size=1024, lstm_hidden_size=1024, embedding_dim=1024, num_layers=2, vocab_size=len(vocab))
24
  model, _, _, _, _, _, _ = load_checkpoint(checkpoint_path=checkpoint_path, model=model)
25
 
 
 
 
26
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
  img = transforms(img).unsqueeze(0)
28
 
@@ -32,15 +33,27 @@ def ImgCap_inference(img, beam_width):
32
 
33
  if __name__ == "__main__":
34
  footer_html = "<p style='text-align: center; font-size: 16px;'>Developed by Sherif Ahmed</p>"
 
 
 
 
 
 
 
 
 
 
35
 
36
  interface = gr.Interface(
37
- fn=ImgCap_inference,
38
  inputs=[
39
  'image',
40
  gr.Slider(minimum=1, maximum=5, step=1, label="Beam Width")
41
  ],
42
  outputs=gr.Textbox(label="Generated Caption"),
43
  title="ImgCap",
44
- article=footer_html
 
45
  )
46
- interface.launch()
 
 
6
  from utils import load_checkpoint
7
  from trainning import ImgCap, beam_search_caption, decoder
8
 
9
+ def initialize():
10
+ with open("vocab.pkl", 'rb') as f:
 
11
  vocab = pickle.load(f)
12
 
13
  transforms = T.Compose([
 
17
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
+ checkpoint_path = "checkpoint_epoch_40.pth"
 
21
  model = ImgCap(cnn_feature_size=1024, lstm_hidden_size=1024, embedding_dim=1024, num_layers=2, vocab_size=len(vocab))
22
  model, _, _, _, _, _, _ = load_checkpoint(checkpoint_path=checkpoint_path, model=model)
23
 
24
+ return model, vocab, transforms
25
+
26
+ def ImgCap_inference(img, beam_width, model, vocab, transforms):
27
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28
  img = transforms(img).unsqueeze(0)
29
 
 
33
 
34
  if __name__ == "__main__":
35
  footer_html = "<p style='text-align: center; font-size: 16px;'>Developed by Sherif Ahmed</p>"
36
+
37
+ img1_path = "1 (1).jpeg"
38
+ img2_path = "1 (2).jpg"
39
+
40
+ examples = [
41
+ [img1_path, 2],
42
+ [img2_path, 5],
43
+ ]
44
+
45
+ model, vocab, transforms = initialize(root_path)
46
 
47
  interface = gr.Interface(
48
+ fn=lambda img, beam_width: ImgCap_inference(img, beam_width, model, vocab, transforms),
49
  inputs=[
50
  'image',
51
  gr.Slider(minimum=1, maximum=5, step=1, label="Beam Width")
52
  ],
53
  outputs=gr.Textbox(label="Generated Caption"),
54
  title="ImgCap",
55
+ article=footer_html,
56
+ examples=examples
57
  )
58
+
59
+ interface.launch(debug=True)