ysharma HF staff commited on
Commit
074a81a
·
1 Parent(s): 8c46c8a
Files changed (1) hide show
  1. app.py +13 -21
app.py CHANGED
@@ -1,38 +1,27 @@
1
- #https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
2
- #export
3
- MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
4
- #export
5
- INSTANCE_DIR="./data_example"
6
- #export
7
- OUTPUT_DIR="./output_example"
8
-
9
-
10
  from diffusers import StableDiffusionPipeline
11
  from lora_diffusion import monkeypatch_lora, tune_lora_scale
12
  import torch
13
  import os
14
  import gradio as gr
15
- #os.system('python file.py')
16
  import subprocess
17
- # If your shell script has shebang,
18
- # you can omit shell=True argument.
19
- #subprocess.run("./run_lora_db.sh", shell=True)
20
 
21
- #####
 
 
 
22
  model_id = "stabilityai/stable-diffusion-2-1-base"
23
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
24
- prompt = "style of sks, baby lion"
25
  torch.manual_seed(1)
26
  #image = pipe(prompt, num_inference_steps=50, guidance_scale= 7).images[0] #no need
27
  #image # nice. diffusers are cool. #no need
28
- finetuned_lora_weights = "./lora_weight.pt"
29
 
30
  #global var
31
  counter = 0
32
 
33
- #####
34
- #my fine tuned weights
35
- def monkeypatching(alpha): #, prompt, pipe): finetuned_lora_weights
36
  global counter
37
  if counter == 0 :
38
  monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
@@ -40,6 +29,7 @@ def monkeypatching(alpha): #, prompt, pipe): finetuned_lora_weights
40
  counter +=1
41
  else :
42
  tune_lora_scale(pipe.unet, alpha) #1.00)
 
43
  image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
44
  image.save("./illust_lora.jpg") #"./contents/illust_lora.jpg")
45
  return image
@@ -73,9 +63,11 @@ with gr.Blocks() as demo:
73
  b1 = gr.Button(value="Train LORA model")
74
  b2 = gr.Button(value="Inference using LORA model")
75
  with gr.Row():
 
76
  out_image = gr.Image(label="Image generated by LORA model")
77
  out_file = gr.File(label="Lora trained model weights")
78
  b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
79
- b2.click(fn = monkeypatching, inputs=in_alpha, outputs=out_image)
80
 
81
- demo.launch(debug=True, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
1
  from diffusers import StableDiffusionPipeline
2
  from lora_diffusion import monkeypatch_lora, tune_lora_scale
3
  import torch
4
  import os
5
  import gradio as gr
 
6
  import subprocess
 
 
 
7
 
8
+ MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
9
+ INSTANCE_DIR="./data_example"
10
+ OUTPUT_DIR="./output_example"
11
+
12
  model_id = "stabilityai/stable-diffusion-2-1-base"
13
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
14
+ #prompt = "style of sks, baby lion"
15
  torch.manual_seed(1)
16
  #image = pipe(prompt, num_inference_steps=50, guidance_scale= 7).images[0] #no need
17
  #image # nice. diffusers are cool. #no need
18
+ #finetuned_lora_weights = "./lora_weight.pt"
19
 
20
  #global var
21
  counter = 0
22
 
23
+ #Getting Lora fine-tuned weights
24
+ def monkeypatching(alpha, in_prompt): #, prompt, pipe): finetuned_lora_weights
 
25
  global counter
26
  if counter == 0 :
27
  monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
 
29
  counter +=1
30
  else :
31
  tune_lora_scale(pipe.unet, alpha) #1.00)
32
+ prompt = "style of sks, " + in_prompt #"baby lion"
33
  image = pipe(prompt, num_inference_steps=50, guidance_scale=7).images[0]
34
  image.save("./illust_lora.jpg") #"./contents/illust_lora.jpg")
35
  return image
 
63
  b1 = gr.Button(value="Train LORA model")
64
  b2 = gr.Button(value="Inference using LORA model")
65
  with gr.Row():
66
+ in_prompt = gr.Textbox(label="Enter a prompt for fine-tuned LORA model")
67
  out_image = gr.Image(label="Image generated by LORA model")
68
  out_file = gr.File(label="Lora trained model weights")
69
  b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
70
+ b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt], outputs=out_image)
71
 
72
+ demo.queue(concurrency_count=3)
73
+ demo.launch(debug=True, show_error=True)