File size: 7,457 Bytes
37a1e1f
 
 
2301b77
f746958
37a1e1f
 
074a81a
 
 
 
37a1e1f
 
074a81a
37a1e1f
 
36c4845
074a81a
37a1e1f
8c46c8a
 
 
074a81a
fc4c790
0417f88
 
8c46c8a
 
797e0cc
 
 
 
 
 
 
 
8c46c8a
 
87f9867
37a1e1f
 
 
720c562
2301b77
8c46c8a
242fe28
2301b77
edd39e5
5ecdabf
 
b12d0df
 
 
87f9867
720c562
 
 
 
 
 
7722e4b
8c46c8a
bef75dc
409a131
 
 
 
 
720c562
37a1e1f
ba8b2ea
 
2a4ca20
bef75dc
0e7df92
2301b77
2a4ca20
37a1e1f
720c562
86dbb58
 
 
 
12bf04e
e7da559
37a1e1f
12bf04e
 
 
 
 
fc4c790
dbedcf7
 
 
fc4c790
dbedcf7
 
 
bef75dc
2301b77
 
bef75dc
 
 
86dbb58
2301b77
fc4c790
37a1e1f
074a81a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from diffusers import StableDiffusionPipeline
from lora_diffusion import monkeypatch_lora, tune_lora_scale
import torch
import os, shutil
import gradio as gr
import subprocess

MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
INSTANCE_DIR="./data_example"
OUTPUT_DIR="./output_example"

model_id = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
#prompt = "style of sks, baby lion"
torch.manual_seed(1)
#image = pipe(prompt, num_inference_steps=50, guidance_scale= 7).images[0] #no need
#image  # nice. diffusers are cool. #no need 
#finetuned_lora_weights = "./lora_weight.pt"

#global var
counter = 0

#Getting Lora fine-tuned weights
def monkeypatching(alpha, in_prompt, wt): #, prompt, pipe): finetuned_lora_weights
    print("****** inside monkeypatching *******")
    print(f"in_prompt is - {str(in_prompt)}")
    global counter
    if counter == 0 :
        #if wt == "./lora_playgroundai_wt.pt" : 
        monkeypatch_lora(pipe.unet, torch.load(wt)) #finetuned_lora_weights
        tune_lora_scale(pipe.unet, alpha) #1.00)
        counter +=1
        #else:
            #monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
            #tune_lora_scale(pipe.unet, alpha) #1.00)
            #counter +=1
    else :
        tune_lora_scale(pipe.unet, alpha) #1.00)
    prompt = "style of hclu, " + str(in_prompt)   #"baby lion"
    image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
    image.save("./illust_lora.jpg") #"./contents/illust_lora.jpg")
    return image

def accelerate_train_lora(steps, images):
    print("*********** inside accelerate_train_lora ***********")
    # path can be retrieved by file_obj.name and original filename can be retrieved with file_obj.orig_name
    for file in images:
        shutil.copy( file.name, './data_example') #/{file.orig_name}
    #subprocess.Popen(f'accelerate launch {"./train_lora_dreambooth.py"} \
    os.system( f'accelerate launch {"./train_lora_dreambooth.py"} \
      --pretrained_model_name_or_path={MODEL_NAME}  \
      --instance_data_dir={INSTANCE_DIR} \
      --output_dir={OUTPUT_DIR} \
      --instance_prompt="style of hclu" \
      --resolution=512 \
      --train_batch_size=1 \
      --gradient_accumulation_steps=1 \
      --learning_rate=1e-4 \
      --lr_scheduler="constant" \
      --lr_warmup_steps=0 \
      --max_train_steps={int(steps)}') #,shell=True) #30000
    print("*********** completing accelerate_train_lora ***********")
    #lora_trained_weights = "./output_example/lora_weight.pt" 
    for file in os.listdir(f"{OUTPUT_DIR}"):
        if file.endswith(".pt"):
            print(os.path.join(f"{OUTPUT_DIR}", file))
            return file 
    #return f"{OUTPUT_DIR}/*.pt"
    
