# import logging | |
# import random | |
# import warnings | |
# import gradio as gr | |
# import os | |
# import shutil | |
# import subprocess | |
# import spaces | |
# import torch | |
# import numpy as np | |
# from diffusers import FluxControlNetModel | |
# from diffusers.pipelines import FluxControlNetPipeline | |
# from PIL import Image | |
# from huggingface_hub import snapshot_download, login | |
# import io | |
# import base64 | |
# from flask import Flask, request, jsonify | |
# from concurrent.futures import ThreadPoolExecutor | |
# from flask_cors import CORS | |
# import threading | |
# # Configure logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# app = Flask(__name__) | |
# CORS(app) | |
# # Function to check disk usage | |
# def check_disk_space(): | |
# result = subprocess.run(['df', '-h'], capture_output=True, text=True) | |
# logger.info("Disk space usage:\n%s", result.stdout) | |
# # Function to clear Hugging Face cache | |
# def clear_huggingface_cache(): | |
# cache_dir = os.path.expanduser('~/.cache/huggingface') | |
# if os.path.exists(cache_dir): | |
# shutil.rmtree(cache_dir) # Removes the entire cache directory | |
# logger.info("Cleared Hugging Face cache at: %s", cache_dir) | |
# else: | |
# logger.info("No Hugging Face cache found.") | |
# # Check disk space | |
# check_disk_space() | |
# # Clear Hugging Face cache | |
# clear_huggingface_cache() | |
# # Add config to store base64 images | |
# app.config['image_outputs'] = {} | |
# # ThreadPoolExecutor for managing image processing threads | |
# executor = ThreadPoolExecutor() | |
# # Determine the device (GPU or CPU) | |
# if torch.cuda.is_available(): | |
# device = "cuda" | |
# logger.info("CUDA is available. Using GPU.") | |
# else: | |
# device = "cpu" | |
# logger.info("CUDA is not available. Using CPU.") | |
# # Load model from Huggingface Hub | |
# huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
# if huggingface_token: | |
# login(token=huggingface_token) | |
# logger.info("Hugging Face token found and logged in.") | |
# else: | |
# logger.warning("Hugging Face token not found in environment variables.") | |
# logger.info("Hugging Face token: %s", huggingface_token) | |
# # Download model using snapshot_download | |
# model_path = snapshot_download( | |
# repo_id="black-forest-labs/FLUX.1-dev", | |
# repo_type="model", | |
# ignore_patterns=["*.md", "*..gitattributes"], | |
# local_dir="FLUX.1-dev", | |
# token=huggingface_token) | |
# logger.info("Model downloaded to: %s", model_path) | |
# # Load pipeline | |
# logger.info('Loading ControlNet model.') | |
# controlnet = FluxControlNetModel.from_pretrained( | |
# "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 | |
# ).to(device) | |
# logger.info("ControlNet model loaded successfully.") | |
# logger.info('Loading pipeline.') | |
# pipe = FluxControlNetPipeline.from_pretrained( | |
# model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 | |
# ).to(device) | |
# logger.info("Pipeline loaded successfully.") | |
# MAX_SEED = 1000000 | |
# MAX_PIXEL_BUDGET = 1024 * 1024 | |
# @spaces.GPU | |
# def process_input(input_image, upscale_factor): | |
# w, h = input_image.size | |
# aspect_ratio = w / h | |
# was_resized = False | |
# # Resize if input size exceeds the maximum pixel budget | |
# if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
# warnings.warn(f"Requested output image is too large. Resizing to fit within pixel budget.") | |
# input_image = input_image.resize( | |
# ( | |
# int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
# int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
# ) | |
# ) | |
# was_resized = True | |
# # Adjust dimensions to be a multiple of 8 | |
# w, h = input_image.size | |
# w = w - w % 8 | |
# h = h - h % 8 | |
# return input_image.resize((w, h)), was_resized | |
# @spaces.GPU | |
# def run_inference(process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale): | |
# logger.info("Processing inference for process_id: %s", process_id) | |
# input_image, was_resized = process_input(input_image, upscale_factor) | |
# # Rescale image for ControlNet processing | |
# w, h = input_image.size | |
# control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) | |
# # Set the random generator for inference | |
# generator = torch.Generator().manual_seed(seed) | |
# # Perform inference using the pipeline | |
# logger.info("Running pipeline for process_id: %s", process_id) | |
# image = pipe( | |
# prompt="", | |
# control_image=control_image, | |
# controlnet_conditioning_scale=controlnet_conditioning_scale, | |
# num_inference_steps=num_inference_steps, | |
# guidance_scale=3.5, | |
# height=control_image.size[1], | |
# width=control_image.size[0], | |
# generator=generator, | |
# ).images[0] | |
# # Resize output image back to the original dimensions if needed | |
# if was_resized: | |
# original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor) | |
# image = image.resize(original_size) | |
# # Convert the output image to base64 | |
# buffered = io.BytesIO() | |
# image.save(buffered, format="JPEG") | |
# image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# # Store the result in the shared dictionary | |
# app.config['image_outputs'][process_id] = image_base64 | |
# logger.info("Inference completed for process_id: %s", process_id) | |
# @app.route('/infer', methods=['POST']) | |
# def infer(): | |
# # Check if the file was provided in the form-data | |
# if 'input_image' not in request.files: | |
# logger.error("No image file provided in request.") | |
# return jsonify({ | |
# "status": "error", | |
# "message": "No input_image file provided" | |
# }), 400 | |
# # Get the uploaded image file from the request | |
# file = request.files['input_image'] | |
# # Check if a file was uploaded | |
# if file.filename == '': | |
# logger.error("No selected file in form-data.") | |
# return jsonify({ | |
# "status": "error", | |
# "message": "No selected file" | |
# }), 400 | |
# # Convert the image to Base64 for internal processing | |
# input_image = Image.open(file) | |
# buffered = io.BytesIO() | |
# input_image.save(buffered, format="JPEG") | |
# # Retrieve additional parameters from the request (if any) | |
# seed = request.form.get("seed", 42, type=int) | |
# randomize_seed = request.form.get("randomize_seed", 'true').lower() == 'true' | |
# num_inference_steps = request.form.get("num_inference_steps", 28, type=int) | |
# upscale_factor = request.form.get("upscale_factor", 4, type=int) | |
# controlnet_conditioning_scale = request.form.get("controlnet_conditioning_scale", 0.6, type=float) | |
# # Randomize seed if specified | |
# if randomize_seed: | |
# seed = random.randint(0, MAX_SEED) | |
# logger.info("Seed randomized to: %d", seed) | |
# # Create a unique process ID for this request | |
# process_id = str(random.randint(1000, 9999)) | |
# logger.info("Process started with process_id: %s", process_id) | |
# # Set the status to 'in_progress' | |
# app.config['image_outputs'][process_id] = None | |
# # Run the inference in a separate thread | |
# executor.submit(run_inference, process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale) | |
# # Return the process ID | |
# return jsonify({ | |
# "process_id": process_id, | |
# "message": "Processing started" | |
# }) | |
# # Modify status endpoint to receive process_id in request body | |
# @app.route('/status', methods=['GET']) | |
# def status(): | |
# # Get the process_id from the query parameters | |
# process_id = request.args.get('process_id') | |
# # Check if process_id was provided | |
# if not process_id: | |
# logger.error("Process ID not provided in request.") | |
# return jsonify({ | |
# "status": "error", | |
# "message": "Process ID is required" | |
# }), 400 | |
# # Check if the process_id exists in the dictionary | |
# if process_id not in app.config['image_outputs']: | |
# logger.error("Invalid process ID: %s", process_id) | |
# return jsonify({ | |
# "status": "error", | |
# "message": "Invalid process ID" | |
# }), 404 | |
# # Check the status of the image processing | |
# image_base64 = app.config['image_outputs'][process_id] | |
# if image_base64 is None: | |
# logger.info("Process ID %s is still in progress.", process_id) | |
# return jsonify({ | |
# "status": "in_progress" | |
# }) | |
# else: | |
# logger.info("Process ID %s completed successfully.", process_id) | |
# return jsonify({ | |
# "status": "completed", | |
# "output_image": image_base64 | |
# }) | |
# if __name__ == '__main__': | |
# app.run(debug=True,host='0.0.0.0') | |
# import logging | |
# import random | |
# import warnings | |
# import gradio as gr | |
# import os | |
# import shutil | |
# import subprocess | |
# import torch | |
# import numpy as np | |
# from diffusers import FluxControlNetModel | |
# from diffusers.pipelines import FluxControlNetPipeline | |
# from PIL import Image | |
# from huggingface_hub import snapshot_download, login | |
# import io | |
# import base64 | |
# import threading | |
# # Configure logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # Function to check disk usage | |
# def check_disk_space(): | |
# result = subprocess.run(['df', '-h'], capture_output=True, text=True) | |
# logger.info("Disk space usage:\n%s", result.stdout) | |
# # Function to clear Hugging Face cache | |
# def clear_huggingface_cache(): | |
# cache_dir = os.path.expanduser('~/.cache/huggingface') | |
# if os.path.exists(cache_dir): | |
# shutil.rmtree(cache_dir) # Removes the entire cache directory | |
# logger.info("Cleared Hugging Face cache at: %s", cache_dir) | |
# else: | |
# logger.info("No Hugging Face cache found.") | |
# # Check disk space | |
# check_disk_space() | |
# # Clear Hugging Face cache | |
# clear_huggingface_cache() | |
# # Add config to store base64 images | |
# image_outputs = {} | |
# # Determine the device (GPU or CPU) | |
# if torch.cuda.is_available(): | |
# device = "cuda" | |
# logger.info("CUDA is available. Using GPU.") | |
# else: | |
# device = "cpu" | |
# logger.info("CUDA is not available. Using CPU.") | |
# # Load model from Huggingface Hub | |
# huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
# if huggingface_token: | |
# login(token=huggingface_token) | |
# logger.info("Hugging Face token found and logged in.") | |
# else: | |
# logger.warning("Hugging Face token not found in environment variables.") | |
# # Download model using snapshot_download | |
# model_path = snapshot_download( | |
# repo_id="black-forest-labs/FLUX.1-dev", | |
# repo_type="model", | |
# ignore_patterns=["*.md", "*..gitattributes"], | |
# local_dir="FLUX.1-dev", | |
# token=huggingface_token | |
# ) | |
# logger.info("Model downloaded to: %s", model_path) | |
# # Load pipeline | |
# logger.info('Loading ControlNet model.') | |
# controlnet = FluxControlNetModel.from_pretrained( | |
# "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 | |
# ).to(device) | |
# logger.info("ControlNet model loaded successfully.") | |
# logger.info('Loading pipeline.') | |
# pipe = FluxControlNetPipeline.from_pretrained( | |
# model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 | |
# ).to(device) | |
# logger.info("Pipeline loaded successfully.") | |
# MAX_SEED = 1000000 | |
# MAX_PIXEL_BUDGET = 1024 * 1024 | |
# def process_input(input_image, upscale_factor): | |
# w, h = input_image.size | |
# aspect_ratio = w / h | |
# was_resized = False | |
# # Resize if input size exceeds the maximum pixel budget | |
# if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
# warnings.warn(f"Requested output image is too large. Resizing to fit within pixel budget.") | |
# input_image = input_image.resize( | |
# ( | |
# int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
# int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
# ) | |
# ) | |
# was_resized = True | |
# # Adjust dimensions to be a multiple of 8 | |
# w, h = input_image.size | |
# w = w - w % 8 | |
# h = h - h % 8 | |
# return input_image.resize((w, h)), was_resized | |
# def run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale): | |
# logger.info("Running inference") | |
# input_image, was_resized = process_input(input_image, upscale_factor) | |
# # Rescale image for ControlNet processing | |
# w, h = input_image.size | |
# control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) | |
# # Set the random generator for inference | |
# generator = torch.Generator().manual_seed(seed) | |
# # Perform inference using the pipeline | |
# logger.info("Running pipeline") | |
# image = pipe( | |
# prompt="", | |
# control_image=control_image, | |
# controlnet_conditioning_scale=controlnet_conditioning_scale, | |
# num_inference_steps=num_inference_steps, | |
# guidance_scale=3.5, | |
# height=control_image.size[1], | |
# width=control_image.size[0], | |
# generator=generator, | |
# ).images[0] | |
# # Resize output image back to the original dimensions if needed | |
# if was_resized: | |
# original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor) | |
# image = image.resize(original_size) | |
# # Convert the output image to base64 | |
# buffered = io.BytesIO() | |
# image.save(buffered, format="JPEG") | |
# image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# logger.info("Inference completed") | |
# return image_base64 | |
# # Define Gradio interface | |
# def gradio_interface(input_image, upscale_factor=4, seed=42, num_inference_steps=28, controlnet_conditioning_scale=0.6): | |
# if randomize_seed: | |
# seed = random.randint(0, MAX_SEED) | |
# logger.info("Seed randomized to: %d", seed) | |
# # Run inference | |
# output_image_base64 = run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale) | |
# return Image.open(io.BytesIO(base64.b64decode(output_image_base64))) | |
# # Create Gradio interface | |
# iface = gr.Interface( | |
# fn=gradio_interface, | |
# inputs=[ | |
# gr.Image(type="pil", label="Input Image"), | |
# gr.Slider(min=1, max=8, step=1, label="Upscale Factor"), | |
# gr.Slider(min=0, max=MAX_SEED, step=1, label="Seed"), | |
# gr.Slider(min=1, max=100, step=1, label="Inference Steps"), | |
# gr.Slider(min=0.0, max=1.0, step=0.1, label="ControlNet Conditioning Scale") | |
# ], | |
# outputs=gr.Image(label="Output Image"), | |
# title="ControlNet Image Upscaling", | |
# description="Upload an image to upscale using the ControlNet model." | |
# ) | |
# # Launch Gradio app | |
# if __name__ == '__main__': | |
# iface.launch() | |
# import logging | |
# import random | |
# import warnings | |
# import gradio as gr | |
# import os | |
# import shutil | |
# import spaces | |
# import subprocess | |
# import torch | |
# import numpy as np | |
# from diffusers import FluxControlNetModel | |
# from diffusers.pipelines import FluxControlNetPipeline | |
# from PIL import Image | |
# from huggingface_hub import snapshot_download, login | |
# import io | |
# import base64 | |
# from concurrent.futures import ThreadPoolExecutor | |
# # Configure logging | |
# logging.basicConfig(level=logging.INFO) | |
# logger = logging.getLogger(__name__) | |
# # ThreadPoolExecutor for managing image processing threads | |
# executor = ThreadPoolExecutor() | |
# # Determine the device (GPU or CPU) | |
# if torch.cuda.is_available(): | |
# device = "cuda" | |
# logger.info("CUDA is available. Using GPU.") | |
# else: | |
# device = "cpu" | |
# logger.info("CUDA is not available. Using CPU.") | |
# # Load model from Huggingface Hub | |
# huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
# if huggingface_token: | |
# login(token=huggingface_token) | |
# logger.info("Hugging Face token found and logged in.") | |
# else: | |
# logger.warning("Hugging Face token not found in environment variables.") | |
# # Download model using snapshot_download | |
# model_path = snapshot_download( | |
# repo_id="black-forest-labs/FLUX.1-dev", | |
# repo_type="model", | |
# ignore_patterns=["*.md", "*..gitattributes"], | |
# local_dir="FLUX.1-dev", | |
# token=huggingface_token | |
# ) | |
# logger.info("Model downloaded to: %s", model_path) | |
# # Load pipeline | |
# logger.info('Loading ControlNet model.') | |
# controlnet = FluxControlNetModel.from_pretrained( | |
# "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 | |
# ).to(device) | |
# logger.info("ControlNet model loaded successfully.") | |
# logger.info('Loading pipeline.') | |
# pipe = FluxControlNetPipeline.from_pretrained( | |
# model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 | |
# ).to(device) | |
# logger.info("Pipeline loaded successfully.") | |
# MAX_SEED = 1000000 | |
# MAX_PIXEL_BUDGET = 1024 * 1024 | |
# @spaces.GPU | |
# def process_input(input_image, upscale_factor): | |
# w, h = input_image.size | |
# aspect_ratio = w / h | |
# was_resized = False | |
# # Resize if input size exceeds the maximum pixel budget | |
# if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
# warnings.warn(f"Requested output image is too large. Resizing to fit within pixel budget.") | |
# input_image = input_image.resize( | |
# ( | |
# int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
# int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
# ) | |
# ) | |
# was_resized = True | |
# # Adjust dimensions to be a multiple of 8 | |
# w, h = input_image.size | |
# w = w - w % 8 | |
# h = h - h % 8 | |
# return input_image.resize((w, h)), was_resized | |
# @spaces.GPU | |
# def run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale): | |
# logger.info("Processing inference.") | |
# input_image, was_resized = process_input(input_image, upscale_factor) | |
# # Rescale image for ControlNet processing | |
# w, h = input_image.size | |
# control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) | |
# # Set the random generator for inference | |
# generator = torch.Generator().manual_seed(seed) | |
# # Perform inference using the pipeline | |
# logger.info("Running pipeline.") | |
# image = pipe( | |
# prompt="", | |
# control_image=control_image, | |
# controlnet_conditioning_scale=controlnet_conditioning_scale, | |
# num_inference_steps=num_inference_steps, | |
# guidance_scale=3.5, | |
# height=control_image.size[1], | |
# width=control_image.size[0], | |
# generator=generator, | |
# ).images[0] | |
# # Resize output image back to the original dimensions if needed | |
# if was_resized: | |
# original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor) | |
# image = image.resize(original_size) | |
# return image | |
# def run_gradio_app(): | |
# with gr.Blocks() as app: | |
# gr.Markdown("## Image Upscaler using ControlNet") | |
# # Define the inputs and outputs | |
# input_image = gr.Image(type="pil", label="Input Image") | |
# upscale_factor = gr.Slider(minimum=1, maximum=8, step=1, label="Upscale Factor") | |
# seed = gr.Slider(minimum=0, maximum=100, step=1, label="Seed") | |
# num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps") | |
# controlnet_conditioning_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="ControlNet Conditioning Scale") | |
# output_image = gr.Image(type="pil", label="Output Image") | |
# # Create a button to trigger the processing | |
# submit_button = gr.Button("Upscale Image") | |
# # Define the function to run when the button is clicked | |
# submit_button.click(run_inference, | |
# inputs=[input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale], | |
# outputs=output_image) | |
# app.launch() | |
# if __name__ == "__main__": | |
# run_gradio_app() | |
import logging | |
import random | |
import warnings | |
import gradio as gr | |
import os | |
import shutil,spaces | |
import subprocess | |
import torch | |
import numpy as np | |
from diffusers import FluxControlNetModel | |
from diffusers.pipelines import FluxControlNetPipeline | |
from PIL import Image | |
from huggingface_hub import snapshot_download, login | |
import io | |
import base64 | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from concurrent.futures import ThreadPoolExecutor | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# FastAPI app for image processing | |
app = FastAPI() | |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
# ThreadPoolExecutor for managing image processing threads | |
executor = ThreadPoolExecutor() | |
# Determine the device (GPU or CPU) | |
if torch.cuda.is_available(): | |
device = "cuda" | |
logger.info("CUDA is available. Using GPU.") | |
else: | |
device = "cpu" | |
logger.info("CUDA is not available. Using CPU.") | |
# Load model from Huggingface Hub | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
if huggingface_token: | |
login(token=huggingface_token) | |
logger.info("Hugging Face token found and logged in.") | |
else: | |
logger.warning("Hugging Face token not found in environment variables.") | |
# Download model using snapshot_download | |
model_path = snapshot_download( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
repo_type="model", | |
ignore_patterns=["*.md", "*..gitattributes"], | |
local_dir="FLUX.1-dev", | |
token=huggingface_token | |
) | |
logger.info("Model downloaded to: %s", model_path) | |
# Load pipeline | |
logger.info('Loading ControlNet model.') | |
controlnet = FluxControlNetModel.from_pretrained( | |
"jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 | |
).to(device) | |
logger.info("ControlNet model loaded successfully.") | |
logger.info('Loading pipeline.') | |
pipe = FluxControlNetPipeline.from_pretrained( | |
model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 | |
).to(device) | |
logger.info("Pipeline loaded successfully.") | |
MAX_SEED = 1000000 | |
MAX_PIXEL_BUDGET = 1024 * 1024 | |
def process_input(input_image, upscale_factor): | |
w, h = input_image.size | |
aspect_ratio = w / h | |
was_resized = False | |
# Resize if input size exceeds the maximum pixel budget | |
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
warnings.warn(f"Requested output image is too large. Resizing to fit within pixel budget.") | |
input_image = input_image.resize( | |
( | |
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
) | |
) | |
was_resized = True | |
# Adjust dimensions to be a multiple of 8 | |
w, h = input_image.size | |
w = w - w % 8 | |
h = h - h % 8 | |
return input_image.resize((w, h)), was_resized | |
def run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale): | |
logger.info("Processing inference.") | |
input_image, was_resized = process_input(input_image, upscale_factor) | |
# Rescale image for ControlNet processing | |
w, h = input_image.size | |
control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) | |
# Set the random generator for inference | |
generator = torch.Generator().manual_seed(seed) | |
# Perform inference using the pipeline | |
logger.info("Running pipeline.") | |
image = pipe( | |
prompt="", | |
control_image=control_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=3.5, | |
height=control_image.size[1], | |
width=control_image.size[0], | |
generator=generator, | |
).images[0] | |
# Resize output image back to the original dimensions if needed | |
if was_resized: | |
original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor) | |
image = image.resize(original_size) | |
# Convert the output image to base64 | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return image_base64 | |
async def infer(input_image: UploadFile = File(...), | |
upscale_factor: int = 4, | |
seed: int = 42, | |
num_inference_steps: int = 28, | |
controlnet_conditioning_scale: float = 0.6): | |
logger.info("Received request for inference.") | |
# Read the uploaded image | |
contents = await input_image.read() | |
image = Image.open(io.BytesIO(contents)) | |
# Run inference in a separate thread | |
base64_image = await executor.submit(run_inference, image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale) | |
return JSONResponse(content={"base64_image": base64_image}) | |
def run_gradio_app(): | |
with gr.Blocks() as app: | |
gr.Markdown("## Image Upscaler using ControlNet") | |
# Define the inputs and outputs | |
input_image = gr.Image(type="pil", label="Input Image") | |
upscale_factor = gr.Slider(minimum=1, maximum=8, step=1, label="Upscale Factor") | |
seed = gr.Slider(minimum=0, maximum=100, step=1, label="Seed") | |
num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps") | |
controlnet_conditioning_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="ControlNet Conditioning Scale") | |
output_image = gr.Image(type="pil", label="Output Image") | |
output_base64 = gr.Textbox(label="Base64 String", interactive=False) | |
# Create a button to trigger the processing | |
submit_button = gr.Button("Upscale Image") | |
# Define the function to run when the button is clicked | |
submit_button.click(run_inference, | |
inputs=[input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale], | |
outputs=[output_image, output_base64]) | |
app.launch() | |
if __name__ == "__main__": | |
run_gradio_app() | |