Spaces:
Configuration error
Configuration error
File size: 3,221 Bytes
37a1e1f f746958 37a1e1f 074a81a 37a1e1f 074a81a 37a1e1f 074a81a 37a1e1f 8c46c8a 074a81a 0417f88 8c46c8a 0417f88 37a1e1f 720c562 8c46c8a e6d7a40 5ecdabf b12d0df 720c562 7722e4b 8c46c8a 60bca8f 720c562 37a1e1f 720c562 37a1e1f 7722e4b 720c562 37a1e1f 0417f88 37a1e1f 60bca8f 7722e4b 074a81a 37a1e1f 074a81a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from diffusers import StableDiffusionPipeline
from lora_diffusion import monkeypatch_lora, tune_lora_scale
import torch
import os
import gradio as gr
import subprocess
MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
INSTANCE_DIR="./data_example"
OUTPUT_DIR="./output_example"
model_id = "stabilityai/stable-diffusion-2-1-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
#prompt = "style of sks, baby lion"
torch.manual_seed(1)
#image = pipe(prompt, num_inference_steps=50, guidance_scale= 7).images[0] #no need
#image # nice. diffusers are cool. #no need
#finetuned_lora_weights = "./lora_weight.pt"
#global var
counter = 0
#Getting Lora fine-tuned weights
def monkeypatching(alpha, in_prompt): #, prompt, pipe): finetuned_lora_weights
print("****** inside monkeypatching *******")
print(f"in_prompt is - {str(in_prompt)}")
global counter
if counter == 0 :
monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
tune_lora_scale(pipe.unet, alpha) #1.00)
counter +=1
else :
tune_lora_scale(pipe.unet, alpha) #1.00)
prompt = "style of sks, " + str(in_prompt) #"baby lion"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
image.save("./illust_lora.jpg") #"./contents/illust_lora.jpg")
return image
def accelerate_train_lora(steps):
print("*********** inside accelerate_train_lora ***********")
#subprocess.run(accelerate launch {"./train_lora_dreambooth.py"} \
#subprocess.Popen(f'accelerate launch {"./train_lora_dreambooth.py"} \
os.system( f'accelerate launch {"./train_lora_dreambooth.py"} \
--pretrained_model_name_or_path={MODEL_NAME} \
--instance_data_dir={INSTANCE_DIR} \
--output_dir={OUTPUT_DIR} \
--instance_prompt="style of sks" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps={int(steps)}') #,shell=True) #30000
print("*********** completing accelerate_train_lora ***********")
return "./output_example/lora_weight.pt"
with gr.Blocks() as demo:
with gr.Row():
in_images = gr.File(label="Upload images to fine-tune for LORA", file_count="multiple")
#in_prompt = gr.Textbox(label="Enter a ")
in_steps = gr.Number(label="Enter number of steps")
in_alpha = gr.Slider(0.1,1.0, step=0.01, label="Set Alpha level - higher value has more chances to overfit")
with gr.Row():
b1 = gr.Button(value="Train LORA model")
b2 = gr.Button(value="Inference using LORA model")
with gr.Row():
in_prompt = gr.Textbox(label="Enter a prompt for fine-tuned LORA model", visible=True)
out_image = gr.Image(label="Image generated by LORA model")
out_file = gr.File(label="Lora trained model weights", )
b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt], outputs=out_image)
demo.queue(concurrency_count=3)
demo.launch(debug=True, show_error=True) |