Spaces:
Paused
Paused
from fastapi import FastAPI, Response | |
from fastapi.responses import FileResponse | |
import torch | |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler | |
from huggingface_hub import hf_hub_download, login | |
from safetensors.torch import load_file | |
from io import BytesIO | |
import os | |
import base64 # Added for encoding images as base64 | |
from typing import List # Added for type hinting the list of prompts | |
app = FastAPI() | |
# Get Hugging Face token from environment variable | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
def load_model(): | |
try: | |
# Login to Hugging Face if token is provided | |
if HF_TOKEN: | |
login(token=HF_TOKEN) | |
base = "stabilityai/stable-diffusion-xl-base-1.0" | |
repo = "ByteDance/SDXL-Lightning" | |
ckpt = "sdxl_lightning_4step_unet.safetensors" | |
# Load model with explicit error handling | |
unet = UNet2DConditionModel.from_config( | |
base, | |
subfolder="unet" | |
).to("cuda", torch.float16) | |
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
base, | |
unet=unet, | |
torch_dtype=torch.float16, | |
variant="fp16" | |
).to("cuda") | |
# Configure scheduler | |
pipe.scheduler = EulerDiscreteScheduler.from_config( | |
pipe.scheduler.config, | |
timestep_spacing="trailing" | |
) | |
return pipe | |
except Exception as e: | |
raise Exception(f"Failed to load model: {str(e)}") | |
# Load model at startup with error handling | |
try: | |
pipe = load_model() | |
except Exception as e: | |
print(f"Model initialization failed: {str(e)}") | |
raise | |
async def generate_image(prompt: str): | |
try: | |
# Generate image | |
image = pipe( | |
prompt, | |
num_inference_steps=4, | |
guidance_scale=0 | |
).images[0] | |
# Save image to buffer | |
buffer = BytesIO() | |
image.save(buffer, format="PNG") | |
buffer.seek(0) | |
return Response(content=buffer.getvalue(), media_type="image/png") | |
except Exception as e: | |
return {"error": str(e)} | |
# New endpoint to handle a list of prompts | |
async def generate_multiple_images(prompts: List[str]): | |
try: | |
# List to store base64-encoded images | |
generated_images = [] | |
# Generate an image for each prompt | |
for prompt in prompts: | |
image = pipe( | |
prompt, | |
num_inference_steps=4, | |
guidance_scale=0 | |
).images[0] | |
# Save image to buffer | |
buffer = BytesIO() | |
image.save(buffer, format="PNG") | |
buffer.seek(0) | |
# Encode the image as base64 | |
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
generated_images.append({ | |
"prompt": prompt, | |
"image_base64": image_base64 | |
}) | |
return {"images": generated_images} | |
except Exception as e: | |
return {"error": str(e)} | |
async def health_check(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |