mikitona commited on
Commit
762954c
·
verified ·
1 Parent(s): 2170957

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -28
app.py CHANGED
@@ -23,7 +23,7 @@ from preprocess.humanparsing.run_parsing import Parsing
23
  from preprocess.openpose.run_openpose import OpenPose
24
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
25
  from torchvision.transforms.functional import to_pil_image
26
- import time # timeモジュールをインポート
27
 
28
  def pil_to_binary_mask(pil_image, threshold=0):
29
  np_image = np.array(pil_image)
@@ -38,22 +38,16 @@ def pil_to_binary_mask(pil_image, threshold=0):
38
  output_mask = Image.fromarray(mask)
39
  return output_mask
40
 
41
- # Duration timeを設定
42
- duration = 60
43
-
44
- device = "cuda"
45
 
46
  base_path = 'yisol/IDM-VTON'
47
  example_path = os.path.join(os.path.dirname(__file__), 'example')
48
 
49
- # モデルのロードと初期化を関数外で行う
50
  unet = UNet2DConditionModel.from_pretrained(
51
  base_path,
52
  subfolder="unet",
53
  torch_dtype=torch.float16,
54
- ).to(device)
55
  unet.requires_grad_(False)
56
-
57
  tokenizer_one = AutoTokenizer.from_pretrained(
58
  base_path,
59
  subfolder="tokenizer",
@@ -72,43 +66,38 @@ text_encoder_one = CLIPTextModel.from_pretrained(
72
  base_path,
73
  subfolder="text_encoder",
74
  torch_dtype=torch.float16,
75
- ).to(device)
76
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
77
  base_path,
78
  subfolder="text_encoder_2",
79
  torch_dtype=torch.float16,
80
- ).to(device)
81
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
82
  base_path,
83
  subfolder="image_encoder",
84
  torch_dtype=torch.float16,
85
- ).to(device)
86
  vae = AutoencoderKL.from_pretrained(
87
  base_path,
88
  subfolder="vae",
89
  torch_dtype=torch.float16,
90
- ).to(device)
91
 
92
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
93
  base_path,
94
  subfolder="unet_encoder",
95
  torch_dtype=torch.float16,
96
- ).to(device)
97
 
98
  parsing_model = Parsing(0)
99
  openpose_model = OpenPose(0)
100
 
101
- # モデルをGPUに転送
102
- # 修正前: parsing_model.model.to(device)
103
- parsing_model.parsenet.to(device) # 修正後: 正しい属性名を使用
104
- openpose_model.preprocessor.body_estimation.model.to(device)
105
  UNet_Encoder.requires_grad_(False)
106
  image_encoder.requires_grad_(False)
107
  vae.requires_grad_(False)
108
  unet.requires_grad_(False)
109
  text_encoder_one.requires_grad_(False)
110
  text_encoder_two.requires_grad_(False)
