File size: 20,146 Bytes
9e91681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import gradio as gr
import os
import zipfile
import time
import uuid # For unique filenames

# --- LLM/Model Setup ---
from transformers import pipeline as transformers_pipeline # For local list generation
from huggingface_hub import InferenceClient # For prompt refinement via API
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler # For image generation
from gradio_client import Client as GradioClient, handle_file # For 3D generation

# --- Configuration ---
# Consider making these configurable in the UI later
LIST_GENERATION_MODEL = "google/flan-t5-base" # Or another suitable small model
PROMPT_REFINEMENT_MODEL_API = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Or another instruct model via Inference API
IMAGE_GENERATION_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # Or "runwayml/stable-diffusion-v1-5"
HUNYUAN_SPACE_ID = "tencent/Hunyuan3D-2"
OUTPUT_DIR = "outputs"
MODELS_SUBDIR = "3d_models"
IMAGES_SUBDIR = "image_previews"
ZIP_FILENAME = "3d_collection.zip"

# --- Initialize Clients/Pipelines (can be slow, consider loading on demand if needed) ---

# Use HF Token from Space secrets if available/needed for Inference API
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)

# Basic List Generator (local)
try:
    list_generator = transformers_pipeline("text2text-generation", model=LIST_GENERATION_MODEL)
except Exception as e:
    print(f"Warning: Could not load local list generator {LIST_GENERATION_MODEL}: {e}")
    list_generator = None

# Prompt Refiner (API)
try:
    if not HF_TOKEN:
        print("Warning: HUGGINGFACE_TOKEN not set. Inference API calls might be rate-limited or fail.")
    prompt_refiner = InferenceClient(model=PROMPT_REFINEMENT_MODEL_API, token=HF_TOKEN)
except Exception as e:
    print(f"Warning: Could not initialize InferenceClient for {PROMPT_REFINEMENT_MODEL_API}: {e}")
    prompt_refiner = None

# Image Generator (Local with Diffusers - requires GPU on Space for reasonable speed)
# Or consider an Image Gen API service if running on CPU hardware
try:
    # Using XL as an example - adjust based on available hardware
    image_pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_GENERATION_MODEL, torch_dtype=torch.float16, use_safetensors=True)
    # Move to GPU if available (check Space hardware)
    # image_pipeline.to("cuda") # Uncomment if GPU is available
    image_pipeline.scheduler = EulerDiscreteScheduler.from_config(image_pipeline.scheduler.config)
except Exception as e:
    print(f"Warning: Could not load diffusers pipeline {IMAGE_GENERATION_MODEL}. Image generation might fail: {e}")
    image_pipeline = None


# 3D Generator Client
try:
    hunyuan_client = GradioClient(HUNYUAN_SPACE_ID)
except Exception as e:
    print(f"Error initializing GradioClient for {HUNYUAN_SPACE_ID}: {e}")
    hunyuan_client = None

# --- Helper Functions ---

def generate_list_local(theme, count):
    if not list_generator:
        return ["Error: List generator model not loaded."]
    prompt = f"Generate a comma-separated list of {count} distinct types of {theme}."
    try:
        result = list_generator(prompt, max_length=200)[0]['generated_text']
        items = [item.strip() for item in result.split(',') if item.strip()]
        return items[:count] # Ensure we don't exceed the requested count
    except Exception as e:
        print(f"Error generating list: {e}")
        return [f"Error: {e}"]

def refine_prompt_api(item_name):
    if not prompt_refiner:
        return f"A 3D model of a {item_name}" # Fallback basic prompt
    prompt = f"Create a detailed, descriptive prompt for generating a highly realistic image of a single '{item_name}'. Focus on visual details suitable for a text-to-image AI. Only output the prompt itself."
    try:
        refined = prompt_refiner.text_generation(prompt, max_new_tokens=100)
        # Clean up potential API artifacts if necessary
        refined = refined.strip().strip('"')
        return refined
    except Exception as e:
        print(f"Error refining prompt for '{item_name}': {e}")
        # Fallback to a simpler prompt for 3D generation if refinement fails
        return f"A high quality 3D model of a {item_name}"

def generate_image_local(refined_prompt, output_path):
    if not image_pipeline:
        print("Image generation pipeline not available.")
        # Create a placeholder image or return None
        # Example: from PIL import Image; img = Image.new('RGB', (60, 30), color = 'red'); img.save(output_path); return output_path
        return None
    try:
        # Adjust inference steps/guidance as needed
        image = image_pipeline(refined_prompt, num_inference_steps=25, guidance_scale=7.5).images[0]
        image.save(output_path)
        return output_path
    except Exception as e:
        print(f"Error generating image for prompt '{refined_prompt}': {e}")
        return None

