Spaces:
Runtime error
Runtime error
mojtaba-nafez
commited on
Commit
·
db11718
1
Parent(s):
1bc9b9d
add image projection weights and configs
Browse files- app.py +11 -10
- config.py +1 -1
- projections/resnet50_best_image_projection.pt +3 -0
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from models import
|
2 |
-
from inference import
|
3 |
from utils import get_poem_embeddings
|
4 |
import config as CFG
|
5 |
import json
|
@@ -10,7 +10,11 @@ def greet_user(name):
|
|
10 |
return "Hello " + name + " Welcome to Gradio!😎"
|
11 |
|
12 |
if __name__ == "__main__":
|
13 |
-
model =
|
|
|
|
|
|
|
|
|
14 |
model.eval()
|
15 |
# Inference: Output some example predictions and write them in a file
|
16 |
with open('poem_embeddings.json', encoding="utf-8") as f:
|
@@ -20,17 +24,14 @@ if __name__ == "__main__":
|
|
20 |
print(poem_embeddings.shape)
|
21 |
poems = [p['beyt'] for p in pe]
|
22 |
|
23 |
-
def gradio_make_predictions(
|
24 |
-
beyts =
|
25 |
return "\n".join(beyts)
|
26 |
|
27 |
CFG.batch_size = 512
|
28 |
-
# print(poem_embeddings[0])
|
29 |
-
# with open('poem_embeddings.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f:
|
30 |
-
# f.write(json.dumps(poem_embeddings, indent= 4))
|
31 |
|
32 |
-
|
33 |
output = gr.Textbox()
|
34 |
|
35 |
-
app = gr.Interface(fn = gradio_make_predictions, inputs=
|
36 |
app.launch()
|
|
|
1 |
+
from models import CLIPModel
|
2 |
+
from inference import predict_poems_from_image
|
3 |
from utils import get_poem_embeddings
|
4 |
import config as CFG
|
5 |
import json
|
|
|
10 |
return "Hello " + name + " Welcome to Gradio!😎"
|
11 |
|
12 |
if __name__ == "__main__":
|
13 |
+
model = CLIPModel(image_encoder_pretrained=True,
|
14 |
+
text_encoder_pretrained=True,
|
15 |
+
text_projection_trainable=False,
|
16 |
+
is_image_poem_pair=True
|
17 |
+
).to(CFG.device)
|
18 |
model.eval()
|
19 |
# Inference: Output some example predictions and write them in a file
|
20 |
with open('poem_embeddings.json', encoding="utf-8") as f:
|
|
|
24 |
print(poem_embeddings.shape)
|
25 |
poems = [p['beyt'] for p in pe]
|
26 |
|
27 |
+
def gradio_make_predictions(image):
|
28 |
+
beyts = predict_poems_from_image(model, poem_embeddings, image, poems, n=10)
|
29 |
return "\n".join(beyts)
|
30 |
|
31 |
CFG.batch_size = 512
|
|
|
|
|
|
|
32 |
|
33 |
+
image_input = gr.Image(type="filepath")
|
34 |
output = gr.Textbox()
|
35 |
|
36 |
+
app = gr.Interface(fn = gradio_make_predictions, inputs=image_input, outputs=output)
|
37 |
app.launch()
|
config.py
CHANGED
@@ -103,7 +103,7 @@ image_encoder_weights_load_path = ""
|
|
103 |
image_encoder_weights_save_path = "{}_best_image_encoder.pt".format(image_encoder_model)
|
104 |
image_embedding = 2048 # embedding dim of image encoder's output (for one token)
|
105 |
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
|
106 |
-
image_projection_load_path = ""
|
107 |
# path to save projection to
|
108 |
image_projection_save_path = "{}_best_image_projection.pt".format(image_encoder_model)
|
109 |
image_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
|
|
|
103 |
image_encoder_weights_save_path = "{}_best_image_encoder.pt".format(image_encoder_model)
|
104 |
image_embedding = 2048 # embedding dim of image encoder's output (for one token)
|
105 |
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
|
106 |
+
image_projection_load_path = os.path.join(file_dirname, "projections/{}_best_image_projection.pt".format(image_encoder_model))
|
107 |
# path to save projection to
|
108 |
image_projection_save_path = "{}_best_image_projection.pt".format(image_encoder_model)
|
109 |
image_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
|
projections/resnet50_best_image_projection.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:145f6d7ab06dac0f59906a6c62d19cdfaa5e09e8b0c3a5f2e1a3c975f31ca184
|
3 |
+
size 12601871
|