cdnuts commited on
Commit
5e7db30
·
verified ·
1 Parent(s): b4dd74c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -119,6 +119,8 @@ model = timm.create_model(
119
  ) # type: VisionTransformer
120
 
121
  safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
 
 
122
  model.eval()
123
 
124
  with open("tagger_tags.json", "r") as file:
@@ -135,7 +137,7 @@ def run_classifier(image, threshold):
135
  global sorted_tag_score
136
  img = image.convert('RGB')
137
  tensor = transform(img).unsqueeze(0)
138
-
139
  with torch.no_grad():
140
  logits = model(tensor)
141
  probabilities = torch.nn.functional.sigmoid(logits[0])
 
119
  ) # type: VisionTransformer
120
 
121
  safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
122
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ model.to(device)
124
  model.eval()
125
 
126
  with open("tagger_tags.json", "r") as file:
 
137
  global sorted_tag_score
138
  img = image.convert('RGB')
139
  tensor = transform(img).unsqueeze(0)
140
+ tensor = tensor.to(device)
141
  with torch.no_grad():
142
  logits = model(tensor)
143
  probabilities = torch.nn.functional.sigmoid(logits[0])