File size: 11,383 Bytes
1618f2a
 
 
 
 
89af645
51f4c93
 
 
89af645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51f4c93
1618f2a
 
 
 
 
 
 
89af645
51f4c93
 
1618f2a
 
 
 
 
 
 
 
 
 
 
89af645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1618f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import os
# ZeroGPU 환경 설정 - 가장 먼저 실행되어야 함!
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['ZEROGPU'] = '1'  # ZeroGPU 환경임을 표시

# Safetensors 사용 강제
os.environ['SAFETENSORS_FAST_GPU'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '0'
os.environ['TRANSFORMERS_USE_SAFETENSORS'] = '1'

# Patch transformers to add missing SiglipImageProcessor
import sys
from types import ModuleType

# Create mock for SiglipImageProcessor before importing transformers
if 'transformers' not in sys.modules:
    # Create a dummy SiglipImageProcessor class
    class DummySiglipImageProcessor:
        def __init__(self, *args, **kwargs):
            pass
    
    class DummySiglipVisionModel:
        def __init__(self, *args, **kwargs):
            pass
    
    # Pre-patch transformers module
    transformers_module = ModuleType('transformers')
    transformers_module.SiglipImageProcessor = DummySiglipImageProcessor
    transformers_module.SiglipVisionModel = DummySiglipVisionModel
    
    # Add to sys.modules
    sys.modules['transformers'] = transformers_module

import spaces  # spaces import는 환경 설정 후에
import shlex
import subprocess

# 라이브러리 버전 호환성 문제 해결
subprocess.run(shlex.split("pip install pip==24.0"), check=True)

# Safetensors 설치
subprocess.run(shlex.split("pip install safetensors --upgrade"), check=True)

subprocess.run(
    shlex.split(
        "pip install package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl --force-reinstall --no-deps"
    ), check=True
)
subprocess.run(
    shlex.split(
        "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
    ), check=True
)

# Import transformers and patch it
import transformers

# Add missing classes if they don't exist
if not hasattr(transformers, 'SiglipImageProcessor'):
    class SiglipImageProcessor:
        def __init__(self, *args, **kwargs):
            # Fallback to CLIPImageProcessor
            self._processor = transformers.CLIPImageProcessor(*args, **kwargs)
        
        def __getattr__(self, name):
            return getattr(self._processor, name)
    
    transformers.SiglipImageProcessor = SiglipImageProcessor

if not hasattr(transformers, 'SiglipVisionModel'):
    class SiglipVisionModel:
        def __init__(self, *args, **kwargs):
            # Fallback to CLIPVisionModel
            from transformers import CLIPVisionModel
            self._model = CLIPVisionModel(*args, **kwargs)
        
        def __getattr__(self, name):
            return getattr(self._model, name)
        
        @classmethod
        def from_pretrained(cls, *args, **kwargs):
            instance = cls.__new__(cls)
            from transformers import CLIPVisionModel
            instance._model = CLIPVisionModel.from_pretrained(*args, **kwargs)
            return instance
    
    transformers.SiglipVisionModel = SiglipVisionModel

# 모델 체크포인트 다운로드 및 torch 설정
if __name__ == "__main__":
    from huggingface_hub import snapshot_download

    snapshot_download("public-data/Unique3D", repo_type="model", local_dir="./ckpt")

    import os
    import sys
    sys.path.append(os.curdir)
    import torch
    torch.set_float32_matmul_precision('medium')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_grad_enabled(False)

import fire
import gradio as gr
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
from gradio_app.all_models import model_zoo

# ===============================
# Text-to-IMAGE 관련 API 함수 정의
# ===============================
@spaces.GPU(duration=60)  # GPU 사용 시간 60초로 설정
def text_to_image(height, width, steps, scales, prompt, seed):
    """
    주어진 파라미터를 이용해 외부 API의 /process_and_save_image 엔드포인트를 호출하여 이미지를 생성한다.
    """
    # GPU가 할당된 상태에서 실행
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    
    from gradio_client import Client
    client = Client(os.getenv("CLIENT_API"))
    result = client.predict(
        height,
        width,
        steps,
        scales,
        prompt,
        seed,
        api_name="/process_and_save_image"
    )
    if isinstance(result, dict):
        return result.get("url", None)
    else:
        return result

def update_random_seed():
    """
    외부 API의 /update_random_seed 엔드포인트를 호출하여 새로운 랜덤 시드 값을 가져온다.
    """
    from gradio_client import Client
    client = Client(os.getenv("CLIENT_API"))
    return client.predict(api_name="/update_random_seed")

# 3D 생성 함수를 위한 래퍼 (GPU 데코레이터 적용)
@spaces.GPU(duration=120)  # 3D 생성은 더 많은 시간 필요
def generate_3d_wrapper(*args, **kwargs):
    """3D 생성 함수를 GPU 환경에서 실행"""
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    # 실제 3D 생성 로직이 여기서 실행됨
    # model_zoo의 함수들이 여기서 호출될 것임
    return model_zoo.generate_3d(*args, **kwargs)

_TITLE = '''✨ 3D LLAMA Studio'''
_DESCRIPTION = '''
### Welcome to 3D Llama Studio - Your Advanced 3D Generation Platform
This platform offers two powerful features:
1. **Text/Image to 3D**: Generate detailed 3D models from text descriptions or reference images
2. **Text to Styled Image**: Create artistic images that can be used for 3D generation
*Note: Both English and Korean prompts are supported (영어와 한글 프롬프트 모두 지원됩니다)*
**Running on ZeroGPU** 🚀
'''

# CSS 스타일 밝은 테마로 수정
custom_css = """
.gradio-container {
    background-color: #ffffff;
    color: #333333;
}
.tabs {
    background-color: #f8f9fa;
    border-radius: 10px;
    padding: 10px;
    margin: 10px 0;
    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.input-box {
    background-color: #ffffff;
    border: 1px solid #e0e0e0;
    border-radius: 8px;
    padding: 15px;
    margin: 10px 0;
    box-shadow: 0 1px 3px rgba(0,0,0,0.05);
}
.button-primary {
    background-color: #4a90e2 !important;
    border: none !important;
    color: white !important;
    transition: all 0.3s ease;
}
.button-primary:hover {
    background-color: #357abd !important;
    transform: translateY(-1px);
}
.button-secondary {
    background-color: #f0f0f0 !important;
    border: 1px solid #e0e0e0 !important;
    color: #333333 !important;
    transition: all 0.3s ease;
}
.button-secondary:hover {
    background-color: #e0e0e0 !important;
}
.main-title {
    color: #2c3e50;
    font-weight: bold;
    margin-bottom: 20px;
}
.slider-label {
    color: #2c3e50;
    font-weight: 500;
}
.textbox-input {
    border: 1px solid #e0e0e0 !important;
    background-color: #ffffff !important;
}
.zerogpu-badge {
    background-color: #4CAF50;
    color: white;
    padding: 5px 10px;
    border-radius: 5px;
    font-size: 14px;
    margin-left: 10px;
}
"""

# Gradio 테마 설정 수정
def launch():
    # CPU 모드로 모델 초기화
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    model_zoo.init_models()
    
    with gr.Blocks(
        title=_TITLE,
        css=custom_css,
        theme=gr.themes.Soft(
            primary_hue="blue",
            secondary_hue="slate",
            neutral_hue="slate",
            font=["Inter", "Arial", "sans-serif"]
        )
    ) as demo:

        with gr.Row():
            with gr.Column():
                gr.Markdown('# ' + _TITLE, elem_classes="main-title")
            with gr.Column():
                gr.HTML('<span class="zerogpu-badge">ZeroGPU Enabled</span>')
        gr.Markdown(_DESCRIPTION)
        
        with gr.Tabs() as tabs:
            with gr.Tab("🎨 Text to Styled Image", elem_classes="tab"):
                with gr.Group(elem_classes="input-box"):
                    gr.Markdown("### Image Generation Settings")
                    with gr.Row():
                        with gr.Column():
                            height_slider = gr.Slider(
                                label="Image Height",
                                minimum=256,
                                maximum=2048,
                                step=64,
                                value=1024,
                                info="Select image height (pixels)"
                            )
                            width_slider = gr.Slider(
                                label="Image Width",
                                minimum=256,
                                maximum=2048,
                                step=64,
                                value=1024,
                                info="Select image width (pixels)"
                            )
                        with gr.Column():
                            steps_slider = gr.Slider(
                                label="Generation Steps",
                                minimum=1,
                                maximum=100,
                                step=1,
                                value=8,
                                info="More steps = higher quality but slower"
                            )
                            scales_slider = gr.Slider(
                                label="Guidance Scale",
                                minimum=1.0,
                                maximum=10.0,
                                step=0.1,
                                value=3.5,
                                info="How closely to follow the prompt"
                            )
                    
                    prompt_text = gr.Textbox(
                        label="Image Description",
                        placeholder="Enter your prompt here (English or Korean)",
                        lines=3,
                        elem_classes="input-box"
                    )
                    
                    with gr.Row():
                        seed_number = gr.Number(
                            label="Seed (Empty = Random)",
                            value=None,
                            elem_classes="input-box"
                        )
                        update_seed_button = gr.Button(
                            "🎲 Random Seed",
                            elem_classes="button-secondary"
                        )
                    
                    generate_button = gr.Button(
                        "🚀 Generate Image",
                        elem_classes="button-primary"
                    )
                
                with gr.Group(elem_classes="input-box"):
                    gr.Markdown("### Generated Result")
                    image_output = gr.Image(label="Output Image")
                
                update_seed_button.click(
                    fn=update_random_seed,
                    inputs=[],
                    outputs=seed_number
                )
                
                generate_button.click(
                    fn=text_to_image,
                    inputs=[height_slider, width_slider, steps_slider, scales_slider, prompt_text, seed_number],
                    outputs=image_output
                )

            with gr.Tab("🎯 Image to 3D", elem_classes="tab"):
                create_3d_ui("wkl")
                
    demo.queue().launch(share=True)

if __name__ == '__main__':
    fire.Fire(launch)