def generate_3d_model_hunyuan(refined_prompt_for_3d, output_dir, item_name_safe):
    if not hunyuan_client:
        print("Hunyuan 3D client not available.")
        return None, "Client not initialized"

    print(f"Requesting 3D model for: {refined_prompt_for_3d}")
    # Use defaults for most parameters initially
    try:
        result_tuple = hunyuan_client.predict(
                caption=refined_prompt_for_3d,
                # Leave image and mv_image inputs as None for text-to-3D
                image=None,
                mv_image_front=None,
                mv_image_back=None,
                mv_image_left=None,
                mv_image_right=None,
                # Default values from API docs (can be overridden)
                steps=30,
                guidance_scale=5,
                seed=1234, # Or use randomize_seed=True
                octree_resolution=256,
                check_box_rembg=True,
                num_chunks=8000,
                randomize_seed=True,
                api_name="/generation_all" # Crucial!
        )

        # --- VERIFICATION NEEDED ---
        # Check the actual return tuple structure. Assuming file path is first or second.
        # Let's try the first element (index 0). If it's None or not a path, try index 1.
        raw_filepath = None
        if result_tuple and len(result_tuple) > 0 and isinstance(result_tuple[0], str):
             raw_filepath = result_tuple[0]
        elif result_tuple and len(result_tuple) > 1 and isinstance(result_tuple[1], str):
             print("Using second element from result tuple for filepath.")
             raw_filepath = result_tuple[1]
        # --- END VERIFICATION NEEDED ---

        if raw_filepath:
            print(f"Job completed. Raw result path: {raw_filepath}")
            os.makedirs(output_dir, exist_ok=True)

            # Download the file using handle_file which manages temp paths etc.
            # handle_file saves with a potentially random name in download_dir
            downloaded_temp_path = handle_file(raw_filepath, download_dir=output_dir)

            if downloaded_temp_path and os.path.exists(downloaded_temp_path):
                # Rename it to something meaningful
                file_ext = os.path.splitext(downloaded_temp_path)[1] # Get extension (.glb, .obj?)
                if not file_ext: file_ext = ".glb" # Assume glb if unknown
                final_path = os.path.join(output_dir, f"{item_name_safe}{file_ext}")
                os.rename(downloaded_temp_path, final_path)
                print(f"Model saved to: {final_path}")
                return final_path, "Success"
            else:
                 error_msg = f"handle_file failed to download or returned invalid path: {downloaded_temp_path}"
                 print(error_msg)
                 return None, error_msg
        else:
            error_msg = f"Job for '{refined_prompt_for_3d}' did not return a valid filepath in expected tuple elements."
            print(error_msg)
            # You might want to inspect the full result_tuple here for debugging
            print(f"Full result tuple: {result_tuple}")
            return None, error_msg

    except Exception as e:
        error_msg = f"Error calling Hunyuan3D API for '{refined_prompt_for_3d}': {e}"
        print(error_msg)
        return None, str(e)

def create_zip(files_to_zip, zip_filepath):
    with zipfile.ZipFile(zip_filepath, 'w') as zf:
        for file_path in files_to_zip:
            if file_path and os.path.exists(file_path):
                zf.write(file_path, os.path.basename(file_path))
    return zip_filepath


# --- Gradio Interface & Logic ---

with gr.Blocks() as demo:
    gr.Markdown("# 3D Asset Collection Generator")
    gr.Markdown("Generate a list based on a theme, refine prompts, preview images, and generate selected 3D models using Hunyuan3D-2.")
    if not HF_TOKEN:
        gr.Warning("Hugging Face Token not found. Prompt refinement quality/rate limits may be affected. Consider adding HUGGINGFACE_TOKEN to Space secrets.")
    if not image_pipeline:
        gr.Warning("Local Image Generation model failed to load. Image previews will be skipped. Check Space hardware/logs.")
    if not hunyuan_client:
        gr.Error("Failed to connect to the Hunyuan3D-2 Space. 3D generation will not work.")


    # State to hold intermediate results
    # Using gr.State is good for simple values, for complex lists/dicts might need alternatives or careful handling
    list_items_state = gr.State([])
    refined_prompts_state = gr.State({}) # Dict: {item_name: refined_prompt}
    image_paths_state = gr.State({}) # Dict: {item_name: image_path}
    selected_items_state = gr.State([]) # List of item_names selected by user
    generated_3d_files_state = gr.State([]) # List of paths to successfully generated models


    with gr.Row():
        theme_input = gr.Textbox(label="Theme", placeholder="e.g., reptiles, kitchen appliances, medieval weapons")
        count_input = gr.Number(label="Number of Items", value=5, minimum=1, step=1)

    generate_list_button = gr.Button("1. Generate List & Refine Prompts")
    list_output_display = gr.Markdown("List will appear here...") # Or use gr.DataFrame

    generate_images_button = gr.Button("2. Generate Image Previews", visible=False) # Hidden initially
    # Use Gallery for display, Dataset for selection tracking
    image_gallery = gr.Gallery(label="Image Previews", visible=False, elem_id="image_gallery")
    # Dataset to hold data for selection (item_name, image_path, refined_prompt)
    selection_data = gr.Dataset(components=[gr.Textbox(visible=False), gr.Textbox(visible=False), gr.Textbox(visible=False)], # item, img_path, prompt
                                 headers=["Item Name", "Image", "Prompt"],
                                 label="Select Items for 3D Generation",
                                 visible=False)


    generate_3d_button = gr.Button("3. Generate 3D Models for Selected Items", visible=False) # Hidden initially
    status_output = gr.Markdown("") # For progress updates
    final_zip_output = gr.File(label="Download 3D Model Collection (ZIP)", visible=False)


    # --- Event Logic ---

    def run_list_and_refine(theme, count):
        if not theme:
            return {list_output_display: "Please enter a theme.", generate_images_button: gr.Button(visible=False)}

        # Ensure output dirs exist
        os.makedirs(os.path.join(OUTPUT_DIR, IMAGES_SUBDIR), exist_ok=True)
        os.makedirs(os.path.join(OUTPUT_DIR, MODELS_SUBDIR), exist_ok=True)


        gr.Info("Generating list...")
        items = generate_list_local(theme, int(count))
        if not items or "Error:" in items[0]:
            return {list_output_display: f"Failed to generate list: {items[0] if items else 'Unknown error'}",
                    generate_images_button: gr.Button(visible=False)}

        list_items_state.value = items # Save items to state

        gr.Info("Refining prompts via API...")
        refined_prompts = {}
        output_md = "### Generated List & Refined Prompts:\n\n"
        for item in items:
            refined = refine_prompt_api(item)
            refined_prompts[item] = refined
            output_md += f"*   **{item}:** {refined}\n"

        refined_prompts_state.value = refined_prompts # Save refined prompts

        # Enable next step
        return {
            list_output_display: output_md,
            generate_images_button: gr.Button(visible=True) # Show image gen button
        }

    generate_list_button.click(
        fn=run_list_and_refine,
        inputs=[theme_input, count_input],
        outputs=[list_output_display, generate_images_button, list_items_state, refined_prompts_state] # Update state too
    )


    def run_image_generation(items, refined_prompts_dict):
        if not image_pipeline:
             # Skip image generation if pipeline not loaded
             gr.Warning("Image pipeline not loaded. Skipping image previews.")
             # Prepare data for selection without images
             selection_samples = [[item, "N/A", refined_prompts_dict.get(item, "")] for item in items]
             image_paths_state.value = {} # Clear image paths
             return {
                 image_gallery: gr.Gallery(visible=False),
                 selection_data: gr.Dataset(samples=selection_samples, visible=True),
                 generate_3d_button: gr.Button(visible=True) # Allow proceeding without previews
             }

        gr.Info("Generating image previews... (this may take a while)")
        image_paths = {}
        gallery_images = []
        selection_samples = [] # For the Dataset component

        img_dir = os.path.join(OUTPUT_DIR, IMAGES_SUBDIR)

        for item in items:
            refined_prompt = refined_prompts_dict.get(item, f"Image of {item}") # Get refined prompt
            safe_item_name = "".join(c if c.isalnum() else "_" for c in item)
            img_filename = f"{safe_item_name}_{uuid.uuid4()}.png"
            img_path = os.path.join(img_dir, img_filename)

            generated_path = generate_image_local(refined_prompt, img_path)

            if generated_path:
                image_paths[item] = generated_path
                gallery_images.append(generated_path)
                selection_samples.append([item, generated_path, refined_prompt])
            else:
                # Handle image generation failure - maybe add placeholder info
                selection_samples.append([item, "Failed", refined_prompt])
                # Optionally add a placeholder to gallery_images too

        image_paths_state.value = image_paths # Save image paths

        # Show gallery and selection dataset
        return {
            image_gallery: gr.Gallery(value=gallery_images, visible=True),
            selection_data: gr.Dataset(samples=selection_samples, visible=True),
            generate_3d_button: gr.Button(visible=True) # Show 3D gen button
        }

    generate_images_button.click(
        fn=run_image_generation,
        inputs=[list_items_state, refined_prompts_state],
        outputs=[image_gallery, selection_data, generate_3d_button, image_paths_state] # Update state
    )

    # Handler for when user makes selections in the Dataset
    # Note: Gradio's Dataset selection handling might require specific event listeners
    # or potentially using gr.CheckboxGroup or similar if Dataset selection is tricky.
    # For simplicity here, we assume we can get the selected indices/items.
    # A common pattern is to add a hidden Textbox updated by JS on selection,
    # or use the Dataset's 'select' event if available and robust.
    # Let's simulate getting selected *items* (requires correct component setup).
    # This part might need refinement based on Gradio version/behavior.

    # We'll trigger 3D generation directly from the button click for now,
    # assuming the selection_data component holds the necessary info and selection state.

    def run_3d_generation(selection_evt: gr.SelectData, all_items_data):
        if not hunyuan_client:
            return {status_output: "Hunyuan3D client not initialized. Cannot generate.", final_zip_output: gr.File(visible=False)}

        selected_indices = selection_evt.index if selection_evt else []
        if not selected_indices:
            return {status_output: "Please select items from the table above before generating 3D models.", final_zip_output: gr.File(visible=False)}

        # Extract selected items based on indices from the *current* data in the dataset
        selected_items_info = [all_items_data[i] for i in selected_indices] # Each item is [name, img_path, prompt]

        generated_files = []
        status_messages = ["### 3D Generation Status:\n"]

        model_dir = os.path.join(OUTPUT_DIR, MODELS_SUBDIR)

        total_selected = len(selected_items_info)
        for i, (item_name, _, refined_prompt) in enumerate(selected_items_info):
            current_status = f"({i+1}/{total_selected}) Generating model for: **{item_name}**..."
            print(current_status)
            status_messages.append(f"*   {current_status}")
            # Update UI status progressively
            yield {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}

            # Adapt prompt slightly for 3D if desired, or use the image prompt directly
            prompt_for_3d = refined_prompt # Or customize: f"A high quality 3D model of {item_name}, {refined_prompt}"

            item_name_safe = "".join(c if c.isalnum() else "_" for c in item_name)

            # --- Retry Logic Placeholder ---
            max_retries = 1 # Example: allow 1 retry
            attempts = 0
            model_path = None
            last_error = "Unknown error"

            while attempts <= max_retries:
                attempts += 1
                if attempts > 1:
                     status_messages.append(f"    *   Retrying ({attempts-1}/{max_retries})...")
                     yield {status_output: "\n".join(status_messages)}
                     time.sleep(2) # Brief pause before retry

                model_path, msg = generate_3d_model_hunyuan(prompt_for_3d, model_dir, item_name_safe)
                last_error = msg
                if model_path:
                    generated_files.append(model_path)
                    status_messages.append(f"    *   Success! Model saved.")
                    break # Exit retry loop on success
                else:
                    status_messages.append(f"    *   Attempt {attempts} failed: {msg}")

            if not model_path:
                 status_messages.append(f"    *   **Failed** after {attempts} attempt(s). Last error: {last_error}")
            # --- End Retry Logic ---

            # Update UI status after each item
            yield {status_output: "\n".join(status_messages)}


        if generated_files:
            status_messages.append("\nCreating ZIP archive...")
            yield {status_output: "\n".join(status_messages)}
            zip_path = os.path.join(OUTPUT_DIR, ZIP_FILENAME)
            final_zip = create_zip(generated_files, zip_path)
            status_messages.append(f"\n**Collection ready!** Download '{ZIP_FILENAME}' below.")
            generated_3d_files_state.value = generated_files # Store final paths
            return {status_output: "\n".join(status_messages), final_zip_output: gr.File(value=final_zip, visible=True)}
        else:
            status_messages.append("\nNo 3D models were successfully generated.")
            return {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)}


    # Link the button click to the generator function
    # The 'select' event on Dataset provides selection info (gr.SelectData)
    # We pass both the selection event data and the full dataset content
    generate_3d_button.click(
        fn=run_3d_generation,
        inputs=[selection_data, selection_data], # Pass dataset twice: once for select event, once for full data access
        outputs=[status_output, final_zip_output, generated_3d_files_state] # Update state
    )


# Launch the Gradio app
demo.queue().launch(debug=True) # Enable queue for longer processes, debug for detailed errors