File size: 3,061 Bytes
0a82683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
from src.pipeline_pe_clone import FluxPipeline
import torch
from PIL import Image

def parse_args():
    parser = argparse.ArgumentParser(description='FLUX image generation with LoRA')
    parser.add_argument('--model_path', type=str, 
                        default="black-forest-labs/FLUX.1-dev",
                        help='Path to pretrained model')
    parser.add_argument('--image_path', type=str,
                        default="assets/1.png",
                        help='Input image path')
    parser.add_argument('--output_path', type=str,
                        default="output.png",
                        help='Output image path')
    parser.add_argument('--height', type=int, default=768)
    parser.add_argument('--width', type=int, default=512)
    parser.add_argument('--prompt', type=str,
                        default="add a halo and wings for the cat by sksmagiceffects",
                        help="""Different LoRA effects and their example prompts:
    - sksmagiceffects: "add a halo and wings for the cat by sksmagiceffects"
    - sksmonstercalledlulu: "add a red sksmonstercalledlulu hugging the cat"
    - skspaintingeffects: "add a yellow flower on the cat's head and psychedelic colors and dynamic flows by skspaintingeffects"
    - sksedgeeffect: "add yellow flames to the cat by sksedgeeffect"
    - skscatooneffect: "add two hands holding the cat in skscatooneffect"
    """)
    parser.add_argument('--guidance_scale', type=float, default=3.5)
    parser.add_argument('--num_steps', type=int, default=20,
                        help='Number of inference steps')
    parser.add_argument('--lora_name', type=str,
                        choices=['pretrained', 'sksmagiceffects', 'sksmonstercalledlulu', 
                                'skspaintingeffects', 'sksedgeeffect', 'skscatooneffect'],
                        default="sksmagiceffects",
                        help='Name of LoRA weights to use. Use "pretrained" for base model only')
    return parser.parse_args()

def main():
    args = parse_args()
    
    pipeline = FluxPipeline.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
    ).to('cuda')

    # Load and fuse base LoRA weights
    pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name="pretrain.safetensors")
    pipeline.fuse_lora()
    pipeline.unload_lora_weights()

    # Load selected LoRA effect only if not using pretrained
    if args.lora_name != 'pretrained':
        pipeline.load_lora_weights("nicolaus-huang/PhotoDoodle", weight_name=f"{args.lora_name}.safetensors")

    condition_image = Image.open(args.image_path).resize((args.height, args.width)).convert("RGB")

    result = pipeline(
        prompt=args.prompt,
        condition_image=condition_image,
        height=args.height,
        width=args.width,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.num_steps,
        max_sequence_length=512
    ).images[0]

    result.save(args.output_path)

if __name__ == "__main__":
    main()