File size: 11,284 Bytes
db6a3b7
3057b36
7d475c1
db6a3b7
690b53e
db6a3b7
9880f3d
7d475c1
db6a3b7
 
9880f3d
db6a3b7
 
9880f3d
db6a3b7
f4648fc
 
db6a3b7
ee210e2
 
 
 
f4648fc
 
 
ce0691d
 
 
 
 
f4648fc
 
 
 
 
 
 
ce0691d
 
 
 
 
 
 
 
 
 
f4648fc
 
d7b1815
ee210e2
 
 
 
 
bd46f72
a898014
 
db894f7
a898014
 
db6a3b7
a898014
9880f3d
 
 
 
 
 
 
 
 
 
 
 
 
a898014
9880f3d
ee210e2
 
9880f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a898014
9880f3d
3057b36
a898014
bd46f72
 
db894f7
a898014
db894f7
bd46f72
 
 
 
 
 
 
 
 
 
 
7d475c1
15fe7bc
 
a898014
 
db6a3b7
7d475c1
a898014
9880f3d
db6a3b7
ee210e2
 
8fb8605
 
 
 
ee210e2
 
8fb8605
 
 
ee210e2
 
8fb8605
ee210e2
 
 
 
 
 
 
db6a3b7
3057b36
9880f3d
a898014
690b53e
a898014
db6a3b7
 
 
 
 
 
 
 
 
 
 
7d475c1
ee210e2
7d475c1
 
ee210e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a898014
2e78ab8
db6a3b7
ee210e2
db6a3b7
 
 
 
 
 
 
2e7f188
a898014
db6a3b7
 
 
 
ee210e2
db6a3b7
 
 
a898014
 
ee210e2
a898014
 
 
db6a3b7
 
 
 
a898014
2e78ab8
db6a3b7
 
 
 
 
 
 
 
 
 
 
 
2e78ab8
db6a3b7
 
 
 
 
 
 
 
 
 
ee210e2
 
 
 
 
 
 
db6a3b7
 
 
ee210e2
c666caf
 
 
 
ee210e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import gradio as gr
import spaces
from gradio_litmodel3d import LitModel3D
import os
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
import uuid
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
from transformers import pipeline as translation_pipeline
from diffusers import FluxPipeline

MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = "/tmp/Trellis-demo"
os.makedirs(TMP_DIR, exist_ok=True)

def initialize_models():
    global pipeline, translator, flux_pipe
    
    # Hugging Face ํ† ํฐ ํ™•์ธ
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN environment variable is not set. Please set your Hugging Face token.")
    
    # Trellis ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
    pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
    pipeline.cuda()
    
    # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
    translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
    
    # Flux ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” - token ์ถ”๊ฐ€
    flux_pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", 
        torch_dtype=torch.bfloat16,
        use_auth_token=hf_token  # Hugging Face ํ† ํฐ ์ ์šฉ
    )
    flux_pipe.load_lora_weights(
        "gokaygokay/Flux-Game-Assets-LoRA-v2",
        use_auth_token=hf_token  # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ์‹œ์—๋„ ํ† ํฐ ์ ์šฉ
    )
    flux_pipe.fuse_lora(lora_scale=1.0)
    flux_pipe.to(device="cuda", dtype=torch.bfloat16)

def translate_if_korean(text):
    if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text):
        translated = translator(text)[0]['translation_text']
        return translated
    return text

def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
    trial_id = str(uuid.uuid4())
    processed_image = pipeline.preprocess_image(image)
    processed_image.save(f"{TMP_DIR}/{trial_id}.png")
    return trial_id, processed_image

def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
    return {
        'gaussian': {
            **gs.init_params,
            '_xyz': gs._xyz.cpu().numpy(),
            '_features_dc': gs._features_dc.cpu().numpy(),
            '_scaling': gs._scaling.cpu().numpy(),
            '_rotation': gs._rotation.cpu().numpy(),
            '_opacity': gs._opacity.cpu().numpy(),
        },
        'mesh': {
            'vertices': mesh.vertices.cpu().numpy(),
            'faces': mesh.faces.cpu().numpy(),
        },
        'trial_id': trial_id,
    }


def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
    gs = Gaussian(
        aabb=state['gaussian']['aabb'],
        sh_degree=state['gaussian']['sh_degree'],
        mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
        scaling_bias=state['gaussian']['scaling_bias'],
        opacity_bias=state['gaussian']['opacity_bias'],
        scaling_activation=state['gaussian']['scaling_activation'],
    )
    gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
    gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
    gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
    gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
    gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
    
    mesh = edict(
        vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
        faces=torch.tensor(state['mesh']['faces'], device='cuda'),
    )
    
    return gs, mesh, state['trial_id']

