levihsu commited on
Commit
a0690fd
·
verified ·
1 Parent(s): c1569cf

Update ootd/inference_ootd.py

Browse files
Files changed (1) hide show
  1. ootd/inference_ootd.py +7 -7
ootd/inference_ootd.py CHANGED
@@ -33,7 +33,7 @@ MODEL_PATH = "./checkpoints/ootd"
33
  class OOTDiffusion:
34
 
35
  def __init__(self, gpu_id):
36
- self.gpu_id = 'cuda:' + str(gpu_id)
37
 
38
  vae = AutoencoderKL.from_pretrained(
39
  VAE_PATH,
@@ -64,12 +64,12 @@ class OOTDiffusion:
64
  use_safetensors=True,
65
  safety_checker=None,
66
  requires_safety_checker=False,
67
- ).to(self.gpu_id)
68
 
69
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
70
 
71
  self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
72
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
73
 
74
  self.tokenizer = CLIPTokenizer.from_pretrained(
75
  MODEL_PATH,
@@ -78,7 +78,7 @@ class OOTDiffusion:
78
  self.text_encoder = CLIPTextModel.from_pretrained(
79
  MODEL_PATH,
80
  subfolder="text_encoder",
81
- ).to(self.gpu_id)
82
 
83
 
84
  def tokenize_captions(self, captions, max_length):
@@ -107,14 +107,14 @@ class OOTDiffusion:
107
  generator = torch.manual_seed(seed)
108
 
109
  with torch.no_grad():
110
- prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
111
  prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
112
  prompt_image = prompt_image.unsqueeze(1)
113
  if model_type == 'hd':
114
- prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
115
  prompt_embeds[:, 1:] = prompt_image[:]
116
  elif model_type == 'dc':
117
- prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
118
  prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
119
  else:
120
  raise ValueError("model_type must be \'hd\' or \'dc\'!")
 
33
  class OOTDiffusion:
34
 
35
  def __init__(self, gpu_id):
36
+ # self.gpu_id = 'cuda:' + str(gpu_id)
37
 
38
  vae = AutoencoderKL.from_pretrained(
39
  VAE_PATH,
 
64
  use_safetensors=True,
65
  safety_checker=None,
66
  requires_safety_checker=False,
67
+ )#.to(self.gpu_id)
68
 
69
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
70
 
71
  self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
72
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)#.to(self.gpu_id)
73
 
74
  self.tokenizer = CLIPTokenizer.from_pretrained(
75
  MODEL_PATH,
 
78
  self.text_encoder = CLIPTextModel.from_pretrained(
79
  MODEL_PATH,
80
  subfolder="text_encoder",
81
+ )#.to(self.gpu_id)
82
 
83
 
84
  def tokenize_captions(self, captions, max_length):
 
107
  generator = torch.manual_seed(seed)
108
 
109
  with torch.no_grad():
110
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to('cuda')
111
  prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
112
  prompt_image = prompt_image.unsqueeze(1)
113
  if model_type == 'hd':
114
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to('cuda'))[0]
115
  prompt_embeds[:, 1:] = prompt_image[:]
116
  elif model_type == 'dc':
117
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to('cuda'))[0]
118
  prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
119
  else:
120
  raise ValueError("model_type must be \'hd\' or \'dc\'!")