Image Classification
timm
drhead commited on
Commit
938a123
1 Parent(s): e8fc417

update inference script for comma separated tags and using cuda if available

Browse files
Files changed (1) hide show
  1. inference_gradio.py +18 -8
inference_gradio.py CHANGED
@@ -1,18 +1,15 @@
1
  import json
2
 
3
- from PIL import Image
4
  import gradio as gr
 
 
 
 
5
  import torch
6
  from torchvision.transforms import transforms
7
  from torchvision.transforms import InterpolationMode
8
  import torchvision.transforms.functional as TF
9
 
10
- import timm
11
- from timm.models import VisionTransformer
12
- import safetensors.torch
13
-
14
-
15
- torch.jit.script = lambda f: f
16
  torch.set_grad_enabled(False)
17
 
18
  class Fit(torch.nn.Module):
@@ -123,13 +120,26 @@ model = timm.create_model(
123
  safetensors.torch.load_model(model, "JTP_PILOT/JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
124
  model.eval()
125
 
 
 
 
 
 
126
  with open("JTP_PILOT/tags.json", "r") as file:
127
  tags = json.load(file) # type: dict
128
  allowed_tags = list(tags.keys())
129
 
 
 
 
130
  def create_tags(image, threshold):
131
  img = image.convert('RGB')
132
- tensor = transform(img).unsqueeze(0)
 
 
 
 
 
133
 
134
  with torch.no_grad():
135
  logits = model(tensor)
 
1
  import json
2
 
 
3
  import gradio as gr
4
+ from PIL import Image
5
+ import safetensors.torch
6
+ import timm
7
+ from timm.models import VisionTransformer
8
  import torch
9
  from torchvision.transforms import transforms
10
  from torchvision.transforms import InterpolationMode
11
  import torchvision.transforms.functional as TF
12
 
 
 
 
 
 
 
13
  torch.set_grad_enabled(False)
14
 
15
  class Fit(torch.nn.Module):
 
120
  safetensors.torch.load_model(model, "JTP_PILOT/JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
121
  model.eval()
122
 
123
+ if torch.cuda.is_available():
124
+ model.cuda()
125
+ if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
126
+ model.half()
127
+
128
  with open("JTP_PILOT/tags.json", "r") as file:
129
  tags = json.load(file) # type: dict
130
  allowed_tags = list(tags.keys())
131
 
132
+ for idx, tag in enumerate(allowed_tags):
133
+ allowed_tags[idx] = tag.replace("_", " ")
134
+
135
  def create_tags(image, threshold):
136
  img = image.convert('RGB')
137
+ tensor = transform(img).unsqueeze(0) # type: torch.Tensor
138
+
139
+ if torch.cuda.is_available():
140
+ tensor.cuda()
141
+ if torch.cuda.get_device_capability()[0] >= 7:
142
+ tensor.half()
143
 
144
  with torch.no_grad():
145
  logits = model(tensor)