|
import gradio as gr |
|
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler |
|
import torch |
|
from PIL import Image |
|
import random |
|
from peft import PeftModel, LoraConfig |
|
|
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
lora_model_id = "codermert/mert_flux" |
|
|
|
def load_model(): |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) |
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe = pipe.to("cpu") |
|
pipe.safety_checker = None |
|
|
|
|
|
config = LoraConfig.from_pretrained(lora_model_id) |
|
pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_model_id) |
|
|
|
return pipe |
|
|
|
pipe = load_model() |
|
|
|
def generate_image(prompt, negative_prompt, steps, cfg_scale, seed, strength): |
|
if seed == -1: |
|
seed = random.randint(1, 1000000000) |
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
with torch.no_grad(): |
|
image = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=cfg_scale, |
|
generator=generator, |
|
).images[0] |
|
|
|
return image, seed |
|
|
|
css = """ |
|
#app-container { |
|
max-width: 800px; |
|
margin-left: auto; |
|
margin-right: auto; |
|
} |
|
""" |
|
|
|
examples = [ |
|
["A beautiful landscape with mountains and a lake", "ugly, deformed"], |
|
["A futuristic cityscape at night", "daytime, rural"], |
|
["A portrait of a smiling person in a colorful outfit", "monochrome, frowning"], |
|
] |
|
|
|
with gr.Blocks(theme='default', css=css) as app: |
|
gr.HTML("<center><h1>Mert Flux LoRA Explorer (CPU Version)</h1></center>") |
|
with gr.Column(elem_id="app-container"): |
|
with gr.Row(): |
|
text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=2) |
|
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What to avoid in the image", lines=2) |
|
with gr.Row(): |
|
with gr.Column(): |
|
steps = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=50, step=1) |
|
cfg_scale = gr.Slider(label="CFG Scale", value=7.5, minimum=1, maximum=15, step=0.5) |
|
with gr.Column(): |
|
seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1) |
|
|
|
with gr.Row(): |
|
generate_button = gr.Button("Generate", variant='primary') |
|
with gr.Row(): |
|
image_output = gr.Image(type="pil", label="Generated Image", show_download_button=True) |
|
with gr.Row(): |
|
seed_output = gr.Number(label="Seed Used") |
|
|
|
gr.Examples(examples=examples, inputs=[text_prompt, negative_prompt]) |
|
|
|
generate_button.click( |
|
generate_image, |
|
inputs=[text_prompt, negative_prompt, steps, cfg_scale, seed, strength], |
|
outputs=[image_output, seed_output] |
|
) |
|
|
|
app.launch() |