File size: 3,221 Bytes
37a1e1f
 
 
 
f746958
37a1e1f
 
074a81a
 
 
 
37a1e1f
 
074a81a
37a1e1f
 
 
074a81a
37a1e1f
8c46c8a
 
 
074a81a
 
0417f88
 
8c46c8a
 
 
 
 
 
 
0417f88
37a1e1f
 
 
720c562
 
8c46c8a
e6d7a40
5ecdabf
 
b12d0df
 
 
720c562
 
 
 
 
 
 
7722e4b
8c46c8a
60bca8f
720c562
37a1e1f
 
720c562
37a1e1f
 
 
7722e4b
720c562
 
37a1e1f
0417f88
37a1e1f
60bca8f
7722e4b
074a81a
37a1e1f
074a81a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from diffusers import StableDiffusionPipeline
from lora_diffusion import monkeypatch_lora, tune_lora_scale
import torch
import os 
import gradio as gr
import subprocess

MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
INSTANCE_DIR="./data_example"
OUTPUT_DIR="./output_example"

model_id = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
#prompt = "style of sks, baby lion"
torch.manual_seed(1)
#image = pipe(prompt, num_inference_steps=50, guidance_scale= 7).images[0] #no need
#image  # nice. diffusers are cool. #no need
#finetuned_lora_weights = "./lora_weight.pt"

#global var
counter = 0

#Getting Lora fine-tuned weights
def monkeypatching(alpha, in_prompt): #, prompt, pipe): finetuned_lora_weights
    print("****** inside monkeypatching *******")
    print(f"in_prompt is - {str(in_prompt)}")
    global counter
    if counter == 0 :
        monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
        tune_lora_scale(pipe.unet, alpha) #1.00)
        counter +=1
    else :
        tune_lora_scale(pipe.unet, alpha) #1.00)
    prompt = "style of sks, " + str(in_prompt)   #"baby lion"
    image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
    image.save("./illust_lora.jpg") #"./contents/illust_lora.jpg")
    return image

def accelerate_train_lora(steps):
    print("*********** inside accelerate_train_lora ***********")
    #subprocess.run(accelerate launch {"./train_lora_dreambooth.py"} \
    #subprocess.Popen(f'accelerate launch {"./train_lora_dreambooth.py"} \
    os.system( f'accelerate launch {"./train_lora_dreambooth.py"} \
      --pretrained_model_name_or_path={MODEL_NAME}  \
      --instance_data_dir={INSTANCE_DIR} \
      --output_dir={OUTPUT_DIR} \
      --instance_prompt="style of sks" \
      --resolution=512 \
      --train_batch_size=1 \
      --gradient_accumulation_steps=1 \
      --learning_rate=1e-4 \
      --lr_scheduler="constant" \
      --lr_warmup_steps=0 \
      --max_train_steps={int(steps)}') #,shell=True) #30000
    print("*********** completing accelerate_train_lora ***********")
    return "./output_example/lora_weight.pt"
    
with gr.Blocks() as demo:
    with gr.Row():
        in_images =  gr.File(label="Upload images to fine-tune for LORA", file_count="multiple")
        #in_prompt = gr.Textbox(label="Enter a ")
        in_steps = gr.Number(label="Enter number of steps")
        in_alpha = gr.Slider(0.1,1.0, step=0.01, label="Set Alpha level - higher value has more chances to overfit")
    with gr.Row():
        b1 = gr.Button(value="Train LORA model")
        b2 = gr.Button(value="Inference using LORA model")
    with gr.Row():
        in_prompt = gr.Textbox(label="Enter a prompt for fine-tuned LORA model", visible=True)
        out_image = gr.Image(label="Image generated by LORA model")
        out_file = gr.File(label="Lora trained model weights", )
    b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
    b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt], outputs=out_image)

demo.queue(concurrency_count=3) 
demo.launch(debug=True, show_error=True)