|
import spaces |
|
import logging |
|
import random |
|
import warnings |
|
import os |
|
import shutil |
|
import subprocess |
|
import torch |
|
import numpy as np |
|
from diffusers import FluxControlNetModel |
|
from diffusers.pipelines import FluxControlNetPipeline |
|
from PIL import Image |
|
from huggingface_hub import snapshot_download, login |
|
import io |
|
import base64 |
|
from fastapi import FastAPI, File, UploadFile,Form |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from concurrent.futures import ThreadPoolExecutor |
|
import uvicorn |
|
import asyncio |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) |
|
|
|
|
|
executor = ThreadPoolExecutor() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
logger.info("CUDA is available. Using GPU.") |
|
else: |
|
device = "cpu" |
|
logger.info("CUDA is not available. Using CPU.") |
|
|
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
if huggingface_token: |
|
login(token=huggingface_token) |
|
logger.info("Hugging Face token found and logged in.") |
|
else: |
|
logger.warning("Hugging Face token not found in environment variables.") |
|
|
|
|
|
model_path = snapshot_download( |
|
repo_id="black-forest-labs/FLUX.1-dev", |
|
repo_type="model", |
|
ignore_patterns=["*.md", "*..gitattributes"], |
|
local_dir="FLUX.1-dev", |
|
token=huggingface_token |
|
) |
|
logger.info("Model downloaded to: %s", model_path) |
|
|
|
|
|
logger.info('Loading ControlNet model.') |
|
cache_dir = "./model_cache" |
|
controlnet = FluxControlNetModel.from_pretrained( |
|
"jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16,cache_dir=cache_dir |
|
).to(device) |
|
logger.info("ControlNet model loaded successfully.") |
|
|
|
logger.info('Loading pipeline.') |
|
pipe = FluxControlNetPipeline.from_pretrained( |
|
model_path, controlnet=controlnet, torch_dtype=torch.bfloat16,cache_dir=cache_dir |
|
).to(device) |
|
logger.info("Pipeline loaded successfully.") |
|
|
|
MAX_SEED = 1000000 |
|
MAX_PIXEL_BUDGET = 1024 * 1024 |
|
|
|
|
|
def process_input(input_image, upscale_factor): |
|
w, h = input_image.size |
|
aspect_ratio = w / h |
|
was_resized = False |
|
|
|
|
|
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: |
|
warnings.warn("Requested output image is too large. Resizing to fit within pixel budget.") |
|
input_image = input_image.resize( |
|
( |
|
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), |
|
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), |
|
) |
|
) |
|
was_resized = True |
|
|
|
|
|
w, h = input_image.size |
|
w = w - w % 8 |
|
h = h - h % 8 |
|
|
|
return input_image.resize((w, h)), was_resized |
|
|
|
|
|
def run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale): |
|
logger.info("Processing inference.") |
|
input_image, was_resized = process_input(input_image, upscale_factor) |
|
|
|
|
|
w, h = input_image.size |
|
control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) |
|
|
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
|
|
logger.info("Running pipeline.") |
|
image = pipe( |
|
prompt="", |
|
control_image=control_image, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=3.5, |
|
height=control_image.size[1], |
|
width=control_image.size[0], |
|
generator=generator, |
|
).images[0] |
|
|
|
|
|
if was_resized: |
|
original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor) |
|
image = image.resize(original_size) |
|
|
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return image_base64 |
|
|
|
@app.post("/infer") |
|
async def infer(input_image: UploadFile = File(...), |
|
upscale_factor: int = Form(4), |
|
seed: int = Form(42), |
|
num_inference_steps: int = Form(28), |
|
controlnet_conditioning_scale: float = Form(0.6)): |
|
logger.info("Received request for inference.") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
contents = await input_image.read() |
|
print(type(contents)) |
|
contents = bytes(contents) |
|
image = Image.open(io.BytesIO(contents)) |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
base64_image = await loop.run_in_executor(executor, run_inference, image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale) |
|
|
|
|
|
time_taken = time.time() - start_time |
|
|
|
return JSONResponse(content={"base64_image": base64_image, "time_taken": time_taken}) |
|
|
|
if __name__ == "__main__": |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|