Spaces:
Running
Running
from contextlib import nullcontext | |
from io import BytesIO | |
import os | |
import threading | |
from typing import Optional, Union | |
import warnings | |
from compel import Compel | |
from fastapi.responses import StreamingResponse | |
from loguru import logger | |
from PIL import Image | |
import torch | |
from leptonai.photon import Photon, FileParam, get_file_content, HTTPException | |
EXAMPLE_IMAGE_BASE64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxAQEBANDxIQEA8PDw8PDxUPEg8NDxUPFRIWFhURFRYYHSggGBolGxUVITEhJSkrLi4uFx8zODMsNygtLisBCgoKDg0OGBAQFysfHx8tKy4tKy0tKystLS0rKy0tLSstNy4tLy0tLS0tKy0tLSsrLS0rLS0tLS0tLS0rKzctK//AABEIAOEA4QMBEQACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQMCBAYHBf/EAEAQAQACAQIBCAUIBwkBAAAAAAABAgMEETEFBhIhQXGRoVFhgbHBBxMyQ1JyktEVIkJic4LhJFNjk6KywuLwFP/EABoBAQEAAwEBAAAAAAAAAAAAAAABAgMFBAb/xAAtEQEAAgIBAgMIAQUBAAAAAAAAAQIDEQQSUSFBkQUTIjFCUmFxMiMzgaHBFP/aAAwDAQACEQMRAD8A9uBIJBIAAAAAAAAAAAAAAAAAAAAAAAAAMAZQACQAAAAAAAAAAAAAAAAAAAAAAAAAYgmASAAAAAAAAAAAAAAAAAAAAAAAAAACASAAAAAAAAAAAAAAAAAAAAAAAAAAACIBIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAI3BIAAAI3BIAAAAAAAAAAAAAAAAAAOd5b546XSV6V5m28zEdHqrMx2RPb7Gi2esfLxerHxL2+fg4/V/K12YcEd95mfyaLcq3lD2U9n087S+ff5S9bb6PzdO6sT792qeTk7t9eBh7f7Uzz311/rZjuise6Guc2T7m6OHhj6YTHOXWW458vstaPixnJefOfVn/AObFH0x6M45Z1E8cuSe+1p+KdVu7L3NPtj0ZRypm/vLeMnVPc91TtHon9KZvt28U6p7r7qnZH6YzxwvbxlOue6+5p2j0ZRzg1NeGXJ+O8fFfeWjzljPHxT9Mei7Hzy1dP25mPXMW98M45GSPNrtwsM/S+7yTz1yXibZOj0YjebXjoV7otHVu9FOTfzeLLwMcfKdO20Oqrmx0zV+jkrFo7pe+s7jblXr02mOy9WIAAAAAAAAAAAAAADxzVx0t8d60yUi07VyVi8Rt2xvwn1uRaZiZ0+kpWJiNvnX5E00/VTTf7F7x5WmYhh1tnR+WMc3sP7Ns0d847f8AGE6mWpW15CrHDJf21ifieC7lbXkiI+sn8H/Y8DcrY0ER+3P4P6htP/yx9qfCPzDaJwR+95QmoNyqtjr6J9sx+RqF3Km+32Y9s3/M0m5a+TLaOG0d1axPjtuyhjLXyTaZibTNp/emZlshqs9t5qzvotN/Bp7nTx/xhwM/9y37fVZtQAAAAAAAAAAAAAADyLW02y5I9GS8eFpcjJHjL6TDO6R+mFatTcsrHqFZxHqETt6lETt6BFdkVXZRReBWvkgRq3hlEMZa8x1s4arPbeasf2LS/wAGnudPH/GHBz/3Lft9Vm1AAAAAAAAAAAAAAAPKeWqdHU56+jNkn2TaZj3uVljV5fQ8ad46/pr0lpelbWUGSiYkRjaQVyiqrqKbg1sqjVuyhjKiI62cNcvcuQMfQ0umrPGMGLfv6EbunSNVh89knd5n8t9kwAAAAAAAAAAAAAAAcZzv5vTM31uOY22i2Ws9XCIjpVnu26p9c7vJnwb+KHR4fK6dY7f4ch1xxie/jHjweGay68XiWdLwx0yWxYDcETIMJkFdpXRtr3tBo2pms24RM90TK9MsZlVOnt27V75jfw4s4rLGbQ6fmLzew6i98mXpXrh6G0fRpa1t+qe2Yjbh1cXqwY4nxlzuZntTVa+b02IexykgAAAAAAAiASAAAAAACrV4IyY74p4Xpak91omPikxuNLWdTEvGsmG0TtO8THVPfDl2nT6KmpjwZRW/pme/rY9TZ0soi3q8ITa6Zb29Eea7TUotafRH+r802uvywm0+rzNmmFrT6vwx8V2aYTktHbt3bR7jZ0qMl7TxmZ79zadMKpiZ4yyiWMxp6f8AJzp+jpLX7cma0x3RER74l78EfC4vNtvJrs6tueQAAAAAAABiCYBIAAAAAAPMucOl6GqzV7JvN47r/rfFzc0avLu8S3Viq0a1aHsZxRA6AMJoKwnGIptSFVTesKjXuqSq7WUMLPYeamDoaLT19OOL/jmbfF0scarD5/PbqyWl9Zm1AAAAAAAAMATAJBIAAAAAOJ584Ns2PJ9vH0fbWfytHg8XKr4xLq+z7fDMdnOQ8bpwziUVEyCJBhbYFGSYUa95Ua2SVSWOKs2tFY4zMRHfPVDOsbnTVedRMvccGKKUrSOFK1rHdEbOpD52Z3O1ggAAAAAAADCATAJgEgAAAAA57ntp+lp63jjjyRv923V7+i8/Irum3s4NtZNd3DOdLt1TFkZG67XTGZQYWlYFNga2SVGveVhjLf5sYPnNZp6f4tbT3V/Wnyq3Yo3aHk5VunHZ7M6LhAAAAAAAAAMIBIJgEgAAAAA0+WNP85gy4+2cduj96OuvnEMbxusw2YrdN4l5jXrcqYfRVkmrBmgUkFdwUXkhWtkVGteWUMZdL8nmKJ1nTnhixXt7Z2r7rS9XGj4nO59v6eu8vUIyw9rkMotAJAAAAAAABWCYkE7gkEgAAAAiQeW8paf5rPlx8IrktEfd36vLZy8katMPoePbqpWVUW9TU3onZGSJQV2lRr5BWtkWEa2RnDCXU8xabRmyemaUj2bzPvh7ONHzlyudPjEOux5p9L1Q5+mzj1NlRtY9QiNmmUFkWBIAAAAKwSCQSCQAAAAAcFz20/R1EZNurLSJ/mr+rPlFXh5Nfi33dj2ffdNdnway8bpMoBEgrsCi6jVyKNXIyhhLtuaeLo6as/bte8+O0eVYdDDGquLyrbyT+H3sdW6HlbGOqo2cdRi2KQguqC2ASAAADAE7AAkEgAAAAA5vnxpelgrljjiv1/dt1T5xVo5Fd132e3g36cmu7hYlzZd2GcSiomQYWkGvlso1MsqjTy2/ozrG2q9tRt6byLp9sWOkcK0rXwjrdSldQ4OS27TL7OLTMmqZbNNObYr64kGcUBnFQZAAAAAxAgEgkAAAAAAFWow1vW1LxFq2ia2ie2JNbWJ1O4cbr+Zt95nT5KzHZTNvWY/nrE7+2Pa8t+LE/wAZ06eL2hMeF42+Vl5B1dOOC0x6aTTJ5RO/k888XJD1152GfPTUyaPPHHBqP8jNMf7WucN/tlujkYZ+uPVRbT5uzDqJ7sGaZ8qnub/bK+/xffHqxryVq7/R02o/mxXx+d4iGcYMk+TXPMwx9X/V+Hmdyhk+rx4fXmy14emIx9Lw6m2vFt5vPf2hjj5RMvucj/J5Sloy6nLOe9Z3ita/N4Yn09HeZn2z7IevHhpTxeDLy75PDydpg0laxtENm3l2vikIidgSAAAAAACN/wD3UBsCQAAAAAAAAQCJgU2U2bAAAbAmIREgAAAAAAAAAx6/V4gyAAAAAAAAAAAABGwGwGwJAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB/9k=" | |
class JPEGResponse(StreamingResponse): | |
media_type = "image/jpeg" | |
class ImgPilot(Photon): | |
requirement_dependency = [ | |
"torch", | |
"diffusers", | |
"invisible-watermark", | |
"compel", | |
"Pillow", | |
] | |
# In default, we will use gpu.a10 as the computation resource shape. This should | |
# be fast enough. | |
deployment_template = { | |
"resource_shape": "gpu.a10", | |
"env": { | |
"MODEL": "SimianLuo/LCM_Dreamshaper_v7", | |
"USE_TORCH_COMPILE": "false", | |
"WIDTH": "768", | |
"HEIGHT": "768", | |
"PRINT_PROMPT": "false", | |
}, | |
} | |
# A10 should be able to support a maximum concurrency of 8 requests to interleave | |
# IO and compute. This is not tuned by the way. | |
handler_max_concurrency = 1 | |
def init(self): | |
from diffusers import AutoPipelineForImage2Image # type: ignore | |
cuda_available = torch.cuda.is_available() | |
if cuda_available: | |
self.device = torch.device("cuda") | |
else: | |
self.device = torch.device("cpu") | |
self.base = AutoPipelineForImage2Image.from_pretrained( | |
os.environ["MODEL"], | |
torch_dtype=torch.float16 if cuda_available else torch.float32, | |
) | |
self.base.safety_checker = None | |
self.base.requires_safety_checker = False | |
if self.handler_max_concurrency > 1: | |
self.base_lock = threading.Lock() | |
else: | |
self.base_lock = nullcontext() | |
self.print_prompt = os.environ["PRINT_PROMPT"].lower() in [ | |
"true", | |
"t", | |
"1", | |
"yes", | |
"y", | |
] | |
logger.info(f"print_prompt: {self.print_prompt}") | |
if cuda_available: | |
self.base.to("cuda") | |
self.use_torch_compile = os.environ["USE_TORCH_COMPILE"].lower() in [ | |
"true", | |
"t", | |
"1", | |
"yes", | |
"y", | |
] | |
if self.use_torch_compile: | |
if self.handler_max_concurrency > 1: | |
warnings.warn( | |
"torch compile does not support multithreading, so we will" | |
" disable torch compile since handler_max_concurrency > 1." | |
) | |
else: | |
self.width = int(os.environ["WIDTH"]) | |
self.height = int(os.environ["HEIGHT"]) | |
logger.info( | |
"Compiling model with torch.compile. Note that with torch" | |
" compile, your first invocation will be slow, but subsequent" | |
" invocations will be faster." | |
) | |
self.base.unet = torch.compile( | |
self.base.unet, mode="reduce-overhead", fullgraph=True | |
) | |
else: | |
self.use_torch_compile = False | |
self.compel_proc = Compel( | |
tokenizer=self.base.tokenizer, | |
text_encoder=self.base.text_encoder, | |
truncate_long_prompts=False, | |
) # type: ignore | |
logger.info(f"Initialized model {os.environ['MODEL']}. cuda: {cuda_available}.") | |
def run( | |
self, | |
prompt: str, | |
seed: int, | |
strength: float, | |
steps: int, | |
guidance_scale: float, | |
width: int, | |
height: int, | |
lcm_steps: int, | |
input_image: Optional[Union[str, FileParam]], | |
) -> JPEGResponse: | |
from diffusers.utils import load_image # type: ignore | |
import time | |
start = time.time() | |
if self.print_prompt: | |
logger.info(f"Prompt: {prompt}") | |
# diffusers truncates prompt to 77 tokens, in case prompt is too long, we will | |
# use compel to process the prompt (but compel is slower) | |
tokens = self.base.tokenizer(prompt, return_tensors="pt") | |
if tokens.input_ids.shape[1] > 77: | |
prompt_embeds = self.compel_proc(prompt) | |
prompt = None | |
else: | |
prompt_embeds = None | |
if input_image is not None: | |
image_file = get_file_content(input_image, return_file=True) | |
pil_image = Image.open(image_file, formats=["JPEG", "PNG", "GIF", "BMP"]) | |
if self.use_torch_compile: | |
# checks width and height parameter, and return error if width and height are not correct | |
if width != self.width or height != self.height: | |
raise HTTPException( | |
status_code=400, | |
detail=( | |
f"width and height must be {self.width} and" | |
f" {self.height} when use_torch_compile is true." | |
), | |
) | |
# checks input image height and width, and resize if necessary | |
if pil_image.height != self.height or pil_image.width != self.width: | |
pil_image = pil_image.resize( | |
(self.width, self.height), Image.BILINEAR | |
) | |
input_image = load_image(pil_image).convert("RGB") | |
with self.base_lock: | |
generator = torch.manual_seed(seed) | |
output_image = self.base( | |
prompt=prompt, | |
prompt_embeds=prompt_embeds, | |
generator=generator, | |
image=input_image, | |
strength=strength, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
width=width, | |
height=height, | |
lcm_origin_steps=lcm_steps, | |
output_type="pil", | |
) # type: ignore | |
nsfw_content_detected = ( | |
output_image.nsfw_content_detected[0] | |
if "nsfw_content_detected" in output_image | |
else False | |
) # type: ignore | |
if nsfw_content_detected: | |
raise HTTPException(status_code=400, detail="nsfw content detected") | |
else: | |
img_io = BytesIO() | |
output_image.images[0].save(img_io, format="JPEG") # type: ignore | |
img_io.seek(0) | |
logger.info(f"Produced output in {time.time() - start} seconds.") | |
return JPEGResponse(img_io) | |
if __name__ == "__main__": | |
p = ImgPilot() | |
p.launch() | |