haor commited on
Commit
a2004d8
·
verified ·
1 Parent(s): 7d44f91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -69,7 +69,7 @@ def convert_numpy_types(data):
69
  def predict(image):
70
  model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
71
  pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth"
72
- device = "cpu" if torch.cuda.is_available() else "cpu"
73
 
74
  model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device))
75
  model.to(device).eval()
@@ -87,7 +87,8 @@ def predict(image):
87
  with torch.no_grad():
88
  img_emb = model2.encode_image(inputs)
89
  img_emb = normalized(img_emb.cpu().numpy())
90
- prediction = model(torch.from_numpy(img_emb).to(device).type(torch.FloatTensor)).item()
 
91
 
92
  result = {
93
  "clip_aesthetic": prediction,
 
69
  def predict(image):
70
  model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
71
  pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth"
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
 
74
  model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device))
75
  model.to(device).eval()
 
87
  with torch.no_grad():
88
  img_emb = model2.encode_image(inputs)
89
  img_emb = normalized(img_emb.cpu().numpy())
90
+ img_emb_tensor = torch.from_numpy(img_emb).to(device)
91
+ prediction = model(img_emb_tensor).item()
92
 
93
  result = {
94
  "clip_aesthetic": prediction,