Update app.py
Browse files
app.py
CHANGED
@@ -69,7 +69,7 @@ def convert_numpy_types(data):
|
|
69 |
def predict(image):
|
70 |
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
|
71 |
pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth"
|
72 |
-
device = "
|
73 |
|
74 |
model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device))
|
75 |
model.to(device).eval()
|
@@ -87,7 +87,8 @@ def predict(image):
|
|
87 |
with torch.no_grad():
|
88 |
img_emb = model2.encode_image(inputs)
|
89 |
img_emb = normalized(img_emb.cpu().numpy())
|
90 |
-
|
|
|
91 |
|
92 |
result = {
|
93 |
"clip_aesthetic": prediction,
|
|
|
69 |
def predict(image):
|
70 |
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
|
71 |
pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth"
|
72 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
73 |
|
74 |
model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device))
|
75 |
model.to(device).eval()
|
|
|
87 |
with torch.no_grad():
|
88 |
img_emb = model2.encode_image(inputs)
|
89 |
img_emb = normalized(img_emb.cpu().numpy())
|
90 |
+
img_emb_tensor = torch.from_numpy(img_emb).to(device)
|
91 |
+
prediction = model(img_emb_tensor).item()
|
92 |
|
93 |
result = {
|
94 |
"clip_aesthetic": prediction,
|