import spaces import logging import random import warnings import os import shutil import subprocess import torch import numpy as np from diffusers import FluxControlNetModel from diffusers.pipelines import FluxControlNetPipeline from PIL import Image from huggingface_hub import snapshot_download, login import io import base64 from fastapi import FastAPI, File, UploadFile,Form from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from concurrent.futures import ThreadPoolExecutor import uvicorn import asyncio import time # Import time module for measuring execution time # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # FastAPI app for image processing app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) # ThreadPoolExecutor for managing image processing threads executor = ThreadPoolExecutor() #Determine the device (GPU or CPU) if torch.cuda.is_available(): device = "cuda" logger.info("CUDA is available. Using GPU.") else: device = "cpu" logger.info("CUDA is not available. Using CPU.") # Load model from Huggingface Hub huggingface_token = os.getenv("HUGGINGFACE_TOKEN") if huggingface_token: login(token=huggingface_token) logger.info("Hugging Face token found and logged in.") else: logger.warning("Hugging Face token not found in environment variables.") # Download model using snapshot_download model_path = snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", repo_type="model", ignore_patterns=["*.md", "*..gitattributes"], local_dir="FLUX.1-dev", token=huggingface_token ) logger.info("Model downloaded to: %s", model_path) # Load pipeline logger.info('Loading ControlNet model.') cache_dir = "./model_cache" controlnet = FluxControlNetModel.from_pretrained( "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16,cache_dir=cache_dir ).to(device) logger.info("ControlNet model loaded successfully.") logger.info('Loading pipeline.') pipe = FluxControlNetPipeline.from_pretrained( model_path, controlnet=controlnet, torch_dtype=torch.bfloat16,cache_dir=cache_dir ).to(device) logger.info("Pipeline loaded successfully.") MAX_SEED = 1000000 MAX_PIXEL_BUDGET = 1024 * 1024 #@spaces.GPU def process_input(input_image, upscale_factor): w, h = input_image.size aspect_ratio = w / h was_resized = False # Resize if input size exceeds the maximum pixel budget if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: warnings.warn("Requested output image is too large. Resizing to fit within pixel budget.") input_image = input_image.resize( ( int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), ) ) was_resized = True # Adjust dimensions to be a multiple of 8 w, h = input_image.size w = w - w % 8 h = h - h % 8 return input_image.resize((w, h)), was_resized #@spaces.GPU def run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale): logger.info("Processing inference.") input_image, was_resized = process_input(input_image, upscale_factor) # Rescale image for ControlNet processing w, h = input_image.size control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) # Set the random generator for inference generator = torch.Generator().manual_seed(seed) # Perform inference using the pipeline logger.info("Running pipeline.") image = pipe( prompt="", control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, num_inference_steps=num_inference_steps, guidance_scale=3.5, height=control_image.size[1], width=control_image.size[0], generator=generator, ).images[0] # Resize output image back to the original dimensions if needed if was_resized: original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor) image = image.resize(original_size) # Convert the output image to base64 buffered = io.BytesIO() image.save(buffered, format="JPEG") image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return image_base64 @app.post("/infer") async def infer(input_image: UploadFile = File(...), upscale_factor: int = Form(4), # Default value of 4 seed: int = Form(42), # Default value of 42 num_inference_steps: int = Form(28), # Default value of 28 controlnet_conditioning_scale: float = Form(0.6)): logger.info("Received request for inference.") # Start timing the entire inference process start_time = time.time() # Read the uploaded image contents = await input_image.read() print(type(contents)) contents = bytes(contents) image = Image.open(io.BytesIO(contents)) # Get the current event loop loop = asyncio.get_event_loop() # Run inference in a separate thread base64_image = await loop.run_in_executor(executor, run_inference, image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale) # Calculate the time taken time_taken = time.time() - start_time return JSONResponse(content={"base64_image": base64_image, "time_taken": time_taken}) if __name__ == "__main__": # Start FastAPI server uvicorn.run(app, host="0.0.0.0", port=7860)