YaArtemNosenko commited on
Commit
bc604aa
·
verified ·
1 Parent(s): e8eaea5

[DEBUG] Dedug custom nn

Browse files
Files changed (1) hide show
  1. app.py +50 -30
app.py CHANGED
@@ -41,6 +41,7 @@ def load_pipeline(model_id,
41
  if model_id in model_cache:
42
  return model_cache[model_id]
43
 
 
44
  if controlnet_checkbox:
45
  if controlnet_mode == "depth_map":
46
  controlnet = ControlNetModel.from_pretrained(
@@ -72,39 +73,58 @@ def load_pipeline(model_id,
72
  cache_dir="./models_cache",
73
  torch_dtype=torch_dtype
74
  )
75
- pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
76
- controlnet=controlnet,
77
- torch_dtype=torch_dtype,
78
- safety_checker=None).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # params['image'] = controlnet_image
80
  # params['controlnet_conditioning_scale'] = float(controlnet_strength)
81
  else:
82
- pipe = StableDiffusionPipeline.from_pretrained(model_id,
83
- torch_dtype=torch_dtype,
84
- safety_checker=None).to(device)
85
-
86
- if model_id == "YaArtemNosenko/dino_stickers":
87
- # Use the specified base model for your LoRA adapter.
88
- base_model = "CompVis/stable-diffusion-v1-4"
89
- # Load the LoRA weights
90
- pipe.unet = PeftModel.from_pretrained(
91
- pipe.unet,
92
- model_id,
93
- subfolder="unet",
94
- torch_dtype=torch_dtype
95
- )
96
- pipe.text_encoder = PeftModel.from_pretrained(
97
- pipe.text_encoder,
98
- model_id,
99
- subfolder="text_encoder",
100
- torch_dtype=torch_dtype
101
- )
102
- pipe.unet.load_state_dict({k: lora_scale * v for k, v in pipe.unet.state_dict().items()})
103
- pipe.text_encoder.load_state_dict({k: lora_scale * v for k, v in pipe.text_encoder.state_dict().items()})
104
- else:
105
- pipe = StableDiffusionPipeline.from_pretrained(model_id,
106
- torch_dtype=torch_dtype
107
- )
108
 
109
  if ip_adapter_checkbox:
110
  pipe.load_ip_adapter("h94/IP-Adapter",
 
41
  if model_id in model_cache:
42
  return model_cache[model_id]
43
 
44
+
45
  if controlnet_checkbox:
46
  if controlnet_mode == "depth_map":
47
  controlnet = ControlNetModel.from_pretrained(
 
73
  cache_dir="./models_cache",
74
  torch_dtype=torch_dtype
75
  )
76
+ if model_id == "YaArtemNosenko/dino_stickers":
77
+ # Use the specified base model for your LoRA adapter.
78
+ base_model = "CompVis/stable-diffusion-v1-4"
79
+ # Load the LoRA weights
80
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model,
81
+ controlnet=controlnet,
82
+ torch_dtype=torch_dtype,
83
+ safety_checker=None).to(device)
84
+ pipe.unet = PeftModel.from_pretrained(
85
+ pipe.unet,
86
+ model_id,
87
+ subfolder="unet",
88
+ torch_dtype=torch_dtype
89
+ )
90
+ pipe.text_encoder = PeftModel.from_pretrained(
91
+ pipe.text_encoder,
92
+ model_id,
93
+ subfolder="text_encoder",
94
+ torch_dtype=torch_dtype
95
+ )
96
+ else:
97
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
98
+ controlnet=controlnet,
99
+ torch_dtype=torch_dtype,
100
+ safety_checker=None).to(device)
101
  # params['image'] = controlnet_image
102
  # params['controlnet_conditioning_scale'] = float(controlnet_strength)
103
  else:
104
+ if model_id == "YaArtemNosenko/dino_stickers":
105
+ base_model = "CompVis/stable-diffusion-v1-4"
106
+ pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
107
+ # Load the LoRA weights
108
+ pipe.unet = PeftModel.from_pretrained(
109
+ pipe.unet,
110
+ model_id,
111
+ subfolder="unet",
112
+ torch_dtype=torch_dtype
113
+ )
114
+
115
+ pipe.text_encoder = PeftModel.from_pretrained(
116
+ pipe.text_encoder,
117
+ model_id,
118
+ subfolder="text_encoder",
119
+ torch_dtype=torch_dtype
120
+ )
121
+ else:
122
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
123
+ torch_dtype=torch_dtype,
124
+ safety_checker=None).to(device)
125
+
126
+ pipe.unet.load_state_dict({k: lora_scale * v for k, v in pipe.unet.state_dict().items()})
127
+ pipe.text_encoder.load_state_dict({k: lora_scale * v for k, v in pipe.text_encoder.state_dict().items()})
 
 
128
 
129
  if ip_adapter_checkbox:
130
  pipe.load_ip_adapter("h94/IP-Adapter",