Spaces:
Runtime error
Runtime error
File size: 2,073 Bytes
6bae932 6a0af53 6bae932 3e81177 6bae932 4598c6c 1870c21 8e97e43 af2cfcd 6bae932 7963de3 6bae932 98b8e77 6bae932 98b8e77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# load both base & refiner
from io import BytesIO
import torch
from diffusers import DiffusionPipeline
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from cache.local_cache import ttl_cache
from config import settings
router = APIRouter()
base = DiffusionPipeline.from_pretrained(
settings.base_sd_model, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
base.to("cuda")
# base.enable_model_cpu_offload()
base.enable_attention_slicing()
refiner = DiffusionPipeline.from_pretrained(
settings.refiner_sd_model,
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
refiner.to("cuda")
# refiner.enable_model_cpu_offload()
refiner.enable_attention_slicing()
@router.get("/generate")
@ttl_cache(key_name='prompt', media_type="image/png", ttl_secs=20)
async def generate(prompt: str):
"""
generate image
"""
# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 40
high_noise_frac = 0.8
negative = "disfigured, ugly, bad, immature, cartoon, anime, 3d, painting, b&w, sketch, blurry, deformed, bad anatomy, poorly drawn face, mutation, multiple people."
prompt = f"single image. single model. {prompt}. zoomed in. full-body. real person. realistic. 4k. best quality."
print(prompt)
# run both experts
image = base(
prompt=prompt,
negative_prompt=negative,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images[0]
final_image = refiner(
prompt=prompt,
negative_prompt=negative,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]
memory_stream = BytesIO()
final_image.save(memory_stream, format="PNG")
image_data = memory_stream.getvalue() # get bytes of the image
memory_stream.seek(0)
return StreamingResponse(memory_stream, media_type="image/png"), image_data
|