ysharma HF staff commited on
Commit
fc4c790
·
1 Parent(s): ef27330

update example

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -21,17 +21,17 @@ torch.manual_seed(1)
21
  counter = 0
22
 
23
  #Getting Lora fine-tuned weights
24
- def monkeypatching(alpha, in_prompt, example_wt): #, prompt, pipe): finetuned_lora_weights
25
  print("****** inside monkeypatching *******")
26
  print(f"in_prompt is - {str(in_prompt)}")
27
  global counter
28
  if counter == 0 :
29
- if example_wt is None :
30
- monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
31
  tune_lora_scale(pipe.unet, alpha) #1.00)
32
  counter +=1
33
  else:
34
- monkeypatch_lora(pipe.unet, torch.load(example_wt)) #finetuned_lora_weights
35
  tune_lora_scale(pipe.unet, alpha) #1.00)
36
  counter +=1
37
  else :
@@ -81,17 +81,17 @@ with gr.Blocks() as demo:
81
  gr.Markdown("Advance settings for a number of Training Steps and Alpha. Set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitting), set alpha to a lower value. If the LORA seems to have too little effect, set the alpha higher. You can tune these two values to your needs.")
82
  in_steps = gr.Number(label="Enter the number of training steps", value = 4000)
83
  in_alpha = gr.Slider(0.1,1.0, step=0.01, label="Set Alpha level", value=0.5)
84
- out_file = gr.File(label="Lora trained model weights", )
85
 
86
  gr.Examples(
87
  examples=[[0.65, "lion", "./lora_playgroundai_wt.pt" ]],
88
- inputs=[in_alpha, in_prompt, example_wt],
89
  outputs=out_image,
90
  fn=monkeypatching,
91
  cache_examples=True,)
92
 
93
  b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
94
- b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt, example_wt], outputs=out_image)
95
 
96
  demo.queue(concurrency_count=3)
97
  demo.launch(debug=True, show_error=True)
 
21
  counter = 0
22
 
23
  #Getting Lora fine-tuned weights
24
+ def monkeypatching(alpha, in_prompt, wt): #, prompt, pipe): finetuned_lora_weights
25
  print("****** inside monkeypatching *******")
26
  print(f"in_prompt is - {str(in_prompt)}")
27
  global counter
28
  if counter == 0 :
29
+ if wt == "./lora_playgroundai_wt.pt" :
30
+ monkeypatch_lora(pipe.unet, torch.load(wt)) #finetuned_lora_weights
31
  tune_lora_scale(pipe.unet, alpha) #1.00)
32
  counter +=1
33
  else:
34
+ monkeypatch_lora(pipe.unet, torch.load("./output_example/lora_weight.pt")) #finetuned_lora_weights
35
  tune_lora_scale(pipe.unet, alpha) #1.00)
36
  counter +=1
37
  else :
 
81
  gr.Markdown("Advance settings for a number of Training Steps and Alpha. Set alpha to 1.0 to fully add LORA. If the LORA seems to have too much effect (i.e., overfitting), set alpha to a lower value. If the LORA seems to have too little effect, set the alpha higher. You can tune these two values to your needs.")
82
  in_steps = gr.Number(label="Enter the number of training steps", value = 4000)
83
  in_alpha = gr.Slider(0.1,1.0, step=0.01, label="Set Alpha level", value=0.5)
84
+ out_file = gr.File(label="Lora trained model weights" )
85
 
86
  gr.Examples(
87
  examples=[[0.65, "lion", "./lora_playgroundai_wt.pt" ]],
88
+ inputs=[in_alpha, in_prompt, out_file],
89
  outputs=out_image,
90
  fn=monkeypatching,
91
  cache_examples=True,)
92
 
93
  b1.click(fn = accelerate_train_lora, inputs=in_steps, outputs=out_file)
94
+ b2.click(fn = monkeypatching, inputs=[in_alpha, in_prompt, out_file], outputs=out_image)
95
 
96
  demo.queue(concurrency_count=3)
97
  demo.launch(debug=True, show_error=True)