PseudoTerminal X commited on
Commit
17b9398
·
verified ·
1 Parent(s): b2a499d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +80 -20
README.md CHANGED
@@ -126,27 +126,87 @@ You may reuse the base model text encoder for inference.
126
 
127
 
128
  ```python
 
129
  import torch
130
- from diffusers import DiffusionPipeline
131
  from lycoris import create_lycoris_from_weights
132
-
133
- model_id = 'black-forest-labs/FLUX.1-dev'
134
- adapter_id = 'pytorch_lora_weights.safetensors' # you will have to download this manually
135
- lora_scale = 1.0
136
- wrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipeline.transformer)
137
- wrapper.merge_to()
138
-
139
- prompt = "A photo-realistic image of a cat"
140
-
141
- pipeline.to('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
142
- image = pipeline(
143
- prompt=prompt,
144
- num_inference_steps=20,
145
- generator=torch.Generator(device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu').manual_seed(1641421826),
146
- width=1776,
147
- height=512,
148
- guidance_scale=3.0,
149
- ).images[0]
150
- image.save("output.png", format="PNG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  ```
152
 
 
126
 
127
 
128
  ```python
129
+ import argparse
130
  import torch
131
+ from helpers.models.flux.pipeline import FluxPipeline as DiffusionPipeline
132
  from lycoris import create_lycoris_from_weights
133
+ from huggingface_hub import hf_hub_download
134
+
135
+ def generate_image(pipeline, prompt, output_file, num_inference_steps, width, height, guidance_scale, seed, device):
136
+ # Set device
137
+ pipeline.to(device)
138
+
139
+ # Generate image
140
+ generator = torch.Generator(device=device).manual_seed(seed)
141
+ image = pipeline(
142
+ prompt=prompt,
143
+ num_inference_steps=num_inference_steps,
144
+ generator=generator,
145
+ width=width,
146
+ height=height,
147
+ guidance_scale=guidance_scale,
148
+ ).images[0]
149
+
150
+ # Save image
151
+ output_file = "output.png"
152
+ image.save(output_file, format="PNG")
153
+ print(f"Image saved as {output_file}")
154
+
155
+ def main():
156
+ parser = argparse.ArgumentParser(description="Generate images using a custom diffusion pipeline with LoRA weights.")
157
+ parser.add_argument("--model_id", type=str, default='black-forest-labs/FLUX.1-dev', help="Model ID from Hugging Face Hub.")
158
+ parser.add_argument("--adapter_id", type=str, required=True, help="LoRA weights file.")
159
+ parser.add_argument("--lora_scale", type=float, default=1.0, help="Scale for LoRA weights.")
160
+ parser.add_argument("--output_file", type=str, default="output.png", help="Output file name for the generated image.")
161
+ parser.add_argument("--num_inference_steps", type=int, default=30, help="Number of inference steps.")
162
+ parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for the generation.")
163
+ parser.add_argument("--seed", type=int, default=1641421826, help="Random seed for reproducibility.")
164
+ parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu', help="Device to run the model on.")
165
+
166
+ args = parser.parse_args()
167
+
168
+ # Load model and weights
169
+ hf_hub_download(repo_id="terminusresearch/flux-lokr-garfield-nomask", filename=args.adapter_id, local_dir="./")
170
+ pipeline = DiffusionPipeline.from_pretrained(args.model_id, torch_dtype=torch.bfloat16)
171
+
172
+ # Apply LoRA weights
173
+ wrapper, _ = create_lycoris_from_weights(args.lora_scale, args.adapter_id, pipeline.transformer)
174
+ wrapper.merge_to()
175
+
176
+ print("Model loaded successfully. Ready to generate images.")
177
+
178
+ while True:
179
+ user_input = input("Enter a prompt or 'quit' to exit: ")
180
+ if user_input.lower() == 'quit':
181
+ break
182
+
183
+ # Check for resolution command
184
+ if user_input.startswith("resolution:"):
185
+ resolution = user_input.split(":")[1]
186
+ width, height = map(int, resolution.split("x"))
187
+ print(f"Resolution set to {width}x{height}")
188
+ continue
189
+
190
+ prompt = user_input
191
+ output_file = args.output_file.replace(".png", f"_{prompt.replace(' ', '_')}.png")
192
+
193
+ # Use default or previously set resolution
194
+ width = locals().get('width', 1024)
195
+ height = locals().get('height', 1024)
196
+
197
+ generate_image(
198
+ pipeline=pipeline,
199
+ prompt=prompt,
200
+ output_file=output_file,
201
+ num_inference_steps=args.num_inference_steps,
202
+ width=width,
203
+ height=height,
204
+ guidance_scale=args.guidance_scale,
205
+ seed=args.seed,
206
+ device=args.device
207
+ )
208
+
209
+ if __name__ == "__main__":
210
+ main()
211
  ```
212