yeq6x commited on
Commit
d086933
·
1 Parent(s): 79cb9f3

rm .to(device)

Browse files
Files changed (1) hide show
  1. scripts/process_utils.py +9 -10
scripts/process_utils.py CHANGED
@@ -54,7 +54,6 @@ def load_lora(pipeline, lora_path, alpha=0.75):
54
 
55
  def initialize_sotai_model():
56
  global device, torch_dtype
57
- print(f"Device: {device}, torch_dtype: {torch_dtype}")
58
 
59
  sotai_sd_model_path = get_file_path(os.environ["sotai_sd_model_name"], subfolder=os.environ["sd_models_dir"])
60
  controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
@@ -66,19 +65,19 @@ def initialize_sotai_model():
66
  sotai_sd_model_path,
67
  torch_dtype=torch_dtype,
68
  use_safetensors=True
69
- ).to(device)
70
 
71
  # Load the ControlNet model
72
  controlnet1 = ControlNetModel.from_single_file(
73
  controlnet_path1,
74
  torch_dtype=torch_dtype
75
- ).to(device)
76
 
77
  # Load the ControlNet model
78
  controlnet2 = ControlNetModel.from_single_file(
79
  controlnet_path2,
80
  torch_dtype=torch_dtype
81
- ).to(device)
82
 
83
  # Create the ControlNet pipeline
84
  sotai_gen_pipe = StableDiffusionControlNetPipeline(
@@ -90,7 +89,7 @@ def initialize_sotai_model():
90
  safety_checker=sd_pipe.safety_checker,
91
  feature_extractor=sd_pipe.feature_extractor,
92
  controlnet=[controlnet1, controlnet2]
93
- ).to(device)
94
 
95
  # LoRAの適用
96
  lora_names = [
@@ -120,23 +119,23 @@ def initialize_refine_model():
120
  refine_sd_model_path,
121
  torch_dtype=torch_dtype,
122
  use_safetensors=True
123
- ).to(device)
124
 
125
  # controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
126
  controlnet1 = ControlNetModel.from_single_file(
127
  controlnet_path3,
128
  torch_dtype=torch_dtype
129
- ).to(device)
130
 
131
  # Load the ControlNet model
132
  controlnet2 = ControlNetModel.from_single_file(
133
  controlnet_path4,
134
  torch_dtype=torch_dtype
135
- ).to(device)
136
 
137
  # Create the ControlNet pipeline
138
  refine_gen_pipe = StableDiffusionControlNetPipeline(
139
- vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype).to(device),
140
  text_encoder=sd_pipe.text_encoder,
141
  tokenizer=sd_pipe.tokenizer,
142
  unet=sd_pipe.unet,
@@ -144,7 +143,7 @@ def initialize_refine_model():
144
  safety_checker=sd_pipe.safety_checker,
145
  feature_extractor=sd_pipe.feature_extractor,
146
  controlnet=[controlnet1, controlnet2], # 複数のControlNetを指定
147
- ).to(device)
148
 
149
  # スケジューラーの設定
150
  refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)
 
54
 
55
  def initialize_sotai_model():
56
  global device, torch_dtype
 
57
 
58
  sotai_sd_model_path = get_file_path(os.environ["sotai_sd_model_name"], subfolder=os.environ["sd_models_dir"])
59
  controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
 
65
  sotai_sd_model_path,
66
  torch_dtype=torch_dtype,
67
  use_safetensors=True
68
+ )
69
 
70
  # Load the ControlNet model
71
  controlnet1 = ControlNetModel.from_single_file(
72
  controlnet_path1,
73
  torch_dtype=torch_dtype
74
+ )
75
 
76
  # Load the ControlNet model
77
  controlnet2 = ControlNetModel.from_single_file(
78
  controlnet_path2,
79
  torch_dtype=torch_dtype
80
+ )
81
 
82
  # Create the ControlNet pipeline
83
  sotai_gen_pipe = StableDiffusionControlNetPipeline(
 
89
  safety_checker=sd_pipe.safety_checker,
90
  feature_extractor=sd_pipe.feature_extractor,
91
  controlnet=[controlnet1, controlnet2]
92
+ )
93
 
94
  # LoRAの適用
95
  lora_names = [
 
119
  refine_sd_model_path,
120
  torch_dtype=torch_dtype,
121
  use_safetensors=True
122
+ )
123
 
124
  # controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
125
  controlnet1 = ControlNetModel.from_single_file(
126
  controlnet_path3,
127
  torch_dtype=torch_dtype
128
+ )
129
 
130
  # Load the ControlNet model
131
  controlnet2 = ControlNetModel.from_single_file(
132
  controlnet_path4,
133
  torch_dtype=torch_dtype
134
+ )
135
 
136
  # Create the ControlNet pipeline
137
  refine_gen_pipe = StableDiffusionControlNetPipeline(
138
+ vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype),
139
  text_encoder=sd_pipe.text_encoder,
140
  tokenizer=sd_pipe.tokenizer,
141
  unet=sd_pipe.unet,
 
143
  safety_checker=sd_pipe.safety_checker,
144
  feature_extractor=sd_pipe.feature_extractor,
145
  controlnet=[controlnet1, controlnet2], # 複数のControlNetを指定
146
+ )
147
 
148
  # スケジューラーの設定
149
  refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)