|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import os |
|
from pipeline_difix import DifixPipeline |
|
from diffusers.utils import load_image |
|
import gradio.themes as gr_themes |
|
from pathlib import Path |
|
import logging |
|
import time |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(), |
|
logging.FileHandler('/tmp/difix3d_app.log', mode='a') |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_NAME = "nvidia/difix" |
|
DEFAULT_PROMPT = "remove degradation" |
|
DEFAULT_HEIGHT = 576 |
|
DEFAULT_WIDTH = 1024 |
|
DEFAULT_TIMESTEP = 199 |
|
DEFAULT_GUIDANCE_SCALE = 0.0 |
|
DEFAULT_NUM_INFERENCE_STEPS = 1 |
|
|
|
|
|
pipe = None |
|
|
|
logger.info("=== Difix Demo Starting ===") |
|
logger.info(f"MODEL_NAME: {MODEL_NAME}") |
|
logger.info(f"Current working directory: {os.getcwd()}") |
|
logger.info(f"CUDA Available: {torch.cuda.is_available()}") |
|
if torch.cuda.is_available(): |
|
logger.info(f"CUDA Device: {torch.cuda.get_device_name()}") |
|
logger.info(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
|
|
|
try: |
|
from huggingface_hub import login |
|
hf_token = os.getenv('HF_TOKEN') |
|
|
|
|
|
if not hf_token and os.path.exists('/etc/hf_env'): |
|
with open('/etc/hf_env', 'r') as f: |
|
for line in f: |
|
if line.strip().startswith('export HF_TOKEN='): |
|
hf_token = line.strip().split('=', 1)[1] |
|
break |
|
elif line.strip().startswith('HF_TOKEN='): |
|
hf_token = line.strip().split('=', 1)[1] |
|
break |
|
|
|
if hf_token: |
|
login(hf_token) |
|
logger.info("Successfully authenticated with Hugging Face") |
|
else: |
|
logger.warning("HF_TOKEN not found in environment or /etc/hf_env") |
|
except Exception as e: |
|
logger.error(f"Failed to authenticate with Hugging Face: {str(e)}") |
|
|
|
def initialize_pipeline(): |
|
"""Initialize the Difix pipeline and perform warmup""" |
|
global pipe |
|
|
|
logger.info("Starting pipeline initialization...") |
|
start_time = time.time() |
|
|
|
try: |
|
logger.info(f"Loading DifixPipeline from {MODEL_NAME}...") |
|
|
|
pipe = DifixPipeline.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
logger.info("DifixPipeline loaded successfully") |
|
|
|
logger.info("Moving pipeline to CUDA...") |
|
if torch.cuda.is_available(): |
|
pipe.to("cuda") |
|
logger.info("Pipeline moved to CUDA") |
|
else: |
|
logger.warning("CUDA not available, using CPU") |
|
|
|
init_time = time.time() - start_time |
|
logger.info(f"Pipeline initialization completed in {init_time:.2f} seconds") |
|
|
|
|
|
logger.info("Starting pipeline warmup...") |
|
warmup_start = time.time() |
|
try: |
|
|
|
dummy_image = Image.new('RGB', (DEFAULT_WIDTH, DEFAULT_HEIGHT), color='red') |
|
logger.info(f"Created dummy image: {dummy_image.size}") |
|
|
|
_ = pipe( |
|
DEFAULT_PROMPT, |
|
image=dummy_image, |
|
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, |
|
timesteps=[DEFAULT_TIMESTEP], |
|
guidance_scale=DEFAULT_GUIDANCE_SCALE |
|
).images[0] |
|
|
|
warmup_time = time.time() - warmup_start |
|
logger.info(f"Pipeline warmup completed successfully in {warmup_time:.2f} seconds") |
|
except Exception as e: |
|
logger.warning(f"Pipeline warmup failed: {e}") |
|
|
|
except Exception as e: |
|
logger.error(f"Pipeline initialization failed: {e}") |
|
raise |
|
|
|
def process_image(image): |
|
""" |
|
Process the input image using the Difix pipeline to remove artifacts. |
|
""" |
|
global pipe |
|
|
|
if image is None: |
|
return None, "Error: No image provided" |
|
|
|
if pipe is None: |
|
error_msg = "Pipeline not initialized" |
|
return None, f"Error: {error_msg}" |
|
|
|
try: |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
output_image = pipe( |
|
DEFAULT_PROMPT, |
|
image=image, |
|
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, |
|
timesteps=[DEFAULT_TIMESTEP], |
|
guidance_scale=DEFAULT_GUIDANCE_SCALE |
|
).images[0] |
|
|
|
return output_image, None |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing image: {str(e)}" |
|
logger.error(error_msg) |
|
return None, error_msg |
|
|
|
def gradio_interface(image): |
|
"""Wrapper function for Gradio interface""" |
|
result, error = process_image(image) |
|
if error: |
|
gr.Warning(error) |
|
return None |
|
return result |
|
|
|
|
|
logger.info("=== Starting Pipeline Initialization ===") |
|
try: |
|
initialize_pipeline() |
|
model_status = "β
Pipeline loaded successfully" |
|
logger.info("Pipeline initialization successful") |
|
except Exception as e: |
|
model_status = f"β Pipeline initialization failed: {e}" |
|
logger.error(f"Pipeline initialization error: {e}") |
|
|
|
logger.info("=== Creating UI Components ===") |
|
|
|
|
|
article = ( |
|
"<p style='font-size: 1.1em;'>" |
|
"This demo showcases <strong>Difix</strong>, a single-step image diffusion model trained to enhance and remove artifacts in rendered novel views caused by underconstrained regions of 3D representation." |
|
"</p>" |
|
"<p><strong style='color: #76B900; font-size: 1.2em;'>Key Features:</strong></p>" |
|
"<ul style='font-size: 1.1em;'>" |
|
" <li>Single-step diffusion-based artifact removal for 3D novel views</li>" |
|
" <li>Enhancement of underconstrained 3D regions (1024x576 default)</li>" |
|
"</ul>" |
|
"<p style='font-size: 1.1em;'>" |
|
f"<strong>Model Status:</strong> {model_status}" |
|
"</p>" |
|
"<p style='font-size: 1.0em; color: #666;'>" |
|
"Upload an image to see the restoration capabilities of Difix+. The model will automatically process your image and return an enhanced version." |
|
"</p>" |
|
|
|
"<p style='text-align: center;'>" |
|
"<a href='https://github.com/nv-tlabs/Difix3D' target='_blank'>π§βπ» GitHub Repository</a> | " |
|
"<a href='https://arxiv.org/abs/2503.01774' target='_blank'>π Research Paper</a> | " |
|
"<a href='https://huggingface.co/nvidia/difix' target='_blank'>π€ Hugging Face Model</a>" |
|
"</p>" |
|
) |
|
|
|
logger.info("Creating theme...") |
|
|
|
difix_theme = gr_themes.Default( |
|
primary_hue=gr_themes.Color( |
|
c50="#E6F7E6", |
|
c100="#CCF2CC", |
|
c200="#99E699", |
|
c300="#66D966", |
|
c400="#33CC33", |
|
c500="#00B300", |
|
c600="#009900", |
|
c700="#007A00", |
|
c800="#005C00", |
|
c900="#003D00", |
|
c950="#002600" |
|
), |
|
secondary_hue=gr_themes.Color( |
|
c50="#F0F8FF", |
|
c100="#E0F0FF", |
|
c200="#C0E0FF", |
|
c300="#A0D0FF", |
|
c400="#80C0FF", |
|
c500="#4A90E2", |
|
c600="#3A80D2", |
|
c700="#2A70C2", |
|
c800="#1A60B2", |
|
c900="#0A50A2", |
|
c950="#004092" |
|
), |
|
neutral_hue="slate", |
|
font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
).set( |
|
body_background_fill="*neutral_50", |
|
block_background_fill="white", |
|
block_border_width="1px", |
|
block_border_color="*neutral_200", |
|
block_radius="8px", |
|
block_shadow="0 2px 4px rgba(0,0,0,0.1)", |
|
button_primary_background_fill="*primary_500", |
|
button_primary_background_fill_hover="*primary_600", |
|
button_primary_text_color="white", |
|
) |
|
|
|
logger.info("Creating Gradio interface...") |
|
|
|
with gr.Blocks(theme=difix_theme, title="Difix") as demo: |
|
gr.Markdown("<h1 style='text-align: center; margin: 0 auto; color: #00B300;'>π¨ Difix</h1>") |
|
gr.HTML(article) |
|
|
|
gr.Markdown("---") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image( |
|
type="pil", |
|
label="π€ Upload Image for Restoration", |
|
height=400 |
|
) |
|
|
|
process_btn = gr.Button( |
|
"π Fix Image", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
gr.Examples( |
|
examples=["assets/example1.png","assets/example2.png"], |
|
inputs=[input_image], |
|
label="π Example Images" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
output_image = gr.Image( |
|
type="pil", |
|
label="β¨ Fixed Image", |
|
height=400 |
|
) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown( |
|
"<p style='text-align: center; color: #666; font-size: 0.9em;'>" |
|
f"Model: {MODEL_NAME} | " |
|
f"Resolution: {DEFAULT_WIDTH}Γ{DEFAULT_HEIGHT} | " |
|
f"Prompt: '{DEFAULT_PROMPT}' | " |
|
f"Steps: {DEFAULT_NUM_INFERENCE_STEPS} | " |
|
f"Timestep: {DEFAULT_TIMESTEP} | " |
|
f"Guidance Scale: {DEFAULT_GUIDANCE_SCALE}" |
|
"</p>" |
|
) |
|
|
|
|
|
process_btn.click( |
|
fn=gradio_interface, |
|
inputs=[input_image], |
|
outputs=[output_image], |
|
api_name="restore_image" |
|
) |
|
|
|
logger.info("Configuring queue...") |
|
|
|
demo.queue( |
|
default_concurrency_limit=2, |
|
max_size=20, |
|
) |
|
|
|
logger.info("=== Gradio Interface Created Successfully ===") |
|
|
|
if __name__ == "__main__": |
|
logger.info("=== Starting Gradio Launch ===") |
|
logger.info(f"Server config: 0.0.0.0:7860, max_threads=10") |
|
|
|
|
|
assets_path = Path("assets").absolute() |
|
if assets_path.exists(): |
|
logger.info(f"Setting up file access for assets directory: {assets_path}") |
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
max_threads=10, |
|
|
|
allowed_paths=[str(assets_path)] if assets_path.exists() else [] |
|
) |
|
|