mikitona commited on
Commit
12179ac
·
verified ·
1 Parent(s): 6dfb4e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -41,15 +41,19 @@ def pil_to_binary_mask(pil_image, threshold=0):
41
  # Duration timeを設定
42
  duration = 60
43
 
 
 
44
  base_path = 'yisol/IDM-VTON'
45
  example_path = os.path.join(os.path.dirname(__file__), 'example')
46
 
 
47
  unet = UNet2DConditionModel.from_pretrained(
48
  base_path,
49
  subfolder="unet",
50
  torch_dtype=torch.float16,
51
- )
52
  unet.requires_grad_(False)
 
53
  tokenizer_one = AutoTokenizer.from_pretrained(
54
  base_path,
55
  subfolder="tokenizer",
@@ -68,38 +72,42 @@ text_encoder_one = CLIPTextModel.from_pretrained(
68
  base_path,
69
  subfolder="text_encoder",
70
  torch_dtype=torch.float16,
71
- )
72
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
73
  base_path,
74
  subfolder="text_encoder_2",
75
  torch_dtype=torch.float16,
76
- )
77
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
78
  base_path,
79
  subfolder="image_encoder",
80
  torch_dtype=torch.float16,
81
- )
82
  vae = AutoencoderKL.from_pretrained(
83
  base_path,
84
  subfolder="vae",
85
  torch_dtype=torch.float16,
86
- )
87
 
88
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
89
  base_path,
90
  subfolder="unet_encoder",
91
  torch_dtype=torch.float16,
92
- )
93
 
94
  parsing_model = Parsing(0)
95
  openpose_model = OpenPose(0)
96
 
 
 
 
97
  UNet_Encoder.requires_grad_(False)
98
  image_encoder.requires_grad_(False)
99
  vae.requires_grad_(False)
100
  unet.requires_grad_(False)
101
  text_encoder_one.requires_grad_(False)
102
  text_encoder_two.requires_grad_(False)
 
103
  tensor_transfrom = transforms.Compose(
104
  [
105
  transforms.ToTensor(),
@@ -119,20 +127,16 @@ pipe = TryonPipeline.from_pretrained(
119
  scheduler=noise_scheduler,
120
  image_encoder=image_encoder,
121
  torch_dtype=torch.float16,
122
- )
123
  pipe.unet_encoder = UNet_Encoder
124
 
125
  @spaces.GPU(duration=duration) # duration変数を使用
126
  def start_tryon(
127
  dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, num_images
128
  ):
129
- device = "cuda"
130
-
131
  start_time = time.time() # 処理開始時間を記録
132
 
133
- openpose_model.preprocessor.body_estimation.model.to(device)
134
- pipe.to(device)
135
- pipe.unet_encoder.to(device)
136
 
137
  garm_img = garm_img.convert("RGB").resize((768, 1024))
138
  human_img_orig = dict["background"].convert("RGB")
@@ -355,4 +359,4 @@ with image_blocks as demo:
355
  api_name='tryon',
356
  )
357
 
358
- image_blocks.launch(show_error=True)
 
41
  # Duration timeを設定
42
  duration = 60
43
 
44
+ device = "cuda"
45
+
46
  base_path = 'yisol/IDM-VTON'
47
  example_path = os.path.join(os.path.dirname(__file__), 'example')
48
 
49
+ # モデルのロードと初期化を関数外で行う
50
  unet = UNet2DConditionModel.from_pretrained(
51
  base_path,
52
  subfolder="unet",
53
  torch_dtype=torch.float16,
54
+ ).to(device)
55
  unet.requires_grad_(False)
56
+
57
  tokenizer_one = AutoTokenizer.from_pretrained(
58
  base_path,
59
  subfolder="tokenizer",
 
72
  base_path,
73
  subfolder="text_encoder",
74
  torch_dtype=torch.float16,
75
+ ).to(device)
76
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
77
  base_path,
78
  subfolder="text_encoder_2",
79
  torch_dtype=torch.float16,
80
+ ).to(device)
81
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
82
  base_path,
83
  subfolder="image_encoder",
84
  torch_dtype=torch.float16,
85
+ ).to(device)
86
  vae = AutoencoderKL.from_pretrained(
87
  base_path,
88
  subfolder="vae",
89
  torch_dtype=torch.float16,
90
+ ).to(device)
91
 
92
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
93
  base_path,
94
  subfolder="unet_encoder",
95
  torch_dtype=torch.float16,
96
+ ).to(device)
97
 
98
  parsing_model = Parsing(0)
99
  openpose_model = OpenPose(0)
100
 
101
+ # モデルをGPUに転送
102
+ parsing_model.model.to(device)
103
+ openpose_model.preprocessor.body_estimation.model.to(device)
104
  UNet_Encoder.requires_grad_(False)
105
  image_encoder.requires_grad_(False)
106
  vae.requires_grad_(False)
107
  unet.requires_grad_(False)
108
  text_encoder_one.requires_grad_(False)
109
  text_encoder_two.requires_grad_(False)
110
+
111
  tensor_transfrom = transforms.Compose(
112
  [
113
  transforms.ToTensor(),
 
127
  scheduler=noise_scheduler,
128
  image_encoder=image_encoder,
129
  torch_dtype=torch.float16,
130
+ ).to(device)
131
  pipe.unet_encoder = UNet_Encoder
132
 
133
  @spaces.GPU(duration=duration) # duration変数を使用
134
  def start_tryon(
135
  dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, num_images
136
  ):
 
 
137
  start_time = time.time() # 処理開始時間を記録
138
 
139
+ # device変数の再定義やモデルの.to(device)呼び出しを削除
 
 
140
 
141
  garm_img = garm_img.convert("RGB").resize((768, 1024))
142
  human_img_orig = dict["background"].convert("RGB")
 
359
  api_name='tryon',
360
  )
361
 
362
+ image_blocks.launch(show_error=True)