with gr.Blocks() as demo:
    gr.Markdown("""<h1><center>LORA - Low-rank Adaptation for Fast Text-to-Image Diffusion Fine-tuning</center></h1>
    """)
    gr.HTML("<p>You can skip the queue by duplicating this space and upgrading to gpu in settings: <a style='display:inline-block' href='https://huggingface.co/spaces/ysharma/Low-rank-Adaptation?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p>")
    gr.Markdown("""<b>NEW!!</b> : I have fine-tuned the SD model for 15,000 steps using 100 PlaygroundAI images and LORA. You can load this trained model using the example component. Load the weight and start using the Space with the Inference button. Feel free to toggle the Alpha value.""")
    gr.Markdown(
        """**Main Features**<br>- Fine-tune Stable diffusion models twice as faster as dreambooth method by Low-rank Adaptation.<br>- Get insanely small end results, easy to share and download.<br>- Easy to use, compatible with diffusers.<br>- Sometimes even better performance than full fine-tuning<br><br>Please refer to the GitHub repo this Space is based on, here - <a href = "https://github.com/cloneofsimo/lora">LORA</a>. You can also refer to this tweet by AK to quote/retweet/like here on <a href="https://twitter.com/_akhaliq/status/1601120767009513472">Twitter</a>.This Gradio Space is an attempt to explore this novel LORA approach to fine-tune Stable diffusion models, using the power and flexibility of Gradio! The higher number of steps results in longer training time and better fine-tuned SD models.<br><br><b>To use this Space well:</b><br>- First, upload your set of images (4-5), then enter the number of fine-tuning steps, and then press the 'Train LORA model' button. This will produce your fine-tuned model weights.<br>- Enter a prompt, set the alpha value using the Slider (nearer to 1 implies overfitting to the uploaded images), and then press the 'Inference' button. This will produce an image by the newly fine-tuned model.<br><b>Bonus:</b>You can download your fine-tuned model weights from the Gradio file component. The smaller size of LORA models (around 3-4 MB files) is the main highlight of this 'Low-rank Adaptation' approach of fine-tuning.""")
    
    with gr.Row():
        in_images =  gr.File(label="Upload images to fine-tune for LORA", file_count="multiple")
        with gr.Column():
            b1 = gr.Button(value="Train LORA model")
            in_prompt = gr.Textbox(label="Enter a prompt for fine-tuned LORA model", visible=True)
            b2 = gr.Button(value="Inference using LORA model")

    with gr.Row():
        out_image = gr.Image(label="Image generated by LORA model")
        with gr.Column():
            with gr.Accordion("Advance settings for Training and Inference", open=False):
                gr.Markdown("Advance settings for a number of Training Steps and Alpha. Set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitting), set alpha to a lower value. If the LORA seems to have too little effect, set the alpha higher. You can tune these two values to your needs.")
                in_steps = gr.Number(label="Enter the number of training steps", value = 4000)
                in_alpha = gr.Slider(0.1,1.0, step=0.01, label="Set Alpha level", value=0.5)
            out_file = gr.File(label="Lora trained model weights" )

    gr.Examples(
        examples=[[0.65, "lion", "./lora_playgroundai_wt.pt" ]],
        inputs=[in_alpha, in_prompt, out_file],
        outputs=out_image,
        fn=monkeypatching,
        cache_examples=True,)
    gr.Examples(
        examples=[[4000, ['./simba1.jpg', './simba2.jpg', './simba3.jpg', './simba4.jpg']]],
        inputs=[in_steps, in_images],
        outputs=out_file,
        fn=accelerate_train_lora,
        cache_examples=True,)    
        
    b1.click(fn = accelerate_train_lora, inputs=[in_steps, in_images] , outputs=out_file) 
    b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt, out_file], outputs=out_image)

demo.queue(concurrency_count=3) 
demo.launch(debug=True, show_error=True)