@spaces.GPU
def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
    if randomize_seed:
        seed = np.random.randint(0, MAX_SEED)
    outputs = pipeline.run(
        Image.open(f"{TMP_DIR}/{trial_id}.png"),
        seed=seed,
        formats=["gaussian", "mesh"],
        preprocess_image=False,
        sparse_structure_sampler_params={
            "steps": ss_sampling_steps,
            "cfg_strength": ss_guidance_strength,
        },
        slat_sampler_params={
            "steps": slat_sampling_steps,
            "cfg_strength": slat_guidance_strength,
        },
    )
    video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
    video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
    video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
    trial_id = uuid.uuid4()
    video_path = f"{TMP_DIR}/{trial_id}.mp4"
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
    imageio.mimsave(video_path, video, fps=15)
    state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
    return state, video_path

@spaces.GPU
def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
    # ๊ธฐ๋ณธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ถ”๊ฐ€
    base_prompt = "wbgmsst, 3D, white background"
    
    # ์‚ฌ์šฉ์ž ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ฒˆ์—ญ (ํ•œ๊ตญ์–ด์ธ ๊ฒฝ์šฐ)
    translated_prompt = translate_if_korean(prompt)
    
    # ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์กฐํ•ฉ
    final_prompt = f"{translated_prompt}, {base_prompt}"
    
    with torch.inference_mode():
        image = flux_pipe(
            prompt=[final_prompt],
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_steps
        ).images[0]
        
        return image

@spaces.GPU
def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
    gs, mesh, trial_id = unpack_state(state)
    glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
    glb_path = f"{TMP_DIR}/{trial_id}.glb"
    glb.export(glb_path)
    return glb_path, glb_path

def activate_button() -> gr.Button:
    return gr.Button(interactive=True)

def deactivate_button() -> gr.Button:
    return gr.Button(interactive=False)


with gr.Blocks() as demo:
    gr.Markdown("""
    # 3D Asset Creation & Text-to-Image Generation
    """)
    
    with gr.Tabs():
        with gr.TabItem("Image to 3D"):
            with gr.Row():
                with gr.Column():
                    image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
                    
                    with gr.Accordion(label="Generation Settings", open=False):
                        seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
                        randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                        gr.Markdown("Stage 1: Sparse Structure Generation")
                        with gr.Row():
                            ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
                            ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
                        gr.Markdown("Stage 2: Structured Latent Generation")
                        with gr.Row():
                            slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
                            slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)

                    generate_btn = gr.Button("Generate")
                    
                    with gr.Accordion(label="GLB Extraction Settings", open=False):
                        mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
                        texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
                    
                    extract_glb_btn = gr.Button("Extract GLB", interactive=False)

                with gr.Column():
                    video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
                    model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
                    download_glb = gr.DownloadButton(label="Download GLB", interactive=False)

        with gr.TabItem("Text to Image"):
            with gr.Row():
                with gr.Column():
                    text_prompt = gr.Textbox(
                        label="Text Prompt",
                        placeholder="Enter your image description...",
                        lines=3
                    )
                    
                    with gr.Row():
                        txt2img_height = gr.Slider(256, 1024, value=512, step=64, label="Height")
                        txt2img_width = gr.Slider(256, 1024, value=512, step=64, label="Width")
                    
                    with gr.Row():
                        guidance_scale = gr.Slider(1.0, 20.0, value=7.5, label="Guidance Scale")
                        num_steps = gr.Slider(1, 50, value=20, label="Number of Steps")
                    
                    generate_txt2img_btn = gr.Button("Generate Image")
                
                with gr.Column():
                    txt2img_output = gr.Image(label="Generated Image")
    
    trial_id = gr.Textbox(visible=False)
    output_buf = gr.State()

    # Example images
    with gr.Row():
        examples = gr.Examples(
            examples=[
                f'assets/example_image/{image}'
                for image in os.listdir("assets/example_image")
            ],
            inputs=[image_prompt],
            fn=preprocess_image,
            outputs=[trial_id, image_prompt],
            run_on_click=True,
            examples_per_page=64,
        )

# Handlers
    image_prompt.upload(
        preprocess_image,
        inputs=[image_prompt],
        outputs=[trial_id, image_prompt],
    )
    
    image_prompt.clear(
        lambda: '',
        outputs=[trial_id],
    )

    generate_btn.click(
        image_to_3d,
        inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
        outputs=[output_buf, video_output],
    ).then(
        activate_button,
        outputs=[extract_glb_btn],
    )

    video_output.clear(
        deactivate_button,
        outputs=[extract_glb_btn],
    )

    extract_glb_btn.click(
        extract_glb,
        inputs=[output_buf, mesh_simplify, texture_size],
        outputs=[model_output, download_glb],
    ).then(
        activate_button,
        outputs=[download_glb],
    )

    model_output.clear(
        deactivate_button,
        outputs=[download_glb],
    )

    # Text to Image ํ•ธ๋“ค๋Ÿฌ
    generate_txt2img_btn.click(
        generate_image_from_text,
        inputs=[text_prompt, txt2img_height, txt2img_width, guidance_scale, num_steps],
        outputs=[txt2img_output]
    )

# Launch the Gradio app
if __name__ == "__main__":
    initialize_models()  # ๋ชจ๋“  ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
    try:
        pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))    # Preload rembg
    except:
        pass
    demo.launch()