Magix / app.py
Singularity666's picture
Update app.py
66f79e7 verified
raw
history blame
2.93 kB
import gradio as gr
import os
import shutil
from pathlib import Path
from main import fine_tune_model
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "/content/stable_diffusion_weights/custom_model"
def fine_tune(instance_prompt, images):
instance_data_dir = "/content/instance_images"
if os.path.exists(instance_data_dir):
shutil.rmtree(instance_data_dir)
os.makedirs(instance_data_dir, exist_ok=True)
for i, img in enumerate(images):
img.save(os.path.join(instance_data_dir, f"instance_{i}.png"))
fine_tune_model(instance_data_dir, instance_prompt, MODEL_NAME, OUTPUT_DIR)
return "Model fine-tuning complete."
def generate_images(prompt, num_samples, height, width, num_inference_steps, guidance_scale):
pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
g_cuda = torch.Generator(device='cuda').manual_seed(1337)
with torch.autocast("cuda"), torch.inference_mode():
images = pipe(
prompt, height=height, width=width, num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda
).images
return images
def gradio_app():
with gr.Blocks() as demo:
with gr.Tab("Fine-Tune Model"):
with gr.Row():
with gr.Column():
instance_prompt = gr.Textbox(label="Instance Prompt")
image_input = gr.Image(label="Upload Images", source="upload", tool="editor", type="pil", multiple=True)
fine_tune_button = gr.Button("Fine-Tune Model")
output_text = gr.Textbox(label="Output")
fine_tune_button.click(fine_tune, inputs=[instance_prompt, image_input], outputs=output_text)
with gr.Tab("Generate Images"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
num_samples = gr.Number(label="Number of Samples", value=1)
guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
height = gr.Number(label="Height", value=512)
width = gr.Number(label="Width", value=512)
num_inference_steps = gr.Slider(label="Steps", value=50, minimum=1, maximum=100)
generate_button = gr.Button("Generate Images")
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
generate_button.click(generate_images, inputs=[prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
demo.launch()
if __name__ == "__main__":
gradio_app()