Listto3d / app.py
Rogerjs's picture
Create app.py
9e91681 verified
import gradio as gr
import os
import zipfile
import time
import uuid # For unique filenames
# --- LLM/Model Setup ---
from transformers import pipeline as transformers_pipeline # For local list generation
from huggingface_hub import InferenceClient # For prompt refinement via API
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler # For image generation
from gradio_client import Client as GradioClient, handle_file # For 3D generation
# --- Configuration ---
# Consider making these configurable in the UI later
LIST_GENERATION_MODEL = "google/flan-t5-base" # Or another suitable small model
PROMPT_REFINEMENT_MODEL_API = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Or another instruct model via Inference API
IMAGE_GENERATION_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # Or "runwayml/stable-diffusion-v1-5"
HUNYUAN_SPACE_ID = "tencent/Hunyuan3D-2"
OUTPUT_DIR = "outputs"
MODELS_SUBDIR = "3d_models"
IMAGES_SUBDIR = "image_previews"
ZIP_FILENAME = "3d_collection.zip"
# --- Initialize Clients/Pipelines (can be slow, consider loading on demand if needed) ---
# Use HF Token from Space secrets if available/needed for Inference API
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
# Basic List Generator (local)
try:
list_generator = transformers_pipeline("text2text-generation", model=LIST_GENERATION_MODEL)
except Exception as e:
print(f"Warning: Could not load local list generator {LIST_GENERATION_MODEL}: {e}")
list_generator = None
# Prompt Refiner (API)
try:
if not HF_TOKEN:
print("Warning: HUGGINGFACE_TOKEN not set. Inference API calls might be rate-limited or fail.")
prompt_refiner = InferenceClient(model=PROMPT_REFINEMENT_MODEL_API, token=HF_TOKEN)
except Exception as e:
print(f"Warning: Could not initialize InferenceClient for {PROMPT_REFINEMENT_MODEL_API}: {e}")
prompt_refiner = None
# Image Generator (Local with Diffusers - requires GPU on Space for reasonable speed)
# Or consider an Image Gen API service if running on CPU hardware
try:
# Using XL as an example - adjust based on available hardware
image_pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_GENERATION_MODEL, torch_dtype=torch.float16, use_safetensors=True)
# Move to GPU if available (check Space hardware)
# image_pipeline.to("cuda") # Uncomment if GPU is available
image_pipeline.scheduler = EulerDiscreteScheduler.from_config(image_pipeline.scheduler.config)
except Exception as e:
print(f"Warning: Could not load diffusers pipeline {IMAGE_GENERATION_MODEL}. Image generation might fail: {e}")
image_pipeline = None
# 3D Generator Client
try:
hunyuan_client = GradioClient(HUNYUAN_SPACE_ID)
except Exception as e:
print(f"Error initializing GradioClient for {HUNYUAN_SPACE_ID}: {e}")
hunyuan_client = None
# --- Helper Functions ---
def generate_list_local(theme, count):
if not list_generator:
return ["Error: List generator model not loaded."]
prompt = f"Generate a comma-separated list of {count} distinct types of {theme}."
try:
result = list_generator(prompt, max_length=200)[0]['generated_text']
items = [item.strip() for item in result.split(',') if item.strip()]
return items[:count] # Ensure we don't exceed the requested count
except Exception as e:
print(f"Error generating list: {e}")
return [f"Error: {e}"]
def refine_prompt_api(item_name):
if not prompt_refiner:
return f"A 3D model of a {item_name}" # Fallback basic prompt
prompt = f"Create a detailed, descriptive prompt for generating a highly realistic image of a single '{item_name}'. Focus on visual details suitable for a text-to-image AI. Only output the prompt itself."
try:
refined = prompt_refiner.text_generation(prompt, max_new_tokens=100)
# Clean up potential API artifacts if necessary
refined = refined.strip().strip('"')
return refined
except Exception as e:
print(f"Error refining prompt for '{item_name}': {e}")
# Fallback to a simpler prompt for 3D generation if refinement fails
return f"A high quality 3D model of a {item_name}"
def generate_image_local(refined_prompt, output_path):
if not image_pipeline:
print("Image generation pipeline not available.")
# Create a placeholder image or return None
# Example: from PIL import Image; img = Image.new('RGB', (60, 30), color = 'red'); img.save(output_path); return output_path
return None
try:
# Adjust inference steps/guidance as needed
image = image_pipeline(refined_prompt, num_inference_steps=25, guidance_scale=7.5).images[0]
image.save(output_path)
return output_path
except Exception as e:
print(f"Error generating image for prompt '{refined_prompt}': {e}")
return None
def generate_3d_model_hunyuan(refined_prompt_for_3d, output_dir, item_name_safe):
if not hunyuan_client:
print("Hunyuan 3D client not available.")
return None, "Client not initialized"
print(f"Requesting 3D model for: {refined_prompt_for_3d}")
# Use defaults for most parameters initially
try:
result_tuple = hunyuan_client.predict(
caption=refined_prompt_for_3d,
# Leave image and mv_image inputs as None for text-to-3D
image=None,
mv_image_front=None,
mv_image_back=None,
mv_image_left=None,
mv_image_right=None,
# Default values from API docs (can be overridden)
steps=30,
guidance_scale=5,
seed=1234, # Or use randomize_seed=True
octree_resolution=256,
check_box_rembg=True,
num_chunks=8000,
randomize_seed=True,
api_name="/generation_all" # Crucial!
)
# --- VERIFICATION NEEDED ---
# Check the actual return tuple structure. Assuming file path is first or second.
# Let's try the first element (index 0). If it's None or not a path, try index 1.
raw_filepath = None
if result_tuple and len(result_tuple) > 0 and isinstance(result_tuple[0], str):
raw_filepath = result_tuple[0]
elif result_tuple and len(result_tuple) > 1 and isinstance(result_tuple[1], str):
print("Using second element from result tuple for filepath.")
raw_filepath = result_tuple[1]
# --- END VERIFICATION NEEDED ---
if raw_filepath:
print(f"Job completed. Raw result path: {raw_filepath}")
os.makedirs(output_dir, exist_ok=True)
# Download the file using handle_file which manages temp paths etc.
# handle_file saves with a potentially random name in download_dir
downloaded_temp_path = handle_file(raw_filepath, download_dir=output_dir)
if downloaded_temp_path and os.path.exists(downloaded_temp_path):
# Rename it to something meaningful
file_ext = os.path.splitext(downloaded_temp_path)[1] # Get extension (.glb, .obj?)
if not file_ext: file_ext = ".glb" # Assume glb if unknown
final_path = os.path.join(output_dir, f"{item_name_safe}{file_ext}")
os.rename(downloaded_temp_path, final_path)
print(f"Model saved to: {final_path}")
return final_path, "Success"
else:
error_msg = f"handle_file failed to download or returned invalid path: {downloaded_temp_path}"
print(error_msg)
return None, error_msg
else:
error_msg = f"Job for '{refined_prompt_for_3d}' did not return a valid filepath in expected tuple elements."
print(error_msg)
# You might want to inspect the full result_tuple here for debugging
print(f"Full result tuple: {result_tuple}")
return None, error_msg
except Exception as e:
error_msg = f"Error calling Hunyuan3D API for '{refined_prompt_for_3d}': {e}"
print(error_msg)
return None, str(e)
def create_zip(files_to_zip, zip_filepath):
with zipfile.ZipFile(zip_filepath, 'w') as zf:
for file_path in files_to_zip:
if file_path and os.path.exists(file_path):
zf.write(file_path, os.path.basename(file_path))
return zip_filepath
# --- Gradio Interface & Logic ---
with gr.Blocks() as demo:
gr.Markdown("# 3D Asset Collection Generator")
gr.Markdown("Generate a list based on a theme, refine prompts, preview images, and generate selected 3D models using Hunyuan3D-2.")
if not HF_TOKEN:
gr.Warning("Hugging Face Token not found. Prompt refinement quality/rate limits may be affected. Consider adding HUGGINGFACE_TOKEN to Space secrets.")
if not image_pipeline:
gr.Warning("Local Image Generation model failed to load. Image previews will be skipped. Check Space hardware/logs.")
if not hunyuan_client:
gr.Error("Failed to connect to the Hunyuan3D-2 Space. 3D generation will not work.")
# State to hold intermediate results
# Using gr.State is good for simple values, for complex lists/dicts might need alternatives or careful handling
list_items_state = gr.State([])
refined_prompts_state = gr.State({}) # Dict: {item_name: refined_prompt}
image_paths_state = gr.State({}) # Dict: {item_name: image_path}
selected_items_state = gr.State([]) # List of item_names selected by user
generated_3d_files_state = gr.State([]) # List of paths to successfully generated models
with gr.Row():
theme_input = gr.Textbox(label="Theme", placeholder="e.g., reptiles, kitchen appliances, medieval weapons")
count_input = gr.Number(label="Number of Items", value=5, minimum=1, step=1)
generate_list_button = gr.Button("1. Generate List & Refine Prompts")
list_output_display = gr.Markdown("List will appear here...") # Or use gr.DataFrame
generate_images_button = gr.Button("2. Generate Image Previews", visible=False) # Hidden initially
# Use Gallery for display, Dataset for selection tracking
image_gallery = gr.Gallery(label="Image Previews", visible=False, elem_id="image_gallery")
# Dataset to hold data for selection (item_name, image_path, refined_prompt)
selection_data = gr.Dataset(components=[gr.Textbox(visible=False), gr.Textbox(visible=False), gr.Textbox(visible=False)], # item, img_path, prompt
headers=["Item Name", "Image", "Prompt"],
label="Select Items for 3D Generation",
visible=False)
generate_3d_button = gr.Button("3. Generate 3D Models for Selected Items", visible=False) # Hidden initially
status_output = gr.Markdown("") # For progress updates
final_zip_output = gr.File(label="Download 3D Model Collection (ZIP)", visible=False)
# --- Event Logic ---
def run_list_and_refine(theme, count):
if not theme:
return {list_output_display: "Please enter a theme.", generate_images_button: gr.Button(visible=False)}
# Ensure output dirs exist
os.makedirs(os.path.join(OUTPUT_DIR, IMAGES_SUBDIR), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, MODELS_SUBDIR), exist_ok=True)
gr.Info("Generating list...")
items = generate_list_local(theme, int(count))
if not items or "Error:" in items[0]:
return {list_output_display: f"Failed to generate list: {items[0] if items else 'Unknown error'}",
generate_images_button: gr.Button(visible=False)}
list_items_state.value = items # Save items to state
gr.Info("Refining prompts via API...")
refined_prompts = {}
output_md = "### Generated List & Refined Prompts:\n\n"
for item in items:
refined = refine_prompt_api(item)
refined_prompts[item] = refined
output_md += f"* **{item}:** {refined}\n"
refined_prompts_state.value = refined_prompts # Save refined prompts
# Enable next step
return {
list_output_display: output_md,
generate_images_button: gr.Button(visible=True) # Show image gen button
}
generate_list_button.click(
fn=run_list_and_refine,
inputs=[theme_input, count_input],
outputs=[list_output_display, generate_images_button, list_items_state, refined_prompts_state] # Update state too
)
def run_image_generation(items, refined_prompts_dict):
if not image_pipeline:
# Skip image generation if pipeline not loaded
gr.Warning("Image pipeline not loaded. Skipping image previews.")
# Prepare data for selection without images
selection_samples = [[item, "N/A", refined_prompts_dict.get(item, "")] for item in items]
image_paths_state.value = {} # Clear image paths
return {
image_gallery: gr.Gallery(visible=False),
selection_data: gr.Dataset(samples=selection_samples, visible=True),
generate_3d_button: gr.Button(visible=True) # Allow proceeding without previews
}
gr.Info("Generating image previews... (this may take a while)")
image_paths = {}
gallery_images = []
selection_samples = [] # For the Dataset component
img_dir = os.path.join(OUTPUT_DIR, IMAGES_SUBDIR)
for item in items:
refined_prompt = refined_prompts_dict.get(item, f"Image of {item}") # Get refined prompt
safe_item_name = "".join(c if c.isalnum() else "_" for c in item)
img_filename = f"{safe_item_name}_{uuid.uuid4()}.png"
img_path = os.path.join(img_dir, img_filename)
generated_path = generate_image_local(refined_prompt, img_path)
if generated_path:
image_paths[item] = generated_path
gallery_images.append(generated_path)
selection_samples.append([item, generated_path, refined_prompt])
else:
# Handle image generation failure - maybe add placeholder info
selection_samples.append([item, "Failed", refined_prompt])
# Optionally add a placeholder to gallery_images too
image_paths_state.value = image_paths # Save image paths
# Show gallery and selection dataset
return {
image_gallery: gr.Gallery(value=gallery_images, visible=True),
selection_data: gr.Dataset(samples=selection_samples, visible=True),
generate_3d_button: gr.Button(visible=True) # Show 3D gen button
}
generate_images_button.click(
fn=run_image_generation,
inputs=[list_items_state, refined_prompts_state],
outputs=[image_gallery, selection_data, generate_3d_button, image_paths_state] # Update state
)
# Handler for when user makes selections in the Dataset
# Note: Gradio's Dataset selection handling might require specific event listeners
# or potentially using gr.CheckboxGroup or similar if Dataset selection is tricky.
# For simplicity here, we assume we can get the selected indices/items.
# A common pattern is to add a hidden Textbox updated by JS on selection,
# or use the Dataset's 'select' event if available and robust.
# Let's simulate getting selected *items* (requires correct component setup).
# This part might need refinement based on Gradio version/behavior.
# We'll trigger 3D generation directly from the button click for now,
# assuming the selection_data component holds the necessary info and selection state.
def run_3d_generation(selection_evt: gr.SelectData, all_items_data):
if not hunyuan_client:
return {status_output: "Hunyuan3D client not initialized. Cannot generate.", final_zip_output: gr.File(visible=False)}
selected_indices = selection_evt.index if selection_evt else []
if not selected_indices:
return {status_output: "Please select items from the table above before generating 3D models.", final_zip_output: gr.File(visible=False)}
# Extract selected items based on indices from the *current* data in the dataset
selected_items_info = [all_items_data[i] for i in selected_indices] # Each item is [name, img_path, prompt]
generated_files = []
status_messages = ["### 3D Generation Status:\n"]
model_dir = os.path.join(OUTPUT_DIR, MODELS_SUBDIR)
total_selected = len(selected_items_info)
for i, (item_name, _, refined_prompt) in enumerate(selected_items_info):
current_status = f"({i+1}/{total_selected}) Generating model for: **{item_name}**..."
print(current_status)
status_messages.append(f"* {current_status}")
# Update UI status progressively
yield {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}
# Adapt prompt slightly for 3D if desired, or use the image prompt directly
prompt_for_3d = refined_prompt # Or customize: f"A high quality 3D model of {item_name}, {refined_prompt}"
item_name_safe = "".join(c if c.isalnum() else "_" for c in item_name)
# --- Retry Logic Placeholder ---
max_retries = 1 # Example: allow 1 retry
attempts = 0
model_path = None
last_error = "Unknown error"
while attempts <= max_retries:
attempts += 1
if attempts > 1:
status_messages.append(f" * Retrying ({attempts-1}/{max_retries})...")
yield {status_output: "\n".join(status_messages)}
time.sleep(2) # Brief pause before retry
model_path, msg = generate_3d_model_hunyuan(prompt_for_3d, model_dir, item_name_safe)
last_error = msg
if model_path:
generated_files.append(model_path)
status_messages.append(f" * Success! Model saved.")
break # Exit retry loop on success
else:
status_messages.append(f" * Attempt {attempts} failed: {msg}")
if not model_path:
status_messages.append(f" * **Failed** after {attempts} attempt(s). Last error: {last_error}")
# --- End Retry Logic ---
# Update UI status after each item
yield {status_output: "\n".join(status_messages)}
if generated_files:
status_messages.append("\nCreating ZIP archive...")
yield {status_output: "\n".join(status_messages)}
zip_path = os.path.join(OUTPUT_DIR, ZIP_FILENAME)
final_zip = create_zip(generated_files, zip_path)
status_messages.append(f"\n**Collection ready!** Download '{ZIP_FILENAME}' below.")
generated_3d_files_state.value = generated_files # Store final paths
return {status_output: "\n".join(status_messages), final_zip_output: gr.File(value=final_zip, visible=True)}
else:
status_messages.append("\nNo 3D models were successfully generated.")
return {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}
# Link the button click to the generator function
# The 'select' event on Dataset provides selection info (gr.SelectData)
# We pass both the selection event data and the full dataset content
generate_3d_button.click(
fn=run_3d_generation,
inputs=[selection_data, selection_data], # Pass dataset twice: once for select event, once for full data access
outputs=[status_output, final_zip_output, generated_3d_files_state] # Update state
)
# Launch the Gradio app
demo.queue().launch(debug=True) # Enable queue for longer processes, debug for detailed errors