|
import logging |
|
import random |
|
import warnings |
|
import os |
|
import shutil |
|
import subprocess |
|
import torch |
|
import numpy as np |
|
from diffusers import FluxControlNetModel |
|
from diffusers.pipelines import FluxControlNetPipeline |
|
from PIL import Image |
|
from huggingface_hub import snapshot_download, login |
|
import io |
|
import base64 |
|
from flask import Flask, request, jsonify |
|
from concurrent.futures import ThreadPoolExecutor |
|
from flask_cors import CORS |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
def check_disk_space(): |
|
result = subprocess.run(['df', '-h'], capture_output=True, text=True) |
|
logger.info("Disk space usage:\n%s", result.stdout) |
|
|
|
|
|
def clear_huggingface_cache(): |
|
cache_dir = os.path.expanduser('~/.cache/huggingface') |
|
if os.path.exists(cache_dir): |
|
shutil.rmtree(cache_dir) |
|
logger.info("Cleared Hugging Face cache at: %s", cache_dir) |
|
else: |
|
logger.info("No Hugging Face cache found.") |
|
|
|
|
|
check_disk_space() |
|
|
|
|
|
clear_huggingface_cache() |
|
|
|
|
|
app.config['image_outputs'] = {} |
|
|
|
|
|
executor = ThreadPoolExecutor() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
logger.info("CUDA is available. Using GPU.") |
|
else: |
|
device = "cpu" |
|
logger.info("CUDA is not available. Using CPU.") |
|
|
|
|
|
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") |
|
if huggingface_token: |
|
login(token=huggingface_token) |
|
logger.info("Hugging Face token found and logged in.") |
|
else: |
|
logger.warning("Hugging Face token not found in environment variables.") |
|
|
|
logger.info("Hugging Face token: %s", huggingface_token) |
|
|
|
|
|
|
|
model_path = snapshot_download( |
|
repo_id="black-forest-labs/FLUX.1-dev", |
|
repo_type="model", |
|
ignore_patterns=["*.md", "*..gitattributes"], |
|
local_dir="FLUX.1-dev", |
|
token=huggingface_token) |
|
logger.info("Model downloaded to: %s", model_path) |
|
|
|
|
|
logger.info('Loading ControlNet model.') |
|
|
|
controlnet = FluxControlNetModel.from_pretrained( |
|
"jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 |
|
).to(device) |
|
logger.info("ControlNet model loaded successfully.") |
|
|
|
logger.info('Loading pipeline.') |
|
|
|
pipe = FluxControlNetPipeline.from_pretrained( |
|
model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 |
|
).to(device) |
|
logger.info("Pipeline loaded successfully.") |
|
|
|
MAX_SEED = 1000000 |
|
MAX_PIXEL_BUDGET = 1024 * 1024 |
|
|
|
def process_input(input_image, upscale_factor): |
|
w, h = input_image.size |
|
aspect_ratio = w / h |
|
was_resized = False |
|
|
|
|
|
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: |
|
warnings.warn(f"Requested output image is too large. Resizing to fit within pixel budget.") |
|
input_image = input_image.resize( |
|
( |
|
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), |
|
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), |
|
) |
|
) |
|
was_resized = True |
|
|
|
|
|
w, h = input_image.size |
|
w = w - w % 8 |
|
h = h - h % 8 |
|
|
|
return input_image.resize((w, h)), was_resized |
|
|
|
def run_inference(process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale): |
|
logger.info("Processing inference for process_id: %s", process_id) |
|
input_image, was_resized = process_input(input_image, upscale_factor) |
|
|
|
|
|
w, h = input_image.size |
|
control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) |
|
|
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
|
|
logger.info("Running pipeline for process_id: %s", process_id) |
|
image = pipe( |
|
prompt="", |
|
control_image=control_image, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=3.5, |
|
height=control_image.size[1], |
|
width=control_image.size[0], |
|
generator=generator, |
|
).images[0] |
|
|
|
|
|
if was_resized: |
|
original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor) |
|
image = image.resize(original_size) |
|
|
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
app.config['image_outputs'][process_id] = image_base64 |
|
logger.info("Inference completed for process_id: %s", process_id) |
|
|
|
@app.route('/infer', methods=['POST']) |
|
def infer(): |
|
|
|
if 'input_image' not in request.files: |
|
logger.error("No image file provided in request.") |
|
return jsonify({ |
|
"status": "error", |
|
"message": "No input_image file provided" |
|
}), 400 |
|
|
|
|
|
file = request.files['input_image'] |
|
|
|
|
|
if file.filename == '': |
|
logger.error("No selected file in form-data.") |
|
return jsonify({ |
|
"status": "error", |
|
"message": "No selected file" |
|
}), 400 |
|
|
|
|
|
input_image = Image.open(file) |
|
buffered = io.BytesIO() |
|
input_image.save(buffered, format="JPEG") |
|
|
|
|
|
seed = request.form.get("seed", 42, type=int) |
|
randomize_seed = request.form.get("randomize_seed", 'true').lower() == 'true' |
|
num_inference_steps = request.form.get("num_inference_steps", 28, type=int) |
|
upscale_factor = request.form.get("upscale_factor", 4, type=int) |
|
controlnet_conditioning_scale = request.form.get("controlnet_conditioning_scale", 0.6, type=float) |
|
|
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
logger.info("Seed randomized to: %d", seed) |
|
|
|
|
|
process_id = str(random.randint(1000, 9999)) |
|
logger.info("Process started with process_id: %s", process_id) |
|
|
|
|
|
app.config['image_outputs'][process_id] = None |
|
|
|
|
|
executor.submit(run_inference, process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale) |
|
|
|
|
|
return jsonify({ |
|
"process_id": process_id, |
|
"message": "Processing started" |
|
}) |
|
|
|
|
|
|
|
@app.route('/status', methods=['POST']) |
|
def status(): |
|
data = request.json |
|
process_id = data.get('process_id') |
|
|
|
|
|
if not process_id: |
|
logger.error("Process ID not provided in request.") |
|
return jsonify({ |
|
"status": "error", |
|
"message": "Process ID is required" |
|
}), 400 |
|
|
|
|
|
if process_id not in app.config['image_outputs']: |
|
logger.error("Invalid process ID: %s", process_id) |
|
return jsonify({ |
|
"status": "error", |
|
"message": "Invalid process ID" |
|
}), 404 |
|
|
|
|
|
image_base64 = app.config['image_outputs'][process_id] |
|
if image_base64 is None: |
|
logger.info("Process ID %s is still in progress.", process_id) |
|
return jsonify({ |
|
"status": "in_progress" |
|
}) |
|
else: |
|
logger.info("Process ID %s completed successfully.", process_id) |
|
return jsonify({ |
|
"status": "completed", |
|
"output_image": image_base64 |
|
}) |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |
|
|
|
|
|
|