jameslahm commited on
Commit
c06ddc9
·
verified ·
1 Parent(s): ddcc136

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -7,7 +7,9 @@ from ultralytics.utils.torch_utils import smart_inference_mode
7
  from ultralytics.models.yolo.yoloe.predict_vp import YOLOEVPSegPredictor
8
  from gradio_image_prompter import ImagePrompter
9
  from huggingface_hub import hf_hub_download
 
10
 
 
11
  def init_model(model_id, is_pf=False):
12
  if not is_pf:
13
  path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg.pt")
@@ -16,10 +18,9 @@ def init_model(model_id, is_pf=False):
16
  path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg-pf.pt")
17
  model = YOLOE(path)
18
  model.eval()
19
- model.to("cuda" if torch.cuda.is_available() else "cpu")
20
  return model
21
 
22
-
23
  @smart_inference_mode()
24
  def yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type):
25
  model = init_model(model_id)
 
7
  from ultralytics.models.yolo.yoloe.predict_vp import YOLOEVPSegPredictor
8
  from gradio_image_prompter import ImagePrompter
9
  from huggingface_hub import hf_hub_download
10
+ import spaces
11
 
12
+ @spaces.GPU
13
  def init_model(model_id, is_pf=False):
14
  if not is_pf:
15
  path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg.pt")
 
18
  path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg-pf.pt")
19
  model = YOLOE(path)
20
  model.eval()
 
21
  return model
22
 
23
+ @spaces.GPU
24
  @smart_inference_mode()
25
  def yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type):
26
  model = init_model(model_id)