diffusion / app.py
YaArtemNosenko's picture
[DEBUG] Dedug custom nn
bc604aa verified
raw
history blame
13.8 kB
import gradio as gr
import numpy as np
import random
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
from peft import PeftModel, PeftConfig
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# Model list including your LoRA model
MODEL_LIST = [
"CompVis/stable-diffusion-v1-4",
"stabilityai/sdxl-turbo",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1",
"YaArtemNosenko/dino_stickers",
]
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# Cache to avoid re-initializing pipelines repeatedly
model_cache = {}
def load_pipeline(model_id,
lora_scale,
controlnet_checkbox,
controlnet_mode,
ip_adapter_checkbox,
ip_adapter_scale
):
"""
Loads or retrieves a cached DiffusionPipeline.
If the chosen model is your LoRA adapter, then load the base model
(CompVis/stable-diffusion-v1-4) and apply the LoRA weights.
"""
if model_id in model_cache:
return model_cache[model_id]
if controlnet_checkbox:
if controlnet_mode == "depth_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-depth",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "pose_estimation":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "normal_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-normal",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "scribbles":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-scribble",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
else:
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
if model_id == "YaArtemNosenko/dino_stickers":
# Use the specified base model for your LoRA adapter.
base_model = "CompVis/stable-diffusion-v1-4"
# Load the LoRA weights
pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
pipe.unet = PeftModel.from_pretrained(
pipe.unet,
model_id,
subfolder="unet",
torch_dtype=torch_dtype
)
pipe.text_encoder = PeftModel.from_pretrained(
pipe.text_encoder,
model_id,
subfolder="text_encoder",
torch_dtype=torch_dtype
)
else:
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
# params['image'] = controlnet_image
# params['controlnet_conditioning_scale'] = float(controlnet_strength)
else:
if model_id == "YaArtemNosenko/dino_stickers":
base_model = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
# Load the LoRA weights
pipe.unet = PeftModel.from_pretrained(
pipe.unet,
model_id,
subfolder="unet",
torch_dtype=torch_dtype
)
pipe.text_encoder = PeftModel.from_pretrained(
pipe.text_encoder,
model_id,
subfolder="text_encoder",
torch_dtype=torch_dtype
)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
pipe.unet.load_state_dict({k: lora_scale * v for k, v in pipe.unet.state_dict().items()})
pipe.text_encoder.load_state_dict({k: lora_scale * v for k, v in pipe.text_encoder.state_dict().items()})
if ip_adapter_checkbox:
pipe.load_ip_adapter("h94/IP-Adapter",
subfolder="models",
weight_name="ip-adapter-plus_sd15.bin"
)
pipe.set_ip_adapter_scale(ip_adapter_scale)
# params['ip_adapter_image'] = ip_adapter_image
pipe.to(device)
model_cache[model_id] = pipe
return pipe
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(
model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_scale, # New parameter for adjusting LoRA scale
controlnet_checkbox=False, # используем ли мы controlnet
controlnet_conditioning_scale=0.0, # вес для controlnet
controlnet_mode="edge_detection", # вариант controlnet
controlnet_image=None, # картинка для controlnet
ip_adapter_checkbox=False, # используется ли ip адаптера
ip_adapter_scale=0.0, # вес для ip адаптера
ip_adapter_image=None, # картинка для ip адаптера
progress=gr.Progress(track_tqdm=True),
):
# Load the pipeline for the chosen model
generator = torch.Generator(device=device).manual_seed(seed)
params = {'prompt': prompt,
'negative_prompt': negative_prompt,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'width': width,
'height': height,
'generator': generator
}
pipe = load_pipeline(model_id,
lora_scale,
controlnet_checkbox,
controlnet_mode,
ip_adapter_checkbox,
ip_adapter_scale
)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# If using the LoRA model, update the LoRA scale if supported.
# if model_id == "YaArtemNosenko/dino_stickers":
# # This assumes your pipeline's unet has a method to update the LoRA scale.
# if hasattr(pipe.unet, "set_lora_scale"):
# pipe.unet.set_lora_scale(lora_scale)
# else:
# print("Warning: LoRA scale adjustment method not found on UNet.")
# если используем controlnet
if controlnet_checkbox:
params['image'] = controlnet_image
params['controlnet_conditioning_scale'] = float(controlnet_conditioning_scale)
# если используем IP адаптер
if ip_adapter_checkbox:
params['ip_adapter_image'] = ip_adapter_image
image = pipe(**params).images[0]
return image, seed
def controlnet_params(show_extra):
return gr.update(visible=show_extra)
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image Gradio Template")
with gr.Row():
# Dropdown to select the model from Hugging Face
model_id = gr.Dropdown(
label="Model",
choices=MODEL_LIST,
value=MODEL_LIST[0], # Default model
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42, # Default seed
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=20.0,
step=0.5,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=20,
)
# New slider for LoRA scale.
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
info="Adjust the influence of the LoRA weights",
)
with gr.Row():
controlnet_checkbox = gr.Checkbox(
label="ControlNet",
value=False
)
with gr.Column(visible=False) as controlnet_params:
controlnet_conditioning_scale = gr.Slider(
label="ControlNet conditioning scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
controlnet_mode = gr.Dropdown(
label="ControlNet mode",
choices=["edge_detection",
"depth_map",
"pose_estimation",
"normal_map",
"scribbles"],
value="edge_detection",
max_choices=1
)
controlnet_image = gr.Image(
label="ControlNet condition image",
type="pil",
format="png"
)
controlnet_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=controlnet_checkbox,
outputs=controlnet_params
)
with gr.Row():
ip_adapter_checkbox = gr.Checkbox(
label="IPAdapter",
value=False
)
with gr.Column(visible=False) as ip_adapter_params:
ip_adapter_scale = gr.Slider(
label="IPAdapter scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
ip_adapter_image = gr.Image(
label="IPAdapter condition image",
type="pil"
)
ip_adapter_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=ip_adapter_checkbox,
outputs=ip_adapter_params
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_scale, # Pass the new slider value
controlnet_checkbox,
controlnet_conditioning_scale,
controlnet_mode,
controlnet_image,
ip_adapter_checkbox,
ip_adapter_scale,
ip_adapter_image
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()