NightRaven109's picture
Update app.py
3f35512 verified
import os
import torch
import gradio as gr
import spaces
from PIL import Image
from diffusers import DiffusionPipeline
from huggingface_hub import snapshot_download
from test_ccsr_tile import load_pipeline
import argparse
from accelerate import Accelerator
# Global variables
class ModelContainer:
def __init__(self):
self.pipeline = None
self.generator = None
self.accelerator = None
self.is_initialized = False
model_container = ModelContainer()
class Args:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@spaces.GPU
def initialize_models():
"""Initialize models only if they haven't been initialized yet"""
if model_container.is_initialized:
return True
try:
# Download model repository (only once)
model_path = snapshot_download(
repo_id="NightRaven109/CCSRModels",
token=os.environ['Read2']
)
# Set up default arguments
args = Args(
pretrained_model_path=os.path.join(model_path, "stable-diffusion-2-1-base"),
controlnet_model_path=os.path.join(model_path, "Controlnet"),
vae_model_path=os.path.join(model_path, "vae"),
mixed_precision="fp16",
tile_vae=False,
sample_method="ddpm",
vae_encoder_tile_size=1024,
vae_decoder_tile_size=224
)
# Initialize accelerator
model_container.accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
# Load pipeline
model_container.pipeline = load_pipeline(args, model_container.accelerator,
enable_xformers_memory_efficient_attention=False)
# Set models to eval mode
model_container.pipeline.unet.eval()
model_container.pipeline.controlnet.eval()
model_container.pipeline.vae.eval()
model_container.pipeline.text_encoder.eval()
# Move pipeline to CUDA and set to eval mode once
model_container.pipeline = model_container.pipeline.to("cuda")
# Initialize generator
model_container.generator = torch.Generator("cuda")
# Set initialization flag
model_container.is_initialized = True
return True
except Exception as e:
print(f"Error initializing models: {str(e)}")
return False
@torch.no_grad() # Add no_grad decorator for inference
@spaces.GPU
def process_image(
input_image,
prompt="clean, texture, high-resolution, 8k",
negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
guidance_scale=2.5,
conditioning_scale=1.0,
num_inference_steps=6,
seed=None,
upscale_factor=4,
color_fix_method="adain"
):
# Initialize models if not already done
if not model_container.is_initialized:
if not initialize_models():
return None
try:
# Create args object
args = Args(
added_prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
conditioning_scale=conditioning_scale,
num_inference_steps=num_inference_steps,
seed=seed,
upscale=upscale_factor,
process_size=512,
align_method=color_fix_method,
t_max=0.6666,
t_min=0.0,
tile_diffusion=False,
tile_diffusion_size=None,
tile_diffusion_stride=None,
start_steps=999,
start_point='lr',
use_vae_encode_condition=True,
sample_times=1
)
# Set seed if provided
if seed is not None:
model_container.generator.manual_seed(seed)
# Process input image
validation_image = Image.fromarray(input_image)
ori_width, ori_height = validation_image.size
# Resize logic
resize_flag = False
if ori_width < args.process_size//args.upscale or ori_height < args.process_size//args.upscale:
scale = (args.process_size//args.upscale)/min(ori_width, ori_height)
validation_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
resize_flag = True
validation_image = validation_image.resize((validation_image.size[0]*args.upscale, validation_image.size[1]*args.upscale))
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
width, height = validation_image.size
# Generate image
inference_time, output = model_container.pipeline(
args.t_max,
args.t_min,
args.tile_diffusion,
args.tile_diffusion_size,
args.tile_diffusion_stride,
args.added_prompt,
validation_image,
num_inference_steps=args.num_inference_steps,
generator=model_container.generator,
height=height,
width=width,
guidance_scale=args.guidance_scale,
negative_prompt=args.negative_prompt,
conditioning_scale=args.conditioning_scale,
start_steps=args.start_steps,
start_point=args.start_point,
use_vae_encode_condition=True,
)
image = output.images[0]
# Apply color fixing if specified
if args.align_method != "none":
from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
fix_func = wavelet_color_fix if args.align_method == "wavelet" else adain_color_fix
image = fix_func(image, validation_image)
if resize_flag:
image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
return image
except Exception as e:
print(f"Error processing image: {str(e)}")
import traceback
traceback.print_exc()
return None
# Define default values
DEFAULT_VALUES = {
"prompt": "clean, texture, high-resolution, 8k",
"negative_prompt": "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
"guidance_scale": 3,
"conditioning_scale": 1.0,
"num_steps": 6,
"seed": None,
"upscale_factor": 4,
"color_fix_method": "adain"
}
# Define example data
EXAMPLES = [
[
"examples/1.png", # Input image path
"clean, texture, high-resolution, 8k", # Prompt
"blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", # Negative prompt
3.0, # Guidance scale
1.0, # Conditioning scale
6, # Num steps
42, # Seed
4, # Upscale factor
"wavelet" # Color fix method
],
[
"examples/22.png",
"clean, texture, high-resolution, 8k",
"blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
3.0,
1.0,
6,
123,
4,
"wavelet"
],
[
"examples/4.png",
"clean, texture, high-resolution, 8k",
"blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
3.0,
1.0,
6,
123,
4,
"wavelet"
],
[
"examples/9D03D7F206775949.png",
"clean, texture, high-resolution, 8k",
"blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
3.0,
1.0,
6,
123,
4,
"wavelet"
],
[
"examples/3.jpeg",
"clean, texture, high-resolution, 8k",
"blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
2.5,
1.0,
6,
456,
4,
"wavelet"
]
]
# Create interface components
with gr.Blocks(title="Controllable Conditional Super-Resolution") as demo:
gr.Markdown("## Controllable Conditional Super-Resolution")
gr.Markdown("Upload an image to enhance its resolution using CCSR.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image")
with gr.Accordion("Advanced Options", open=False):
prompt = gr.Textbox(label="Prompt", value=DEFAULT_VALUES["prompt"])
negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_VALUES["negative_prompt"])
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=DEFAULT_VALUES["guidance_scale"], label="Guidance Scale")
conditioning_scale = gr.Slider(minimum=0.1, maximum=2.0, value=DEFAULT_VALUES["conditioning_scale"], label="Conditioning Scale")
num_steps = gr.Slider(minimum=1, maximum=50, value=DEFAULT_VALUES["num_steps"], step=1, label="Number of Steps")
seed = gr.Number(label="Seed", value=DEFAULT_VALUES["seed"])
upscale_factor = gr.Slider(minimum=1, maximum=8, value=DEFAULT_VALUES["upscale_factor"], step=1, label="Upscale Factor")
color_fix_method = gr.Dropdown(
choices=["none", "wavelet", "adain"],
label="Color Fix Method",
value=DEFAULT_VALUES["color_fix_method"]
)
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image")
# Add examples
gr.Examples(
examples=EXAMPLES,
inputs=[
input_image, prompt, negative_prompt, guidance_scale,
conditioning_scale, num_steps, seed, upscale_factor,
color_fix_method
],
outputs=output_image,
fn=process_image,
cache_examples=True # Cache the results for faster loading
)
# Define submit action
submit_btn.click(
fn=process_image,
inputs=[
input_image, prompt, negative_prompt, guidance_scale,
conditioning_scale, num_steps, seed, upscale_factor,
color_fix_method
],
outputs=output_image
)
# Define clear action that resets to default values
def reset_to_defaults():
return [
None, # input_image
DEFAULT_VALUES["prompt"],
DEFAULT_VALUES["negative_prompt"],
DEFAULT_VALUES["guidance_scale"],
DEFAULT_VALUES["conditioning_scale"],
DEFAULT_VALUES["num_steps"],
DEFAULT_VALUES["seed"],
DEFAULT_VALUES["upscale_factor"],
DEFAULT_VALUES["color_fix_method"]
]
clear_btn.click(
fn=reset_to_defaults,
inputs=None,
outputs=[
input_image, prompt, negative_prompt, guidance_scale,
conditioning_scale, num_steps, seed, upscale_factor,
color_fix_method
]
)
if __name__ == "__main__":
demo.launch()