111
-
112
  tensor_transfrom = transforms.Compose(
113
  [
114
  transforms.ToTensor(),
@@ -128,16 +117,19 @@ pipe = TryonPipeline.from_pretrained(
128
  scheduler=noise_scheduler,
129
  image_encoder=image_encoder,
130
  torch_dtype=torch.float16,
131
- ).to(device)
132
  pipe.unet_encoder = UNet_Encoder
133
 
134
- @spaces.GPU(duration=duration) # duration変数を使用
 
135
  def start_tryon(
136
  dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, num_images
137
  ):
138
- start_time = time.time() # 処理開始時間を記録
139
 
140
- # device変数の再定義やモデルの.to(device)呼び出しを削除
 
 
141
 
142
  garm_img = garm_img.convert("RGB").resize((768, 1024))
143
  human_img_orig = dict["background"].convert("RGB")
@@ -226,11 +218,6 @@ def start_tryon(
226
  yield output_images.value, mask_gray
227
 
228
  for i in range(int(num_images)):
229
- # 経過時間をチェック
230
- elapsed_time = time.time() - start_time
231
- if elapsed_time >= duration - 5: # duration変数を使用
232
- break
233
-
234
  current_seed = seed + i if seed is not None and seed != -1 else None
235
  generator = (
236
  torch.Generator(device).manual_seed(int(current_seed)) if current_seed is not None else None
@@ -276,6 +263,7 @@ def start_tryon(
276
  # 最終的な結果を返す
277
  return output_images.value, mask_gray
278
 
 
279
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
280
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
281
 
@@ -341,7 +329,7 @@ with image_blocks as demo:
341
  )
342
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
343
  num_images = gr.Slider(
344
- label="Number of Images", minimum=1, maximum=10, step=1, value=1
345
  )
346
 
347
  try_button.click(
 
23
  from preprocess.openpose.run_openpose import OpenPose
24
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
25
  from torchvision.transforms.functional import to_pil_image
26
+
27
 
28
  def pil_to_binary_mask(pil_image, threshold=0):
29
  np_image = np.array(pil_image)
 
38
  output_mask = Image.fromarray(mask)
39
  return output_mask
40
 
 
 
 
 
41
 
42
  base_path = 'yisol/IDM-VTON'
43
  example_path = os.path.join(os.path.dirname(__file__), 'example')
44
 
 
45
  unet = UNet2DConditionModel.from_pretrained(
46
  base_path,
47
  subfolder="unet",
48
  torch_dtype=torch.float16,
49
+ )
50
  unet.requires_grad_(False)
 
51
  tokenizer_one = AutoTokenizer.from_pretrained(
52
  base_path,
53
  subfolder="tokenizer",
 
66
  base_path,
67
  subfolder="text_encoder",
68
  torch_dtype=torch.float16,
69
+ )
70
  text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
71
  base_path,
72
  subfolder="text_encoder_2",
73
  torch_dtype=torch.float16,
74
+ )
75
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
76
  base_path,
77
  subfolder="image_encoder",
78
  torch_dtype=torch.float16,
79
+ )
80
  vae = AutoencoderKL.from_pretrained(
81
  base_path,
82
  subfolder="vae",
83
  torch_dtype=torch.float16,
84
+ )
85
 
86
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
87
  base_path,
88
  subfolder="unet_encoder",
89
  torch_dtype=torch.float16,
90
+ )
91
 
92
  parsing_model = Parsing(0)
93
  openpose_model = OpenPose(0)
94
 
 
 
 
 
95
  UNet_Encoder.requires_grad_(False)
96
  image_encoder.requires_grad_(False)
97
  vae.requires_grad_(False)
98
  unet.requires_grad_(False)
99
  text_encoder_one.requires_grad_(False)
100
  text_encoder_two.requires_grad_(False)
 
101
  tensor_transfrom = transforms.Compose(
102
  [
103
  transforms.ToTensor(),
 
117
  scheduler=noise_scheduler,
118
  image_encoder=image_encoder,
119
  torch_dtype=torch.float16,
120
+ )
121
  pipe.unet_encoder = UNet_Encoder
122
 
123
+
124
+ @spaces.GPU(duration=60) # 実行時間を60秒に設定
125
  def start_tryon(
126
  dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, num_images
127
  ):
128
+ device = "cuda"
129
 
130
+ openpose_model.preprocessor.body_estimation.model.to(device)
131
+ pipe.to(device)
132
+ pipe.unet_encoder.to(device)
133
 
134
  garm_img = garm_img.convert("RGB").resize((768, 1024))
135
  human_img_orig = dict["background"].convert("RGB")
 
218
  yield output_images.value, mask_gray
219
 
220
  for i in range(int(num_images)):
 
 
 
 
 
221
  current_seed = seed + i if seed is not None and seed != -1 else None
222
  generator = (
223
  torch.Generator(device).manual_seed(int(current_seed)) if current_seed is not None else None
 
263
  # 最終的な結果を返す
264
  return output_images.value, mask_gray
265
 
266
+
267
  garm_list = os.listdir(os.path.join(example_path, "cloth"))
268
  garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
269
 
 
329
  )
330
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
331
  num_images = gr.Slider(
332
+ label="Number of Images", minimum=1, maximum=10, step=1, value=1 # 最大値を10に変更
333
  )
334
 
335
  try_button.click(