File size: 7,899 Bytes
1774ce2
 
fc3eb0b
1774ce2
e24d56e
 
 
 
 
 
 
 
 
 
1219780
 
248ebfb
e24d56e
 
75a1da3
e24d56e
 
 
 
 
 
 
 
 
 
 
 
 
 
1774ce2
248ebfb
 
 
 
 
b2b3a97
248ebfb
 
b2b3a97
248ebfb
 
 
 
 
 
 
 
 
aa621d3
 
 
 
e24d56e
248ebfb
 
b2b3a97
248ebfb
 
b2b3a97
248ebfb
 
b2b3a97
de87ab7
e24d56e
e8dc013
aa621d3
e8dc013
 
 
 
 
248ebfb
1774ce2
834340a
e8dc013
 
834340a
 
e8dc013
 
834340a
e8dc013
 
 
834340a
e8dc013
 
834340a
 
e8dc013
 
 
834340a
e8dc013
 
834340a
e8dc013
 
834340a
 
 
 
 
e8dc013
 
834340a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8dc013
 
 
834340a
1774ce2
e24d56e
248ebfb
834340a
 
 
 
 
 
 
 
 
 
 
e24d56e
e8dc013
e24d56e
248ebfb
e8dc013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400fd40
e8dc013
 
 
 
 
 
 
 
 
 
 
 
248ebfb
e8dc013
 
 
248ebfb
e8dc013
 
 
 
 
248ebfb
 
 
 
 
 
aa621d3
e8dc013
aa621d3
248ebfb
e24d56e
e8dc013
1774ce2
e8dc013
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import shlex
import subprocess
import os

subprocess.run(shlex.split("pip install pip==24.0"), check=True)
subprocess.run(
    shlex.split(
        "pip install package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl --force-reinstall --no-deps"
    ), check=True
)
subprocess.run(
    shlex.split(
        "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
    ), check=True
)

# 모델 체크포인트 다운로드 및 torch 설정
if __name__ == "__main__":
    from huggingface_hub import snapshot_download

    snapshot_download("public-data/Unique3D", repo_type="model", local_dir="./ckpt")

    import os
    import sys
    sys.path.append(os.curdir)
    import torch
    torch.set_float32_matmul_precision('medium')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_grad_enabled(False)

import fire
import gradio as gr
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
from gradio_app.all_models import model_zoo

# ===============================
# Text-to-IMAGE 관련 API 함수 정의
# ===============================
def text_to_image(height, width, steps, scales, prompt, seed):
    """
    주어진 파라미터를 이용해 외부 API의 /process_and_save_image 엔드포인트를 호출하여 이미지를 생성한다.
    """
    from gradio_client import Client
    client = Client(os.getenv("CLIENT_API"))  # 기본값 설정
    result = client.predict(
        height,
        width,
        steps,
        scales,
        prompt,
        seed,
        api_name="/process_and_save_image"
    )
    if isinstance(result, dict):
        return result.get("url", None)
    else:
        return result

def update_random_seed():
    """
    외부 API의 /update_random_seed 엔드포인트를 호출하여 새로운 랜덤 시드 값을 가져온다.
    """
    from gradio_client import Client
    client = Client(os.getenv("CLIENT_API"))  # 기본값 설정
    return client.predict(api_name="/update_random_seed")


_TITLE = '''✨ 3D LLAMA Studio'''
_DESCRIPTION = '''
### Welcome to 3D Llama Studio - Your Advanced 3D Generation Platform

This platform offers two powerful features:
1. **Text/Image to 3D**: Generate detailed 3D models from text descriptions or reference images
2. **Text to Styled Image**: Create artistic images that can be used for 3D generation

*Note: Both English and Korean prompts are supported (영어와 한글 프롬프트 모두 지원됩니다)*
'''

