hysts commited on
Commit
4bd7dce
·
1 Parent(s): f689e96

Allow changing LoRA scaling alpha

Browse files
Files changed (2) hide show
  1. app_inference.py +6 -0
  2. inference.py +8 -4
app_inference.py CHANGED
@@ -99,6 +99,11 @@ def create_inference_demo(pipe: InferencePipeline,
99
  max_lines=1,
100
  placeholder='Example: "A picture of a sks dog in a bucket"'
101
  )
 
 
 
 
 
102
  seed = gr.Slider(label='Seed',
103
  minimum=0,
104
  maximum=100000,
@@ -149,6 +154,7 @@ def create_inference_demo(pipe: InferencePipeline,
149
  inputs = [
150
  lora_model_id,
151
  prompt,
 
152
  seed,
153
  num_steps,
154
  guidance_scale,
 
99
  max_lines=1,
100
  placeholder='Example: "A picture of a sks dog in a bucket"'
101
  )
102
+ alpha = gr.Slider(label='LoRA alpha',
103
+ minimum=0,
104
+ maximum=2,
105
+ step=0.05,
106
+ value=1)
107
  seed = gr.Slider(label='Seed',
108
  minimum=0,
109
  maximum=100000,
 
154
  inputs = [
155
  lora_model_id,
156
  prompt,
157
+ alpha,
158
  seed,
159
  num_steps,
160
  guidance_scale,
inference.py CHANGED
@@ -73,6 +73,7 @@ class InferencePipeline:
73
  self,
74
  lora_model_id: str,
75
  prompt: str,
 
76
  seed: int,
77
  n_steps: int,
78
  guidance_scale: float,
@@ -83,8 +84,11 @@ class InferencePipeline:
83
  self.load_pipe(lora_model_id)
84
 
85
  generator = torch.Generator(device=self.device).manual_seed(seed)
86
- out = self.pipe(prompt,
87
- num_inference_steps=n_steps,
88
- guidance_scale=guidance_scale,
89
- generator=generator) # type: ignore
 
 
 
90
  return out.images[0]
 
73
  self,
74
  lora_model_id: str,
75
  prompt: str,
76
+ lora_scale: float,
77
  seed: int,
78
  n_steps: int,
79
  guidance_scale: float,
 
84
  self.load_pipe(lora_model_id)
85
 
86
  generator = torch.Generator(device=self.device).manual_seed(seed)
87
+ out = self.pipe(
88
+ prompt,
89
+ num_inference_steps=n_steps,
90
+ guidance_scale=guidance_scale,
91
+ generator=generator,
92
+ cross_attention_kwargs={'scale': lora_scale},
93
+ ) # type: ignore
94
  return out.images[0]