|
import gradio as gr |
|
import os |
|
import zipfile |
|
import time |
|
import uuid |
|
|
|
|
|
from transformers import pipeline as transformers_pipeline |
|
from huggingface_hub import InferenceClient |
|
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler |
|
from gradio_client import Client as GradioClient, handle_file |
|
|
|
|
|
|
|
LIST_GENERATION_MODEL = "google/flan-t5-base" |
|
PROMPT_REFINEMENT_MODEL_API = "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
IMAGE_GENERATION_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" |
|
HUNYUAN_SPACE_ID = "tencent/Hunyuan3D-2" |
|
OUTPUT_DIR = "outputs" |
|
MODELS_SUBDIR = "3d_models" |
|
IMAGES_SUBDIR = "image_previews" |
|
ZIP_FILENAME = "3d_collection.zip" |
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None) |
|
|
|
|
|
try: |
|
list_generator = transformers_pipeline("text2text-generation", model=LIST_GENERATION_MODEL) |
|
except Exception as e: |
|
print(f"Warning: Could not load local list generator {LIST_GENERATION_MODEL}: {e}") |
|
list_generator = None |
|
|
|
|
|
try: |
|
if not HF_TOKEN: |
|
print("Warning: HUGGINGFACE_TOKEN not set. Inference API calls might be rate-limited or fail.") |
|
prompt_refiner = InferenceClient(model=PROMPT_REFINEMENT_MODEL_API, token=HF_TOKEN) |
|
except Exception as e: |
|
print(f"Warning: Could not initialize InferenceClient for {PROMPT_REFINEMENT_MODEL_API}: {e}") |
|
prompt_refiner = None |
|
|
|
|
|
|
|
try: |
|
|
|
image_pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_GENERATION_MODEL, torch_dtype=torch.float16, use_safetensors=True) |
|
|
|
|
|
image_pipeline.scheduler = EulerDiscreteScheduler.from_config(image_pipeline.scheduler.config) |
|
except Exception as e: |
|
print(f"Warning: Could not load diffusers pipeline {IMAGE_GENERATION_MODEL}. Image generation might fail: {e}") |
|
image_pipeline = None |
|
|
|
|
|
|
|
try: |
|
hunyuan_client = GradioClient(HUNYUAN_SPACE_ID) |
|
except Exception as e: |
|
print(f"Error initializing GradioClient for {HUNYUAN_SPACE_ID}: {e}") |
|
hunyuan_client = None |
|
|
|
|
|
|
|
def generate_list_local(theme, count): |
|
if not list_generator: |
|
return ["Error: List generator model not loaded."] |
|
prompt = f"Generate a comma-separated list of {count} distinct types of {theme}." |
|
try: |
|
result = list_generator(prompt, max_length=200)[0]['generated_text'] |
|
items = [item.strip() for item in result.split(',') if item.strip()] |
|
return items[:count] |
|
except Exception as e: |
|
print(f"Error generating list: {e}") |
|
return [f"Error: {e}"] |
|
|
|
def refine_prompt_api(item_name): |
|
if not prompt_refiner: |
|
return f"A 3D model of a {item_name}" |
|
prompt = f"Create a detailed, descriptive prompt for generating a highly realistic image of a single '{item_name}'. Focus on visual details suitable for a text-to-image AI. Only output the prompt itself." |
|
try: |
|
refined = prompt_refiner.text_generation(prompt, max_new_tokens=100) |
|
|
|
refined = refined.strip().strip('"') |
|
return refined |
|
except Exception as e: |
|
print(f"Error refining prompt for '{item_name}': {e}") |
|
|
|
return f"A high quality 3D model of a {item_name}" |
|
|
|
def generate_image_local(refined_prompt, output_path): |
|
if not image_pipeline: |
|
print("Image generation pipeline not available.") |
|
|
|
|
|
return None |
|
try: |
|
|
|
image = image_pipeline(refined_prompt, num_inference_steps=25, guidance_scale=7.5).images[0] |
|
image.save(output_path) |
|
return output_path |
|
except Exception as e: |
|
print(f"Error generating image for prompt '{refined_prompt}': {e}") |
|
return None |
|
|
|
def generate_3d_model_hunyuan(refined_prompt_for_3d, output_dir, item_name_safe): |
|
if not hunyuan_client: |
|
print("Hunyuan 3D client not available.") |
|
return None, "Client not initialized" |
|
|
|
print(f"Requesting 3D model for: {refined_prompt_for_3d}") |
|
|
|
try: |
|
result_tuple = hunyuan_client.predict( |
|
caption=refined_prompt_for_3d, |
|
|
|
image=None, |
|
mv_image_front=None, |
|
mv_image_back=None, |
|
mv_image_left=None, |
|
mv_image_right=None, |
|
|
|
steps=30, |
|
guidance_scale=5, |
|
seed=1234, |
|
octree_resolution=256, |
|
check_box_rembg=True, |
|
num_chunks=8000, |
|
randomize_seed=True, |
|
api_name="/generation_all" |
|
) |
|
|
|
|
|
|
|
|
|
raw_filepath = None |
|
if result_tuple and len(result_tuple) > 0 and isinstance(result_tuple[0], str): |
|
raw_filepath = result_tuple[0] |
|
elif result_tuple and len(result_tuple) > 1 and isinstance(result_tuple[1], str): |
|
print("Using second element from result tuple for filepath.") |
|
raw_filepath = result_tuple[1] |
|
|
|
|
|
if raw_filepath: |
|
print(f"Job completed. Raw result path: {raw_filepath}") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
downloaded_temp_path = handle_file(raw_filepath, download_dir=output_dir) |
|
|
|
if downloaded_temp_path and os.path.exists(downloaded_temp_path): |
|
|
|
file_ext = os.path.splitext(downloaded_temp_path)[1] |
|
if not file_ext: file_ext = ".glb" |
|
final_path = os.path.join(output_dir, f"{item_name_safe}{file_ext}") |
|
os.rename(downloaded_temp_path, final_path) |
|
print(f"Model saved to: {final_path}") |
|
return final_path, "Success" |
|
else: |
|
error_msg = f"handle_file failed to download or returned invalid path: {downloaded_temp_path}" |
|
print(error_msg) |
|
return None, error_msg |
|
else: |
|
error_msg = f"Job for '{refined_prompt_for_3d}' did not return a valid filepath in expected tuple elements." |
|
print(error_msg) |
|
|
|
print(f"Full result tuple: {result_tuple}") |
|
return None, error_msg |
|
|
|
except Exception as e: |
|
error_msg = f"Error calling Hunyuan3D API for '{refined_prompt_for_3d}': {e}" |
|
print(error_msg) |
|
return None, str(e) |
|
|
|
def create_zip(files_to_zip, zip_filepath): |
|
with zipfile.ZipFile(zip_filepath, 'w') as zf: |
|
for file_path in files_to_zip: |
|
if file_path and os.path.exists(file_path): |
|
zf.write(file_path, os.path.basename(file_path)) |
|
return zip_filepath |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 3D Asset Collection Generator") |
|
gr.Markdown("Generate a list based on a theme, refine prompts, preview images, and generate selected 3D models using Hunyuan3D-2.") |
|
if not HF_TOKEN: |
|
gr.Warning("Hugging Face Token not found. Prompt refinement quality/rate limits may be affected. Consider adding HUGGINGFACE_TOKEN to Space secrets.") |
|
if not image_pipeline: |
|
gr.Warning("Local Image Generation model failed to load. Image previews will be skipped. Check Space hardware/logs.") |
|
if not hunyuan_client: |
|
gr.Error("Failed to connect to the Hunyuan3D-2 Space. 3D generation will not work.") |
|
|
|
|
|
|
|
|
|
list_items_state = gr.State([]) |
|
refined_prompts_state = gr.State({}) |
|
image_paths_state = gr.State({}) |
|
selected_items_state = gr.State([]) |
|
generated_3d_files_state = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
theme_input = gr.Textbox(label="Theme", placeholder="e.g., reptiles, kitchen appliances, medieval weapons") |
|
count_input = gr.Number(label="Number of Items", value=5, minimum=1, step=1) |
|
|
|
generate_list_button = gr.Button("1. Generate List & Refine Prompts") |
|
list_output_display = gr.Markdown("List will appear here...") |
|
|
|
generate_images_button = gr.Button("2. Generate Image Previews", visible=False) |
|
|
|
image_gallery = gr.Gallery(label="Image Previews", visible=False, elem_id="image_gallery") |
|
|
|
selection_data = gr.Dataset(components=[gr.Textbox(visible=False), gr.Textbox(visible=False), gr.Textbox(visible=False)], |
|
headers=["Item Name", "Image", "Prompt"], |
|
label="Select Items for 3D Generation", |
|
visible=False) |
|
|
|
|
|
generate_3d_button = gr.Button("3. Generate 3D Models for Selected Items", visible=False) |
|
status_output = gr.Markdown("") |
|
final_zip_output = gr.File(label="Download 3D Model Collection (ZIP)", visible=False) |
|
|
|
|
|
|
|
|
|
def run_list_and_refine(theme, count): |
|
if not theme: |
|
return {list_output_display: "Please enter a theme.", generate_images_button: gr.Button(visible=False)} |
|
|
|
|
|
os.makedirs(os.path.join(OUTPUT_DIR, IMAGES_SUBDIR), exist_ok=True) |
|
os.makedirs(os.path.join(OUTPUT_DIR, MODELS_SUBDIR), exist_ok=True) |
|
|
|
|
|
gr.Info("Generating list...") |
|
items = generate_list_local(theme, int(count)) |
|
if not items or "Error:" in items[0]: |
|
return {list_output_display: f"Failed to generate list: {items[0] if items else 'Unknown error'}", |
|
generate_images_button: gr.Button(visible=False)} |
|
|
|
list_items_state.value = items |
|
|
|
gr.Info("Refining prompts via API...") |
|
refined_prompts = {} |
|
output_md = "### Generated List & Refined Prompts:\n\n" |
|
for item in items: |
|
refined = refine_prompt_api(item) |
|
refined_prompts[item] = refined |
|
output_md += f"* **{item}:** {refined}\n" |
|
|
|
refined_prompts_state.value = refined_prompts |
|
|
|
|
|
return { |
|
list_output_display: output_md, |
|
generate_images_button: gr.Button(visible=True) |
|
} |
|
|
|
generate_list_button.click( |
|
fn=run_list_and_refine, |
|
inputs=[theme_input, count_input], |
|
outputs=[list_output_display, generate_images_button, list_items_state, refined_prompts_state] |
|
) |
|
|
|
|
|
def run_image_generation(items, refined_prompts_dict): |
|
if not image_pipeline: |
|
|
|
gr.Warning("Image pipeline not loaded. Skipping image previews.") |
|
|
|
selection_samples = [[item, "N/A", refined_prompts_dict.get(item, "")] for item in items] |
|
image_paths_state.value = {} |
|
return { |
|
image_gallery: gr.Gallery(visible=False), |
|
selection_data: gr.Dataset(samples=selection_samples, visible=True), |
|
generate_3d_button: gr.Button(visible=True) |
|
} |
|
|
|
gr.Info("Generating image previews... (this may take a while)") |
|
image_paths = {} |
|
gallery_images = [] |
|
selection_samples = [] |
|
|
|
img_dir = os.path.join(OUTPUT_DIR, IMAGES_SUBDIR) |
|
|
|
for item in items: |
|
refined_prompt = refined_prompts_dict.get(item, f"Image of {item}") |
|
safe_item_name = "".join(c if c.isalnum() else "_" for c in item) |
|
img_filename = f"{safe_item_name}_{uuid.uuid4()}.png" |
|
img_path = os.path.join(img_dir, img_filename) |
|
|
|
generated_path = generate_image_local(refined_prompt, img_path) |
|
|
|
if generated_path: |
|
image_paths[item] = generated_path |
|
gallery_images.append(generated_path) |
|
selection_samples.append([item, generated_path, refined_prompt]) |
|
else: |
|
|
|
selection_samples.append([item, "Failed", refined_prompt]) |
|
|
|
|
|
image_paths_state.value = image_paths |
|
|
|
|
|
return { |
|
image_gallery: gr.Gallery(value=gallery_images, visible=True), |
|
selection_data: gr.Dataset(samples=selection_samples, visible=True), |
|
generate_3d_button: gr.Button(visible=True) |
|
} |
|
|
|
generate_images_button.click( |
|
fn=run_image_generation, |
|
inputs=[list_items_state, refined_prompts_state], |
|
outputs=[image_gallery, selection_data, generate_3d_button, image_paths_state] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_3d_generation(selection_evt: gr.SelectData, all_items_data): |
|
if not hunyuan_client: |
|
return {status_output: "Hunyuan3D client not initialized. Cannot generate.", final_zip_output: gr.File(visible=False)} |
|
|
|
selected_indices = selection_evt.index if selection_evt else [] |
|
if not selected_indices: |
|
return {status_output: "Please select items from the table above before generating 3D models.", final_zip_output: gr.File(visible=False)} |
|
|
|
|
|
selected_items_info = [all_items_data[i] for i in selected_indices] |
|
|
|
generated_files = [] |
|
status_messages = ["### 3D Generation Status:\n"] |
|
|
|
model_dir = os.path.join(OUTPUT_DIR, MODELS_SUBDIR) |
|
|
|
total_selected = len(selected_items_info) |
|
for i, (item_name, _, refined_prompt) in enumerate(selected_items_info): |
|
current_status = f"({i+1}/{total_selected}) Generating model for: **{item_name}**..." |
|
print(current_status) |
|
status_messages.append(f"* {current_status}") |
|
|
|
yield {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)} |
|
|
|
|
|
prompt_for_3d = refined_prompt |
|
|
|
item_name_safe = "".join(c if c.isalnum() else "_" for c in item_name) |
|
|
|
|
|
max_retries = 1 |
|
attempts = 0 |
|
model_path = None |
|
last_error = "Unknown error" |
|
|
|
while attempts <= max_retries: |
|
attempts += 1 |
|
if attempts > 1: |
|
status_messages.append(f" * Retrying ({attempts-1}/{max_retries})...") |
|
yield {status_output: "\n".join(status_messages)} |
|
time.sleep(2) |
|
|
|
model_path, msg = generate_3d_model_hunyuan(prompt_for_3d, model_dir, item_name_safe) |
|
last_error = msg |
|
if model_path: |
|
generated_files.append(model_path) |
|
status_messages.append(f" * Success! Model saved.") |
|
break |
|
else: |
|
status_messages.append(f" * Attempt {attempts} failed: {msg}") |
|
|
|
if not model_path: |
|
status_messages.append(f" * **Failed** after {attempts} attempt(s). Last error: {last_error}") |
|
|
|
|
|
|
|
yield {status_output: "\n".join(status_messages)} |
|
|
|
|
|
if generated_files: |
|
status_messages.append("\nCreating ZIP archive...") |
|
yield {status_output: "\n".join(status_messages)} |
|
zip_path = os.path.join(OUTPUT_DIR, ZIP_FILENAME) |
|
final_zip = create_zip(generated_files, zip_path) |
|
status_messages.append(f"\n**Collection ready!** Download '{ZIP_FILENAME}' below.") |
|
generated_3d_files_state.value = generated_files |
|
return {status_output: "\n".join(status_messages), final_zip_output: gr.File(value=final_zip, visible=True)} |
|
else: |
|
status_messages.append("\nNo 3D models were successfully generated.") |
|
return {status_output: "\n".join(status_messages), final_zip_output: gr.File(visible=False)} |
|
|
|
|
|
|
|
|
|
|
|
generate_3d_button.click( |
|
fn=run_3d_generation, |
|
inputs=[selection_data, selection_data], |
|
outputs=[status_output, final_zip_output, generated_3d_files_state] |
|
) |
|
|
|
|
|
|
|
demo.queue().launch(debug=True) |