# CSS 스타일 밝은 테마로 수정
custom_css = """
.gradio-container {
    background-color: #ffffff;
    color: #333333;
}
.tabs {
    background-color: #f8f9fa;
    border-radius: 10px;
    padding: 10px;
    margin: 10px 0;
    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.input-box {
    background-color: #ffffff;
    border: 1px solid #e0e0e0;
    border-radius: 8px;
    padding: 15px;
    margin: 10px 0;
    box-shadow: 0 1px 3px rgba(0,0,0,0.05);
}
.button-primary {
    background-color: #4a90e2 !important;
    border: none !important;
    color: white !important;
    transition: all 0.3s ease;
}
.button-primary:hover {
    background-color: #357abd !important;
    transform: translateY(-1px);
}
.button-secondary {
    background-color: #f0f0f0 !important;
    border: 1px solid #e0e0e0 !important;
    color: #333333 !important;
    transition: all 0.3s ease;
}
.button-secondary:hover {
    background-color: #e0e0e0 !important;
}
.main-title {
    color: #2c3e50;
    font-weight: bold;
    margin-bottom: 20px;
}
.slider-label {
    color: #2c3e50;
    font-weight: 500;
}
.textbox-input {
    border: 1px solid #e0e0e0 !important;
    background-color: #ffffff !important;
}
"""

# Gradio 테마 설정 수정
def launch():
    model_zoo.init_models()
    
    with gr.Blocks(
        title=_TITLE,
        css=custom_css,
        theme=gr.themes.Soft(
            primary_hue="blue",
            secondary_hue="slate",
            neutral_hue="slate",
            font=["Inter", "Arial", "sans-serif"]
        )
    ) as demo:

        with gr.Row():
            gr.Markdown('# ' + _TITLE, elem_classes="main-title")
        gr.Markdown(_DESCRIPTION)
        
        with gr.Tabs() as tabs:
            with gr.Tab("🎨 Text to Styled Image", elem_classes="tab"):
                with gr.Group(elem_classes="input-box"):
                    gr.Markdown("### Image Generation Settings")
                    with gr.Row():
                        with gr.Column():
                            height_slider = gr.Slider(
                                label="Image Height",
                                minimum=256,
                                maximum=2048,
                                step=64,
                                value=1024,
                                info="Select image height (pixels)"
                            )
                            width_slider = gr.Slider(
                                label="Image Width",
                                minimum=256,
                                maximum=2048,
                                step=64,
                                value=1024,
                                info="Select image width (pixels)"
                            )
                        with gr.Column():
                            steps_slider = gr.Slider(
                                label="Generation Steps",
                                minimum=1,
                                maximum=100,
                                step=1,
                                value=8,
                                info="More steps = higher quality but slower"
                            )
                            scales_slider = gr.Slider(
                                label="Guidance Scale",
                                minimum=1.0,
                                maximum=10.0,
                                step=0.1,
                                value=3.5,
                                info="How closely to follow the prompt"
                            )
                    
                    prompt_text = gr.Textbox(
                        label="Image Description",
                        placeholder="Enter your prompt here (English or Korean)",
                        lines=3,
                        elem_classes="input-box"
                    )
                    
                    with gr.Row():
                        seed_number = gr.Number(
                            label="Seed (Empty = Random)",
                            value=None,
                            elem_classes="input-box"
                        )
                        update_seed_button = gr.Button(
                            "🎲 Random Seed",
                            elem_classes="button-secondary"
                        )
                    
                    generate_button = gr.Button(
                        "🚀 Generate Image",
                        elem_classes="button-primary"
                    )
                
                with gr.Group(elem_classes="input-box"):
                    gr.Markdown("### Generated Result")
                    image_output = gr.Image(label="Output Image")
                
                update_seed_button.click(
                    fn=update_random_seed,
                    inputs=[],
                    outputs=seed_number
                )
                
                generate_button.click(
                    fn=text_to_image,
                    inputs=[height_slider, width_slider, steps_slider, scales_slider, prompt_text, seed_number],
                    outputs=image_output
                )

            with gr.Tab("🎯 Image to 3D", elem_classes="tab"):
                create_3d_ui("wkl")
                
    demo.queue().launch(share=True)

if __name__ == '__main__':
    fire.Fire(launch)