difix / app.py
shsolanki's picture
Difix Model Release
841615f
import gradio as gr
import numpy as np
from PIL import Image
import torch
import os
from pipeline_difix import DifixPipeline
from diffusers.utils import load_image
import gradio.themes as gr_themes
from pathlib import Path
import logging
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(), # Console output
logging.FileHandler('/tmp/difix3d_app.log', mode='a') # File output
]
)
logger = logging.getLogger(__name__)
# Configuration
MODEL_NAME = "nvidia/difix"
DEFAULT_PROMPT = "remove degradation"
DEFAULT_HEIGHT = 576
DEFAULT_WIDTH = 1024
DEFAULT_TIMESTEP = 199
DEFAULT_GUIDANCE_SCALE = 0.0
DEFAULT_NUM_INFERENCE_STEPS = 1
# Global pipeline variable
pipe = None
logger.info("=== Difix Demo Starting ===")
logger.info(f"MODEL_NAME: {MODEL_NAME}")
logger.info(f"Current working directory: {os.getcwd()}")
logger.info(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA Device: {torch.cuda.get_device_name()}")
logger.info(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Login to Hugging Face using environment variable
try:
from huggingface_hub import login
hf_token = os.getenv('HF_TOKEN')
# If not in environment, try reading from /etc/hf_env
if not hf_token and os.path.exists('/etc/hf_env'):
with open('/etc/hf_env', 'r') as f:
for line in f:
if line.strip().startswith('export HF_TOKEN='):
hf_token = line.strip().split('=', 1)[1]
break
elif line.strip().startswith('HF_TOKEN='):
hf_token = line.strip().split('=', 1)[1]
break
if hf_token:
login(hf_token)
logger.info("Successfully authenticated with Hugging Face")
else:
logger.warning("HF_TOKEN not found in environment or /etc/hf_env")
except Exception as e:
logger.error(f"Failed to authenticate with Hugging Face: {str(e)}")
def initialize_pipeline():
"""Initialize the Difix pipeline and perform warmup"""
global pipe
logger.info("Starting pipeline initialization...")
start_time = time.time()
try:
logger.info(f"Loading DifixPipeline from {MODEL_NAME}...")
# Initialize pipeline using the new approach
pipe = DifixPipeline.from_pretrained(MODEL_NAME, trust_remote_code=True)
logger.info("DifixPipeline loaded successfully")
logger.info("Moving pipeline to CUDA...")
if torch.cuda.is_available():
pipe.to("cuda")
logger.info("Pipeline moved to CUDA")
else:
logger.warning("CUDA not available, using CPU")
init_time = time.time() - start_time
logger.info(f"Pipeline initialization completed in {init_time:.2f} seconds")
# Warmup with dummy data
logger.info("Starting pipeline warmup...")
warmup_start = time.time()
try:
# Create dummy image with the model's expected resolution
dummy_image = Image.new('RGB', (DEFAULT_WIDTH, DEFAULT_HEIGHT), color='red')
logger.info(f"Created dummy image: {dummy_image.size}")
_ = pipe(
DEFAULT_PROMPT,
image=dummy_image,
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
timesteps=[DEFAULT_TIMESTEP],
guidance_scale=DEFAULT_GUIDANCE_SCALE
).images[0]
warmup_time = time.time() - warmup_start
logger.info(f"Pipeline warmup completed successfully in {warmup_time:.2f} seconds")
except Exception as e:
logger.warning(f"Pipeline warmup failed: {e}")
except Exception as e:
logger.error(f"Pipeline initialization failed: {e}")
raise
def process_image(image):
"""
Process the input image using the Difix pipeline to remove artifacts.
"""
global pipe
if image is None:
return None, "Error: No image provided"
if pipe is None:
error_msg = "Pipeline not initialized"
return None, f"Error: {error_msg}"
try:
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Ensure image is in RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
# Process the image using Difix pipeline
output_image = pipe(
DEFAULT_PROMPT,
image=image,
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
timesteps=[DEFAULT_TIMESTEP],
guidance_scale=DEFAULT_GUIDANCE_SCALE
).images[0]
return output_image, None
except Exception as e:
error_msg = f"Error processing image: {str(e)}"
logger.error(error_msg)
return None, error_msg
def gradio_interface(image):
"""Wrapper function for Gradio interface"""
result, error = process_image(image)
if error:
gr.Warning(error)
return None
return result
# Initialize pipeline at startup
logger.info("=== Starting Pipeline Initialization ===")
try:
initialize_pipeline()
model_status = "βœ… Pipeline loaded successfully"
logger.info("Pipeline initialization successful")
except Exception as e:
model_status = f"❌ Pipeline initialization failed: {e}"
logger.error(f"Pipeline initialization error: {e}")
logger.info("=== Creating UI Components ===")
# Article content for Difix
article = (
"<p style='font-size: 1.1em;'>"
"This demo showcases <strong>Difix</strong>, a single-step image diffusion model trained to enhance and remove artifacts in rendered novel views caused by underconstrained regions of 3D representation."
"</p>"
"<p><strong style='color: #76B900; font-size: 1.2em;'>Key Features:</strong></p>"
"<ul style='font-size: 1.1em;'>"
" <li>Single-step diffusion-based artifact removal for 3D novel views</li>"
" <li>Enhancement of underconstrained 3D regions (1024x576 default)</li>"
"</ul>"
"<p style='font-size: 1.1em;'>"
f"<strong>Model Status:</strong> {model_status}"
"</p>"
"<p style='font-size: 1.0em; color: #666;'>"
"Upload an image to see the restoration capabilities of Difix+. The model will automatically process your image and return an enhanced version."
"</p>"
"<p style='text-align: center;'>"
"<a href='https://github.com/nv-tlabs/Difix3D' target='_blank'>πŸ§‘β€πŸ’» GitHub Repository</a> | "
"<a href='https://arxiv.org/abs/2503.01774' target='_blank'>πŸ“„ Research Paper</a> | "
"<a href='https://huggingface.co/nvidia/difix' target='_blank'>πŸ€— Hugging Face Model</a>"
"</p>"
)
logger.info("Creating theme...")
# Define a modern green-inspired theme similar to NVIDIA
difix_theme = gr_themes.Default(
primary_hue=gr_themes.Color(
c50="#E6F7E6", # Lightest green
c100="#CCF2CC",
c200="#99E699",
c300="#66D966",
c400="#33CC33",
c500="#00B300", # Primary green
c600="#009900",
c700="#007A00",
c800="#005C00",
c900="#003D00", # Darkest green
c950="#002600"
),
secondary_hue=gr_themes.Color(
c50="#F0F8FF", # Light blue accent
c100="#E0F0FF",
c200="#C0E0FF",
c300="#A0D0FF",
c400="#80C0FF",
c500="#4A90E2", # Secondary blue
c600="#3A80D2",
c700="#2A70C2",
c800="#1A60B2",
c900="#0A50A2",
c950="#004092"
),
neutral_hue="slate",
font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
).set(
body_background_fill="*neutral_50",
block_background_fill="white",
block_border_width="1px",
block_border_color="*neutral_200",
block_radius="8px",
block_shadow="0 2px 4px rgba(0,0,0,0.1)",
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_600",
button_primary_text_color="white",
)
logger.info("Creating Gradio interface...")
# Create Gradio interface with Blocks for better control
with gr.Blocks(theme=difix_theme, title="Difix") as demo:
gr.Markdown("<h1 style='text-align: center; margin: 0 auto; color: #00B300;'>🎨 Difix</h1>")
gr.HTML(article)
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="πŸ“€ Upload Image for Restoration",
height=400
)
process_btn = gr.Button(
"πŸš€ Fix Image",
variant="primary",
size="lg"
)
gr.Examples(
examples=["assets/example1.png","assets/example2.png"],
inputs=[input_image],
label="πŸ“‹ Example Images"
)
with gr.Column(scale=1):
output_image = gr.Image(
type="pil",
label="✨ Fixed Image",
height=400
)
gr.Markdown("---")
gr.Markdown(
"<p style='text-align: center; color: #666; font-size: 0.9em;'>"
f"Model: {MODEL_NAME} | "
f"Resolution: {DEFAULT_WIDTH}Γ—{DEFAULT_HEIGHT} | "
f"Prompt: '{DEFAULT_PROMPT}' | "
f"Steps: {DEFAULT_NUM_INFERENCE_STEPS} | "
f"Timestep: {DEFAULT_TIMESTEP} | "
f"Guidance Scale: {DEFAULT_GUIDANCE_SCALE}"
"</p>"
)
# Event handlers
process_btn.click(
fn=gradio_interface,
inputs=[input_image],
outputs=[output_image],
api_name="restore_image"
)
logger.info("Configuring queue...")
# Configure queueing for better performance
demo.queue(
default_concurrency_limit=2, # Process up to 2 requests simultaneously
max_size=20, # Maximum 20 users can wait in queue
)
logger.info("=== Gradio Interface Created Successfully ===")
if __name__ == "__main__":
logger.info("=== Starting Gradio Launch ===")
logger.info(f"Server config: 0.0.0.0:7860, max_threads=10")
# Set up file access for assets directory
assets_path = Path("assets").absolute()
if assets_path.exists():
logger.info(f"Setting up file access for assets directory: {assets_path}")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
max_threads=10,
# Allow access to assets directory
allowed_paths=[str(assets_path)] if assets_path.exists() else []
)