pengHTYX commited on
Commit
501a6cc
1 Parent(s): 4799ad7
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -4,7 +4,7 @@ import fire
4
  import gradio as gr
5
  from PIL import Image
6
  from functools import partial
7
-
8
  import cv2
9
  import time
10
  import numpy as np
@@ -62,16 +62,16 @@ _GPU_ID = 0
62
  if not hasattr(Image, 'Resampling'):
63
  Image.Resampling = Image
64
 
65
-
66
  def sam_init():
67
  sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
68
  model_type = "vit_h"
69
 
70
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
71
  predictor = SamPredictor(sam)
72
  return predictor
73
 
74
-
75
  def sam_segment(predictor, input_image, *bbox_coords):
76
  bbox = np.array(bbox_coords)
77
  image = np.asarray(input_image)
@@ -143,7 +143,7 @@ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=Fal
143
  input_image = expand2square(input_image, (127, 127, 127, 0))
144
  return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
145
 
146
-
147
  def load_era3d_pipeline(cfg):
148
  # Load scheduler, tokenizer and models.
149
 
@@ -153,7 +153,7 @@ def load_era3d_pipeline(cfg):
153
  )
154
 
155
  if torch.cuda.is_available():
156
- pipeline.to('cuda:0')
157
  pipeline.unet.enable_xformers_memory_efficient_attention()
158
  # sys.main_lock = threading.Lock()
159
  return pipeline
@@ -168,7 +168,7 @@ def prepare_data(single_image, crop_size, cfg):
168
  return dataset[0]
169
 
170
  scene = 'scene'
171
-
172
  def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, chk_group=None):
173
  import pdb
174
  global scene
@@ -302,7 +302,7 @@ def run_demo():
302
 
303
  pipeline = load_era3d_pipeline(cfg)
304
  torch.set_grad_enabled(False)
305
- pipeline.to(f'cuda:{_GPU_ID}')
306
 
307
  predictor = sam_init()
308
 
 
4
  import gradio as gr
5
  from PIL import Image
6
  from functools import partial
7
+ +import spaces
8
  import cv2
9
  import time
10
  import numpy as np
 
62
  if not hasattr(Image, 'Resampling'):
63
  Image.Resampling = Image
64
 
65
66
  def sam_init():
67
  sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
68
  model_type = "vit_h"
69
 
70
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device="cuda")
71
  predictor = SamPredictor(sam)
72
  return predictor
73
 
74
75
  def sam_segment(predictor, input_image, *bbox_coords):
76
  bbox = np.array(bbox_coords)
77
  image = np.asarray(input_image)
 
143
  input_image = expand2square(input_image, (127, 127, 127, 0))
144
  return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
145
 
146
147
  def load_era3d_pipeline(cfg):
148
  # Load scheduler, tokenizer and models.
149
 
 
153
  )
154
 
155
  if torch.cuda.is_available():
156
+ pipeline.to('cuda')
157
  pipeline.unet.enable_xformers_memory_efficient_attention()
158
  # sys.main_lock = threading.Lock()
159
  return pipeline
 
168
  return dataset[0]
169
 
170
  scene = 'scene'
171
172
  def run_pipeline(pipeline, cfg, single_image, guidance_scale, steps, seed, crop_size, chk_group=None):
173
  import pdb
174
  global scene
 
302
 
303
  pipeline = load_era3d_pipeline(cfg)
304
  torch.set_grad_enabled(False)
305
+ pipeline.to('cuda')
306
 
307
  predictor = sam_init()
308