rawc0der's picture
add release GPU decoration for generate fn
c323c02
raw
history blame
14 kB
import spaces
from fastapi import FastAPI, HTTPException, UploadFile, File
from typing import Optional, Dict, Any
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
AutoPipelineForText2Image
)
import gradio as gr
from PIL import Image
import numpy as np
import gc
from io import BytesIO
import base64
import functools
app = FastAPI()
# Comprehensive model registry
MODELS = {
"SDXL-Base": {
"model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"pipeline": StableDiffusionXLPipeline,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 100, "default": 50},
"guidance_scale": {"min": 1, "max": 15, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"SDXL-Turbo": {
"model_id": "stabilityai/sdxl-turbo",
"pipeline": AutoPipelineForText2Image,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 50, "default": 1},
"guidance_scale": {"min": 0.0, "max": 20.0, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"SD-1.5": {
"model_id": "runwayml/stable-diffusion-v1-5",
"pipeline": StableDiffusionPipeline,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 50, "default": 30},
"guidance_scale": {"min": 1, "max": 20, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"Waifu-Diffusion": {
"model_id": "hakurei/waifu-diffusion",
"pipeline": StableDiffusionPipeline,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 100, "default": 50},
"guidance_scale": {"min": 1, "max": 15, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"Flux": {
"model_id": "black-forest-labs/flux-1-1-dev",
"pipeline": AutoPipelineForText2Image,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 50, "default": 25},
"guidance_scale": {"min": 1, "max": 15, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
}
}
class ModelManager:
def __init__(self):
self.current_model = None
self.current_pipeline = None
self.model_cache: Dict[str, Any] = {}
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._dtype = torch.float16 if self._device == "cuda" else torch.float32
def _clear_memory(self):
"""Clear CUDA memory and garbage collect"""
if self.current_pipeline is not None:
del self.current_pipeline
self.current_pipeline = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
@functools.lru_cache(maxsize=1)
def get_model_config(self, model_id: str, pipeline_class):
"""Load and cache model configuration"""
return pipeline_class.from_pretrained(
model_id,
torch_dtype=self._dtype,
variant="fp16" if self._device == "cuda" else None,
device_map="auto"
)
def load_model(self, model_name: str):
"""Load model with memory optimization"""
if self.current_model != model_name:
self._clear_memory()
try:
model_info = MODELS[model_name]
self.current_pipeline = self.get_model_config(
model_info["model_id"],
model_info["pipeline"]
)
if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'):
self.current_pipeline.enable_xformers_memory_efficient_attention()
if self._device == "cuda":
self.current_pipeline.enable_model_cpu_offload()
self.current_model = model_name
except Exception as e:
self._clear_memory()
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
return self.current_pipeline
def unload_current_model(self):
"""Explicitly unload current model"""
self._clear_memory()
self.current_model = None
def get_memory_status(self):
"""Get current memory usage status"""
if not torch.cuda.is_available():
return {"status": "CPU Mode"}
return {
"total": torch.cuda.get_device_properties(0).total_memory / 1e9,
"allocated": torch.cuda.memory_allocated() / 1e9,
"cached": torch.cuda.memory_reserved() / 1e9,
"free": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9
}
class ModelContext:
def __init__(self, model_name: str):
self.model_name = model_name
def __enter__(self):
return model_manager.load_model(self.model_name)
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
model_manager.unload_current_model()
model_manager = ModelManager()
async def generate_image(
model_name: str,
prompt: str,
height: int = 512,
width: int = 512,
num_inference_steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
reference_image: Optional[Image.Image] = None
) -> dict:
try:
with ModelContext(model_name) as pipeline:
pre_mem = model_manager.get_memory_status()
# Process reference image if provided
if reference_image and MODELS[model_name]["supports_img2img"]:
reference_image = reference_image.resize((width, height))
# Generate image
generation_params = {
"prompt": prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps or MODELS[model_name]["parameters"]["num_inference_steps"]["default"],
"guidance_scale": guidance_scale or MODELS[model_name]["parameters"]["guidance_scale"]["default"]
}
if reference_image:
generation_params["image"] = reference_image
image = pipeline(**generation_params).images[0]
# Convert to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
post_mem = model_manager.get_memory_status()
return {
"status": "success",
"image_base64": img_str,
"memory": {
"before": pre_mem,
"after": post_mem
}
}
except Exception as e:
model_manager.unload_current_model()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate")
async def generate_image_endpoint(
model_name: str,
prompt: str,
height: int = 512,
width: int = 512,
num_inference_steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
reference_image: UploadFile = File(None)
):
ref_img = None
if reference_image:
content = await reference_image.read()
ref_img = Image.open(BytesIO(content))
return await generate_image(
model_name=model_name,
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
reference_image=ref_img
)
@app.get("/memory")
async def get_memory_status():
return model_manager.get_memory_status()
@app.post("/unload")
async def unload_model():
model_manager.unload_current_model()
return {"status": "success", "message": "Model unloaded"}
def create_gradio_interface():
with gr.Blocks() as interface:
gr.Markdown("# Text-to-Image Generation Interface")
with gr.Row():
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model"
)
prompt = gr.Textbox(
lines=3,
label="Prompt",
placeholder="Enter your image description here..."
)
with gr.Row():
height = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Height"
)
width = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Width"
)
with gr.Row():
num_steps = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Number of Inference Steps"
)
guidance = gr.Slider(
minimum=1,
maximum=15,
value=7.5,
step=0.1,
label="Guidance Scale"
)
reference_image = gr.Image(
type="pil",
label="Reference Image (optional)"
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
unload_btn = gr.Button("Unload Model")
with gr.Column(scale=2):
output_image = gr.Image(label="Generated Image")
memory_status = gr.JSON(
label="Memory Status",
value=model_manager.get_memory_status()
)
def update_params(model_name):
model_config = MODELS[model_name]["parameters"]
return [
gr.update(
minimum=model_config["height"]["min"],
maximum=model_config["height"]["max"],
value=model_config["height"]["default"],
step=model_config["height"]["step"]
),
gr.update(
minimum=model_config["width"]["min"],
maximum=model_config["width"]["max"],
value=model_config["width"]["default"],
step=model_config["width"]["step"]
),
gr.update(
minimum=model_config["num_inference_steps"]["min"],
maximum=model_config["num_inference_steps"]["max"],
value=model_config["num_inference_steps"]["default"]
),
gr.update(
minimum=model_config["guidance_scale"]["min"],
maximum=model_config["guidance_scale"]["max"],
value=model_config["guidance_scale"]["default"]
)
]
@spaces.GPU
def generate(model_name, prompt_text, h, w, steps, guide_scale, ref_img):
response = generate_image(
model_name=model_name,
prompt=prompt_text,
height=h,
width=w,
num_inference_steps=steps,
guidance_scale=guide_scale,
reference_image=ref_img
)
return Image.open(BytesIO(base64.b64decode(response["image_base64"])))
model_dropdown.change(
update_params,
inputs=[model_dropdown],
outputs=[height, width, num_steps, guidance]
)
generate_btn.click(
generate,
inputs=[
model_dropdown,
prompt,
height,
width,
num_steps,
guidance,
reference_image
],
outputs=[output_image]
)
unload_btn.click(
lambda: [model_manager.unload_current_model(), model_manager.get_memory_status()],
outputs=[memory_status]
)
return interface
if __name__ == "__main__":
import uvicorn
from threading import Thread
# Launch Gradio interface
interface = create_gradio_interface()
gradio_thread = Thread(
target=interface.launch,
kwargs={
"server_name": "0.0.0.0",
"server_port": 7860,
"share": True
}
)
gradio_thread.start()
# Launch FastAPI
uvicorn.run(app, host="0.0.0.0", port=8000)