HighCWu commited on
Commit
917e749
·
1 Parent(s): ebb7aea

add peft, set_adapter

Browse files
Files changed (4) hide show
  1. model.py +2 -2
  2. pipeline.py +1 -1
  3. requirements.txt +1 -0
  4. unet.py +5 -3
model.py CHANGED
@@ -44,7 +44,7 @@ class Model:
44
  and self.pipe is not None
45
  ):
46
  unet: UNet2DConditionModelEx = self.pipe.unet
47
- unet.activate_adapters([task_name])
48
  return self.pipe
49
  unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
50
  base_model_id, subfolder="unet", torch_dtype=torch.float16
@@ -82,7 +82,7 @@ class Model:
82
  if task_name == self.task_name:
83
  return
84
  unet: UNet2DConditionModelEx = self.pipe.unet
85
- unet.activate_adapters([task_name])
86
  self.task_name = task_name
87
 
88
  def get_prompt(self, prompt: str, additional_prompt: str) -> str:
 
44
  and self.pipe is not None
45
  ):
46
  unet: UNet2DConditionModelEx = self.pipe.unet
47
+ unet.set_adapter(task_name)
48
  return self.pipe
49
  unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
50
  base_model_id, subfolder="unet", torch_dtype=torch.float16
 
82
  if task_name == self.task_name:
83
  return
84
  unet: UNet2DConditionModelEx = self.pipe.unet
85
+ unet.set_adapter(task_name)
86
  self.task_name = task_name
87
 
88
  def get_prompt(self, prompt: str, additional_prompt: str) -> str:
pipeline.py CHANGED
@@ -949,7 +949,7 @@ class StableDiffusionControlLoraV3Pipeline(
949
  if adapter_name_ori is not None:
950
  break
951
 
952
- unet.activate_adapters()
953
 
954
  @torch.no_grad()
955
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
949
  if adapter_name_ori is not None:
950
  break
951
 
952
+ unet.activate_extra_condition_adapters()
953
 
954
  @torch.no_grad()
955
  @replace_example_docstring(EXAMPLE_DOC_STRING)
requirements.txt CHANGED
@@ -6,6 +6,7 @@ gradio>=4.26.0
6
  huggingface-hub>=0.16.4
7
  mediapipe>=0.10.1
8
  opencv-python-headless>=4.8.0.74
 
9
  safetensors>=0.3.1
10
  torch>=2.0.1
11
  torchvision>=0.15.2
 
6
  huggingface-hub>=0.16.4
7
  mediapipe>=0.10.1
8
  opencv-python-headless>=4.8.0.74
9
+ peft>=0.11.1
10
  safetensors>=0.3.1
11
  torch>=2.0.1
12
  torchvision>=0.15.2
unet.py CHANGED
@@ -145,11 +145,13 @@ class UNet2DConditionModelEx(UNet2DConditionModel):
145
 
146
  return self
147
 
148
- def activate_adapters(self, adapter_names: Union[List[str], None] = None):
149
  lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
150
  for lora_layer in lora_layers:
151
- _adapter_names = adapter_names or list(lora_layer.scaling.keys())
152
- lora_layer.set_adapter(_adapter_names)
 
 
153
 
154
  def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0):
155
  if isinstance(scale, float):
 
145
 
146
  return self
147
 
148
+ def activate_extra_condition_adapters(self):
149
  lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)]
150
  for lora_layer in lora_layers:
151
+ adapter_names = [k for k in lora_layer.scaling.keys() if k in self.config.extra_condition_names]
152
+ adapter_names += lora_layer.active_adapters
153
+ adapter_names = list(set(adapter_names))
154
+ lora_layer.set_adapter(adapter_names)
155
 
156
  def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0):
157
  if isinstance(scale, float):