ginipick commited on
Commit
1618f2a
·
verified ·
1 Parent(s): f60eddf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # ZeroGPU 환경 설정 - 가장 먼저 실행되어야 함!
3
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
4
+ os.environ['ZEROGPU'] = '1' # ZeroGPU 환경임을 표시
5
+
6
+ import spaces # spaces import는 환경 설정 후에
7
+ import shlex
8
+ import subprocess
9
+
10
+ # 라이브러리 버전 호환성 문제 해결
11
+ subprocess.run(shlex.split("pip install pip==24.0"), check=True)
12
+
13
+ # transformers와 diffusers 버전 업데이트
14
+ subprocess.run(shlex.split("pip install transformers==4.44.0 --upgrade"), check=True)
15
+ subprocess.run(shlex.split("pip install diffusers==0.30.0 --upgrade"), check=True)
16
+
17
+ subprocess.run(
18
+ shlex.split(
19
+ "pip install package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl --force-reinstall --no-deps"
20
+ ), check=True
21
+ )
22
+ subprocess.run(
23
+ shlex.split(
24
+ "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
25
+ ), check=True
26
+ )
27
+
28
+ # 모델 체크포인트 다운로드 및 torch 설정
29
+ if __name__ == "__main__":
30
+ from huggingface_hub import snapshot_download
31
+
32
+ snapshot_download("public-data/Unique3D", repo_type="model", local_dir="./ckpt")
33
+
34
+ import os
35
+ import sys
36
+ sys.path.append(os.curdir)
37
+ import torch
38
+ torch.set_float32_matmul_precision('medium')
39
+ torch.backends.cuda.matmul.allow_tf32 = True
40
+ torch.set_grad_enabled(False)
41
+
42
+ import fire
43
+ import gradio as gr
44
+ from gradio_app.gradio_3dgen import create_ui as create_3d_ui
45
+ from gradio_app.all_models import model_zoo
46
+
47
+ # ===============================
48
+ # Text-to-IMAGE 관련 API 함수 정의
49
+ # ===============================
50
+ @spaces.GPU(duration=60) # GPU 사용 시간 60초로 설정
51
+ def text_to_image(height, width, steps, scales, prompt, seed):
52
+ """
53
+ 주어진 파라미터를 이용해 외부 API의 /process_and_save_image 엔드포인트를 호출하여 이미지를 생성한다.
54
+ """
55
+ # GPU가 할당된 상태에서 실행
56
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
57
+
58
+ from gradio_client import Client
59
+ client = Client(os.getenv("CLIENT_API"))
60
+ result = client.predict(
61
+ height,
62
+ width,
63
+ steps,
64
+ scales,
65
+ prompt,
66
+ seed,
67
+ api_name="/process_and_save_image"
68
+ )
69
+ if isinstance(result, dict):
70
+ return result.get("url", None)
71
+ else:
72
+ return result
73
+
74
+ def update_random_seed():
75
+ """
76
+ 외부 API의 /update_random_seed 엔드포인트를 호출하여 새로운 랜덤 시드 값을 가져온다.
77
+ """
78
+ from gradio_client import Client
79
+ client = Client(os.getenv("CLIENT_API"))
80
+ return client.predict(api_name="/update_random_seed")
81
+
82
+ # 3D 생성 함수를 위한 래퍼 (GPU 데코레이터 적용)
83
+ @spaces.GPU(duration=120) # 3D 생성은 더 많은 시간 필요
84
+ def generate_3d_wrapper(*args, **kwargs):
85
+ """3D 생성 함수를 GPU 환경에서 실행"""
86
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
87
+ # 실제 3D 생성 로직이 여기서 실행됨
88
+ # model_zoo의 함수들이 여기서 호출될 것임
89
+ return model_zoo.generate_3d(*args, **kwargs)
90
+
91
+ _TITLE = '''✨ 3D LLAMA Studio'''
92
+ _DESCRIPTION = '''
93
+ ### Welcome to 3D Llama Studio - Your Advanced 3D Generation Platform
94
+ This platform offers two powerful features:
95
+ 1. **Text/Image to 3D**: Generate detailed 3D models from text descriptions or reference images
96
+ 2. **Text to Styled Image**: Create artistic images that can be used for 3D generation
97
+ *Note: Both English and Korean prompts are supported (영어와 한글 프롬프트 모두 지원됩니다)*
98
+ **Running on ZeroGPU** 🚀
99
+ '''
100
+
101
+ # CSS 스타일 밝은 테마로 수정
102
+ custom_css = """
103
+ .gradio-container {
104
+ background-color: #ffffff;
105
+ color: #333333;
106
+ }
107
+ .tabs {
108
+ background-color: #f8f9fa;
109
+ border-radius: 10px;
110
+ padding: 10px;
111
+ margin: 10px 0;
112
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
113
+ }
114
+ .input-box {
115
+ background-color: #ffffff;
116
+ border: 1px solid #e0e0e0;
117
+ border-radius: 8px;
118
+ padding: 15px;
119
+ margin: 10px 0;
120
+ box-shadow: 0 1px 3px rgba(0,0,0,0.05);
121
+ }
122
+ .button-primary {
123
+ background-color: #4a90e2 !important;
124
+ border: none !important;
125
+ color: white !important;
126
+ transition: all 0.3s ease;
127
+ }
128
+ .button-primary:hover {
129
+ background-color: #357abd !important;
130
+ transform: translateY(-1px);
131
+ }
132
+ .button-secondary {
133
+ background-color: #f0f0f0 !important;
134
+ border: 1px solid #e0e0e0 !important;
135
+ color: #333333 !important;
136
+ transition: all 0.3s ease;
137
+ }
138
+ .button-secondary:hover {
139
+ background-color: #e0e0e0 !important;
140
+ }
141
+ .main-title {
142
+ color: #2c3e50;
143
+ font-weight: bold;
144
+ margin-bottom: 20px;
145
+ }
146
+ .slider-label {
147
+ color: #2c3e50;
148
+ font-weight: 500;
149
+ }
150
+ .textbox-input {
151
+ border: 1px solid #e0e0e0 !important;
152
+ background-color: #ffffff !important;
153
+ }
154
+ .zerogpu-badge {
155
+ background-color: #4CAF50;
156
+ color: white;
157
+ padding: 5px 10px;
158
+ border-radius: 5px;
159
+ font-size: 14px;
160
+ margin-left: 10px;
161
+ }
162
+ """
163
+
164
+ # Gradio 테마 설정 수정
165
+ def launch():
166
+ # CPU 모드로 모델 초기화
167
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
168
+ model_zoo.init_models()
169
+
170
+ with gr.Blocks(
171
+ title=_TITLE,
172
+ css=custom_css,
173
+ theme=gr.themes.Soft(
174
+ primary_hue="blue",
175
+ secondary_hue="slate",
176
+ neutral_hue="slate",
177
+ font=["Inter", "Arial", "sans-serif"]
178
+ )
179
+ ) as demo:
180
+
181
+ with gr.Row():
182
+ with gr.Column():
183
+ gr.Markdown('# ' + _TITLE, elem_classes="main-title")
184
+ with gr.Column():
185
+ gr.HTML('<span class="zerogpu-badge">ZeroGPU Enabled</span>')
186
+ gr.Markdown(_DESCRIPTION)
187
+
188
+ with gr.Tabs() as tabs:
189
+ with gr.Tab("🎨 Text to Styled Image", elem_classes="tab"):
190
+ with gr.Group(elem_classes="input-box"):
191
+ gr.Markdown("### Image Generation Settings")
192
+ with gr.Row():
193
+ with gr.Column():
194
+ height_slider = gr.Slider(
195
+ label="Image Height",
196
+ minimum=256,
197
+ maximum=2048,
198
+ step=64,
199
+ value=1024,
200
+ info="Select image height (pixels)"
201
+ )
202
+ width_slider = gr.Slider(
203
+ label="Image Width",
204
+ minimum=256,
205
+ maximum=2048,
206
+ step=64,
207
+ value=1024,
208
+ info="Select image width (pixels)"
209
+ )
210
+ with gr.Column():
211
+ steps_slider = gr.Slider(
212
+ label="Generation Steps",
213
+ minimum=1,
214
+ maximum=100,
215
+ step=1,
216
+ value=8,
217
+ info="More steps = higher quality but slower"
218
+ )
219
+ scales_slider = gr.Slider(
220
+ label="Guidance Scale",
221
+ minimum=1.0,
222
+ maximum=10.0,
223
+ step=0.1,
224
+ value=3.5,
225
+ info="How closely to follow the prompt"
226
+ )
227
+
228
+ prompt_text = gr.Textbox(
229
+ label="Image Description",
230
+ placeholder="Enter your prompt here (English or Korean)",
231
+ lines=3,
232
+ elem_classes="input-box"
233
+ )
234
+
235
+ with gr.Row():
236
+ seed_number = gr.Number(
237
+ label="Seed (Empty = Random)",
238
+ value=None,
239
+ elem_classes="input-box"
240
+ )
241
+ update_seed_button = gr.Button(
242
+ "🎲 Random Seed",
243
+ elem_classes="button-secondary"
244
+ )
245
+
246
+ generate_button = gr.Button(
247
+ "🚀 Generate Image",
248
+ elem_classes="button-primary"
249
+ )
250
+
251
+ with gr.Group(elem_classes="input-box"):
252
+ gr.Markdown("### Generated Result")
253
+ image_output = gr.Image(label="Output Image")
254
+
255
+ update_seed_button.click(
256
+ fn=update_random_seed,
257
+ inputs=[],
258
+ outputs=seed_number
259
+ )
260
+
261
+ generate_button.click(
262
+ fn=text_to_image,
263
+ inputs=[height_slider, width_slider, steps_slider, scales_slider, prompt_text, seed_number],
264
+ outputs=image_output
265
+ )
266
+
267
+ with gr.Tab("🎯 Image to 3D", elem_classes="tab"):
268
+ create_3d_ui("wkl")
269
+
270
+ demo.queue().launch(share=True)
271
+
272
+ if __name__ == '__main__':
273
+ fire.Fire(launch)