user-agent commited on
Commit
4ecab25
·
verified ·
1 Parent(s): 75f65e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -7,10 +7,13 @@ from transformers import AutoModelForImageClassification, AutoConfig
7
  import gradio as gr
8
  import spaces
9
 
 
 
10
  model_id = "thelabel/240903-image-tagging"
11
  config = AutoConfig.from_pretrained(model_id)
12
  model = AutoModelForImageClassification.from_pretrained(model_id)
13
- model.eval()
 
14
 
15
  # Standard ViT image transforms
16
  image_transform = transforms.Compose([
@@ -33,7 +36,7 @@ def predict_tags(image_url, threshold=0.5):
33
  if image is None:
34
  return [], "Could not load image from the provided URL."
35
 
36
- image_tensor = image_transform(image).unsqueeze(0)
37
  with torch.no_grad():
38
  logits = model(image_tensor).logits
39
  probs = torch.sigmoid(logits).squeeze()
 
7
  import gradio as gr
8
  import spaces
9
 
10
+
11
+
12
  model_id = "thelabel/240903-image-tagging"
13
  config = AutoConfig.from_pretrained(model_id)
14
  model = AutoModelForImageClassification.from_pretrained(model_id)
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.to(device)
17
 
18
  # Standard ViT image transforms
19
  image_transform = transforms.Compose([
 
36
  if image is None:
37
  return [], "Could not load image from the provided URL."
38
 
39
+ image_tensor = image_transform(image).unsqueeze(0).to(device)
40
  with torch.no_grad():
41
  logits = model(image_tensor).logits
42
  probs = torch.sigmoid(logits).squeeze()