mojtaba-nafez commited on
Commit
db11718
·
1 Parent(s): 1bc9b9d

add image projection weights and configs

Browse files
app.py CHANGED
@@ -1,5 +1,5 @@
1
- from models import PoemTextModel
2
- from inference import predict_poems_from_text
3
  from utils import get_poem_embeddings
4
  import config as CFG
5
  import json
@@ -10,7 +10,11 @@ def greet_user(name):
10
  return "Hello " + name + " Welcome to Gradio!😎"
11
 
12
  if __name__ == "__main__":
13
- model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device)
 
 
 
 
14
  model.eval()
15
  # Inference: Output some example predictions and write them in a file
16
  with open('poem_embeddings.json', encoding="utf-8") as f:
@@ -20,17 +24,14 @@ if __name__ == "__main__":
20
  print(poem_embeddings.shape)
21
  poems = [p['beyt'] for p in pe]
22
 
23
- def gradio_make_predictions(text):
24
- beyts = predict_poems_from_text(model, poem_embeddings, text, poems, n=10)
25
  return "\n".join(beyts)
26
 
27
  CFG.batch_size = 512
28
- # print(poem_embeddings[0])
29
- # with open('poem_embeddings.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
30
- # f.write(json.dumps(poem_embeddings, indent= 4))
31
 
32
- text_input = gr.Textbox(label = "Enter the text to find poem beyts for")
33
  output = gr.Textbox()
34
 
35
- app = gr.Interface(fn = gradio_make_predictions, inputs=text_input, outputs=output)
36
  app.launch()
 
1
+ from models import CLIPModel
2
+ from inference import predict_poems_from_image
3
  from utils import get_poem_embeddings
4
  import config as CFG
5
  import json
 
10
  return "Hello " + name + " Welcome to Gradio!😎"
11
 
12
  if __name__ == "__main__":
13
+ model = CLIPModel(image_encoder_pretrained=True,
14
+ text_encoder_pretrained=True,
15
+ text_projection_trainable=False,
16
+ is_image_poem_pair=True
17
+ ).to(CFG.device)
18
  model.eval()
19
  # Inference: Output some example predictions and write them in a file
20
  with open('poem_embeddings.json', encoding="utf-8") as f:
 
24
  print(poem_embeddings.shape)
25
  poems = [p['beyt'] for p in pe]
26
 
27
+ def gradio_make_predictions(image):
28
+ beyts = predict_poems_from_image(model, poem_embeddings, image, poems, n=10)
29
  return "\n".join(beyts)
30
 
31
  CFG.batch_size = 512
 
 
 
32
 
33
+ image_input = gr.Image(type="filepath")
34
  output = gr.Textbox()
35
 
36
+ app = gr.Interface(fn = gradio_make_predictions, inputs=image_input, outputs=output)
37
  app.launch()
config.py CHANGED
@@ -103,7 +103,7 @@ image_encoder_weights_load_path = ""
103
  image_encoder_weights_save_path = "{}_best_image_encoder.pt".format(image_encoder_model)
104
  image_embedding = 2048 # embedding dim of image encoder's output (for one token)
105
  # keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
106
- image_projection_load_path = ""
107
  # path to save projection to
108
  image_projection_save_path = "{}_best_image_projection.pt".format(image_encoder_model)
109
  image_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
 
103
  image_encoder_weights_save_path = "{}_best_image_encoder.pt".format(image_encoder_model)
104
  image_embedding = 2048 # embedding dim of image encoder's output (for one token)
105
  # keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
106
+ image_projection_load_path = os.path.join(file_dirname, "projections/{}_best_image_projection.pt".format(image_encoder_model))
107
  # path to save projection to
108
  image_projection_save_path = "{}_best_image_projection.pt".format(image_encoder_model)
109
  image_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
projections/resnet50_best_image_projection.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:145f6d7ab06dac0f59906a6c62d19cdfaa5e09e8b0c3a5f2e1a3c975f31ca184
3
+ size 12601871