import gradio as gr import os import zipfile import time import uuid # For unique filenames # --- LLM/Model Setup --- from transformers import pipeline as transformers_pipeline # For local list generation from huggingface_hub import InferenceClient # For prompt refinement via API from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler # For image generation from gradio_client import Client as GradioClient, handle_file # For 3D generation # --- Configuration --- # Consider making these configurable in the UI later LIST_GENERATION_MODEL = "google/flan-t5-base" # Or another suitable small model PROMPT_REFINEMENT_MODEL_API = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Or another instruct model via Inference API IMAGE_GENERATION_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # Or "runwayml/stable-diffusion-v1-5" HUNYUAN_SPACE_ID = "tencent/Hunyuan3D-2" OUTPUT_DIR = "outputs" MODELS_SUBDIR = "3d_models" IMAGES_SUBDIR = "image_previews" ZIP_FILENAME = "3d_collection.zip" # --- Initialize Clients/Pipelines (can be slow, consider loading on demand if needed) --- # Use HF Token from Space secrets if available/needed for Inference API HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None) # Basic List Generator (local) try: list_generator = transformers_pipeline("text2text-generation", model=LIST_GENERATION_MODEL) except Exception as e: print(f"Warning: Could not load local list generator {LIST_GENERATION_MODEL}: {e}") list_generator = None # Prompt Refiner (API) try: if not HF_TOKEN: print("Warning: HUGGINGFACE_TOKEN not set. Inference API calls might be rate-limited or fail.") prompt_refiner = InferenceClient(model=PROMPT_REFINEMENT_MODEL_API, token=HF_TOKEN) except Exception as e: print(f"Warning: Could not initialize InferenceClient for {PROMPT_REFINEMENT_MODEL_API}: {e}") prompt_refiner = None # Image Generator (Local with Diffusers - requires GPU on Space for reasonable speed) # Or consider an Image Gen API service if running on CPU hardware try: # Using XL as an example - adjust based on available hardware image_pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_GENERATION_MODEL, torch_dtype=torch.float16, use_safetensors=True) # Move to GPU if available (check Space hardware) # image_pipeline.to("cuda") # Uncomment if GPU is available image_pipeline.scheduler = EulerDiscreteScheduler.from_config(image_pipeline.scheduler.config) except Exception as e: print(f"Warning: Could not load diffusers pipeline {IMAGE_GENERATION_MODEL}. Image generation might fail: {e}") image_pipeline = None # 3D Generator Client try: hunyuan_client = GradioClient(HUNYUAN_SPACE_ID) except Exception as e: print(f"Error initializing GradioClient for {HUNYUAN_SPACE_ID}: {e}") hunyuan_client = None # --- Helper Functions --- def generate_list_local(theme, count): if not list_generator: return ["Error: List generator model not loaded."] prompt = f"Generate a comma-separated list of {count} distinct types of {theme}." try: result = list_generator(prompt, max_length=200)[0]['generated_text'] items = [item.strip() for item in result.split(',') if item.strip()] return items[:count] # Ensure we don't exceed the requested count except Exception as e: print(f"Error generating list: {e}") return [f"Error: {e}"] def refine_prompt_api(item_name): if not prompt_refiner: return f"A 3D model of a {item_name}" # Fallback basic prompt prompt = f"Create a detailed, descriptive prompt for generating a highly realistic image of a single '{item_name}'. Focus on visual details suitable for a text-to-image AI. Only output the prompt itself." try: refined = prompt_refiner.text_generation(prompt, max_new_tokens=100) # Clean up potential API artifacts if necessary refined = refined.strip().strip('"') return refined except Exception as e: print(f"Error refining prompt for '{item_name}': {e}") # Fallback to a simpler prompt for 3D generation if refinement fails return f"A high quality 3D model of a {item_name}" def generate_image_local(refined_prompt, output_path): if not image_pipeline: print("Image generation pipeline not available.") # Create a placeholder image or return None # Example: from PIL import Image; img = Image.new('RGB', (60, 30), color = 'red'); img.save(output_path); return output_path return None try: # Adjust inference steps/guidance as needed image = image_pipeline(refined_prompt, num_inference_steps=25, guidance_scale=7.5).images[0] image.save(output_path) return output_path except Exception as e: print(f"Error generating image for prompt '{refined_prompt}': {e}") return None def generate_3d_model_hunyuan(refined_prompt_for_3d, output_dir, item_name_safe): if not hunyuan_client: print("Hunyuan 3D client not available.") return None, "Client not initialized" print(f"Requesting 3D model for: {refined_prompt_for_3d}") # Use defaults for most parameters initially try: result_tuple = hunyuan_client.predict( caption=refined_prompt_for_3d, # Leave image and mv_image inputs as None for text-to-3D image=None, mv_image_front=None, mv_image_back=None, mv_image_left=None, mv_image_right=None, # Default values from API docs (can be overridden) steps=30, guidance_scale=5, seed=1234, # Or use randomize_seed=True octree_resolution=256, check_box_rembg=True, num_chunks=8000, randomize_seed=True, api_name="/generation_all" # Crucial! ) # --- VERIFICATION NEEDED --- # Check the actual return tuple structure. Assuming file path is first or second. # Let's try the first element (index 0). If it's None or not a path, try index 1. raw_filepath = None if result_tuple and len(result_tuple) > 0 and isinstance(result_tuple[0], str): raw_filepath = result_tuple[0] elif result_tuple and len(result_tuple) > 1 and isinstance(result_tuple[1], str): print("Using second element from result tuple for filepath.") raw_filepath = result_tuple[1] # --- END VERIFICATION NEEDED --- if raw_filepath: print(f"Job completed. Raw result path: {raw_filepath}") os.makedirs(output_dir, exist_ok=True) # Download the file using handle_file which manages temp paths etc. # handle_file saves with a potentially random name in download_dir downloaded_temp_path = handle_file(raw_filepath, download_dir=output_dir) if downloaded_temp_path and os.path.exists(downloaded_temp_path): # Rename it to something meaningful file_ext = os.path.splitext(downloaded_temp_path)[1] # Get extension (.glb, .obj?) if not file_ext: file_ext = ".glb" # Assume glb if unknown final_path = os.path.join(output_dir, f"{item_name_safe}{file_ext}") os.rename(downloaded_temp_path, final_path) print(f"Model saved to: {final_path}") return final_path, "Success" else: error_msg = f"handle_file failed to download or returned invalid path: {downloaded_temp_path}" print(error_msg) return None, error_msg else: error_msg = f"Job for '{refined_prompt_for_3d}' did not return a valid filepath in expected tuple elements." print(error_msg) # You might want to inspect the full result_tuple here for debugging print(f"Full result tuple: {result_tuple}") return None, error_msg except Exception as e: error_msg = f"Error calling Hunyuan3D API for '{refined_prompt_for_3d}': {e}" print(error_msg) return None, str(e) def create_zip(files_to_zip, zip_filepath): with zipfile.ZipFile(zip_filepath, 'w') as zf: for file_path in files_to_zip: if file_path and os.path.exists(file_path): zf.write(file_path, os.path.basename(file_path)) return zip_filepath # --- Gradio Interface & Logic --- with gr.Blocks() as demo: gr.Markdown("# 3D Asset Collection Generator") gr.Markdown("Generate a list based on a theme, refine prompts, preview images, and generate selected 3D models using Hunyuan3D-2.") if not HF_TOKEN: gr.Warning("Hugging Face Token not found. Prompt refinement quality/rate limits may be affected. Consider adding HUGGINGFACE_TOKEN to Space secrets.") if not image_pipeline: gr.Warning("Local Image Generation model failed to load. Image previews will be skipped. Check Space hardware/logs.") if not hunyuan_client: gr.Error("Failed to connect to the Hunyuan3D-2 Space. 3D generation will not work.") # State to hold intermediate results # Using gr.State is good for simple values, for complex lists/dicts might need alternatives or careful handling list_items_state = gr.State([]) refined_prompts_state = gr.State({}) # Dict: {item_name: refined_prompt} image_paths_state = gr.State({}) # Dict: {item_name: image_path} selected_items_state = gr.State([]) # List of item_names selected by user generated_3d_files_state = gr.State([]) # List of paths to successfully generated models with gr.Row(): theme_input = gr.Textbox(label="Theme", placeholder="e.g., reptiles, kitchen appliances, medieval weapons") count_input = gr.Number(label="Number of Items", value=5, minimum=1, step=1) generate_list_button = gr.Button("1. Generate List & Refine Prompts") list_output_display = gr.Markdown("List will appear here...") # Or use gr.DataFrame generate_images_button = gr.Button("2. Generate Image Previews", visible=False) # Hidden initially # Use Gallery for display, Dataset for selection tracking image_gallery = gr.Gallery(label="Image Previews", visible=False, elem_id="image_gallery") # Dataset to hold data for selection (item_name, image_path, refined_prompt) selection_data = gr.Dataset(components=[gr.Textbox(visible=False), gr.Textbox(visible=False), gr.Textbox(visible=False)], # item, img_path, prompt headers=["Item Name", "Image", "Prompt"], label="Select Items for 3D Generation", visible=False) generate_3d_button = gr.Button("3. Generate 3D Models for Selected Items", visible=False) # Hidden initially status_output = gr.Markdown("") # For progress updates final_zip_output = gr.File(label="Download 3D Model Collection (ZIP)", visible=False) # --- Event Logic --- def run_list_and_refine(theme, count): if not theme: return {list_output_display: "Please enter a theme.", generate_images_button: gr.Button(visible=False)} # Ensure output dirs exist os.makedirs(os.path.join(OUTPUT_DIR, IMAGES_SUBDIR), exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, MODELS_SUBDIR), exist_ok=True) gr.Info("Generating list...") items = generate_list_local(theme, int(count)) if not items or "Error:" in items[0]: return {list_output_display: f"Failed to generate list: {items[0] if items else 'Unknown error'}", generate_images_button: gr.Button(visible=False)} list_items_state.value = items # Save items to state gr.Info("Refining prompts via API...") refined_prompts = {} output_md = "### Generated List & Refined Prompts:\n\n" for item in items: refined = refine_prompt_api(item) refined_prompts[item] = refined output_md += f"* **{item}:** {refined}\n" refined_prompts_state.value = refined_prompts # Save refined prompts # Enable next step return { list_output_display: output_md, generate_images_button: gr.Button(visible=True) # Show image gen button } generate_list_button.click( fn=run_list_and_refine, inputs=[theme_input, count_input], outputs=[list_output_display, generate_images_button, list_items_state, refined_prompts_state] # Update state too ) def run_image_generation(items, refined_prompts_dict): if not image_pipeline: # Skip image generation if pipeline not loaded gr.Warning("Image pipeline not loaded. Skipping image previews.") # Prepare data for selection without images selection_samples = [[item, "N/A", refined_prompts_dict.get(item, "")] for item in items] image_paths_state.value = {} # Clear image paths return { image_gallery: gr.Gallery(visible=False), selection_data: gr.Dataset(samples=selection_samples, visible=True), generate_3d_button: gr.Button(visible=True) # Allow proceeding without previews } gr.Info("Generating image previews... (this may take a while)") image_paths = {} gallery_images = [] selection_samples = [] # For the Dataset component img_dir = os.path.join(OUTPUT_DIR, IMAGES_SUBDIR) for item in items: refined_prompt = refined_prompts_dict.get(item, f"Image of {item}") # Get refined prompt safe_item_name = "".join(c if c.isalnum() else "_" for c in item) img_filename = f"{safe_item_name}_{uuid.uuid4()}.png" img_path = os.path.join(img_dir, img_filename) generated_path = generate_image_local(refined_prompt, img_path) if generated_path: image_paths[item] = generated_path gallery_images.append(generated_path) selection_samples.append([item, generated_path, refined_prompt]) else: # Handle image generation failure - maybe add placeholder info selection_samples.append([item, "Failed", refined_prompt]) # Optionally add a placeholder to gallery_images too image_paths_state.value = image_paths # Save image paths # Show gallery and selection dataset return { image_gallery: gr.Gallery(value=gallery_images, visible=True), selection_data: gr.Dataset(samples=selection_samples, visible=True), generate_3d_button: gr.Button(visible=True) # Show 3D gen button } generate_images_button.click( fn=run_image_generation, inputs=[list_items_state, refined_prompts_state], outputs=[image_gallery, selection_data, generate_3d_button, image_paths_state] # Update state ) # Handler for when user makes selections in the Dataset # Note: Gradio's Dataset selection handling might require specific event listeners # or potentially using gr.CheckboxGroup or similar if Dataset selection is tricky. # For simplicity here, we assume we can get the selected indices/items. # A common pattern is to add a hidden Textbox updated by JS on selection, # or use the Dataset's 'select' event if available and robust. # Let's simulate getting selected *items* (requires correct component setup). # This part might need refinement based on Gradio version/behavior. # We'll trigger 3D generation directly from the button click for now, # assuming the selection_data component holds the necessary info and selection state. def run_3d_generation(selection_evt: gr.SelectData, all_items_data): if not hunyuan_client: return {status_output: "Hunyuan3D client not initialized. Cannot generate.", final_zip_output: gr.File(visible=False)} selected_indices = selection_evt.index if selection_evt else [] if not selected_indices: return {status_output: "Please select items from the table above before generating 3D models.", final_zip_output: gr.File(visible=False)} # Extract selected items based on indices from the *current* data in the dataset selected_items_info = [all_items_data[i] for i in selected_indices] # Each item is [name, img_path, prompt] generated_files = [] status_messages = ["### 3D Generation Status:\n"] model_dir = os.path.join(OUTPUT_DIR, MODELS_SUBDIR) total_selected = len(selected_items_info) for i, (item_name, _, refined_prompt) in enumerate(selected_items_info): current_status = f"({i+1}/{total_selected}) Generating model for: **{item_name}**..." print(current_status) status_messages.append(f"* {current_status}") # Update UI status progressively yield {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)} # Adapt prompt slightly for 3D if desired, or use the image prompt directly prompt_for_3d = refined_prompt # Or customize: f"A high quality 3D model of {item_name}, {refined_prompt}" item_name_safe = "".join(c if c.isalnum() else "_" for c in item_name) # --- Retry Logic Placeholder --- max_retries = 1 # Example: allow 1 retry attempts = 0 model_path = None last_error = "Unknown error" while attempts <= max_retries: attempts += 1 if attempts > 1: status_messages.append(f" * Retrying ({attempts-1}/{max_retries})...") yield {status_output: "\n".join(status_messages)} time.sleep(2) # Brief pause before retry model_path, msg = generate_3d_model_hunyuan(prompt_for_3d, model_dir, item_name_safe) last_error = msg if model_path: generated_files.append(model_path) status_messages.append(f" * Success! Model saved.") break # Exit retry loop on success else: status_messages.append(f" * Attempt {attempts} failed: {msg}") if not model_path: status_messages.append(f" * **Failed** after {attempts} attempt(s). Last error: {last_error}") # --- End Retry Logic --- # Update UI status after each item yield {status_output: "\n".join(status_messages)} if generated_files: status_messages.append("\nCreating ZIP archive...") yield {status_output: "\n".join(status_messages)} zip_path = os.path.join(OUTPUT_DIR, ZIP_FILENAME) final_zip = create_zip(generated_files, zip_path) status_messages.append(f"\n**Collection ready!** Download '{ZIP_FILENAME}' below.") generated_3d_files_state.value = generated_files # Store final paths return {status_output: "\n".join(status_messages), final_zip_output: gr.File(value=final_zip, visible=True)} else: status_messages.append("\nNo 3D models were successfully generated.") return {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)} # Link the button click to the generator function # The 'select' event on Dataset provides selection info (gr.SelectData) # We pass both the selection event data and the full dataset content generate_3d_button.click( fn=run_3d_generation, inputs=[selection_data, selection_data], # Pass dataset twice: once for select event, once for full data access outputs=[status_output, final_zip_output, generated_3d_files_state] # Update state ) # Launch the Gradio app demo.queue().launch(debug=True) # Enable queue for longer processes, debug for detailed errors