File size: 4,654 Bytes
1774ce2
 
 
e24d56e
 
 
 
 
 
 
 
 
 
1219780
 
248ebfb
e24d56e
 
75a1da3
e24d56e
 
 
 
 
 
 
 
 
 
 
 
 
 
1774ce2
248ebfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa621d3
 
 
 
 
e24d56e
248ebfb
 
 
 
 
 
 
 
 
 
 
 
aa621d3
e24d56e
aa621d3
 
248ebfb
1774ce2
 
248ebfb
e24d56e
248ebfb
 
 
e24d56e
248ebfb
e24d56e
248ebfb
 
 
aa621d3
248ebfb
aa621d3
248ebfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa621d3
 
 
248ebfb
aa621d3
e24d56e
 
1774ce2
e24d56e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import shlex
import subprocess

subprocess.run(shlex.split("pip install pip==24.0"), check=True)
subprocess.run(
    shlex.split(
        "pip install package/onnxruntime_gpu-1.17.0-cp310-cp310-manylinux_2_28_x86_64.whl --force-reinstall --no-deps"
    ), check=True
)
subprocess.run(
    shlex.split(
        "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
    ), check=True
)

# 모델 체크포인트 다운로드 및 torch 설정
if __name__ == "__main__":
    from huggingface_hub import snapshot_download

    snapshot_download("public-data/Unique3D", repo_type="model", local_dir="./ckpt")

    import os
    import sys
    sys.path.append(os.curdir)
    import torch
    torch.set_float32_matmul_precision('medium')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_grad_enabled(False)

import fire
import gradio as gr
from gradio_app.gradio_3dgen import create_ui as create_3d_ui
from gradio_app.all_models import model_zoo

# ===============================
# Text-to-IMAGE 관련 API 함수 정의
# ===============================
def text_to_image(height, width, steps, scales, prompt, seed):
    """
    주어진 파라미터를 이용해 외부 API (http://211.233.58.201:7971/)의
    /process_and_save_image 엔드포인트를 호출하여 이미지를 생성한다.
    """
    from gradio_client import Client
    client = Client("http://211.233.58.201:7971/")
    result = client.predict(
        height,
        width,
        steps,
        scales,
        prompt,
        seed,
        api_name="/process_and_save_image"
    )
    # 결과가 dict이면 "url" 키의 값을, 아니라면 그대로 반환합니다.
    if isinstance(result, dict):
        return result.get("url", None)
    else:
        return result

def update_random_seed():
    """
    외부 API의 /update_random_seed 엔드포인트를 호출하여
    새로운 랜덤 시드 값을 가져온다.
    """
    from gradio_client import Client
    client = Client("http://211.233.58.201:7971/")
    return client.predict(api_name="/update_random_seed")

# ===============================
# UI 타이틀 및 설명
# ===============================
_TITLE = '''3D-llama'''
_DESCRIPTION = '''
Text와 이미지를 이용하여 3D 모델을 생성할 수 있습니다.

'''

def launch():
    # 3D 모델 초기화
    model_zoo.init_models()
    
    # Gradio Blocks 생성 (두 탭 포함)
    with gr.Blocks(title=_TITLE) as demo:
        with gr.Row():
            gr.Markdown('# ' + _TITLE)
        gr.Markdown(_DESCRIPTION)
        
        # 탭 생성: 기존의 Text-to-3D와 새로 추가한 Text-to-IMAGE
        with gr.Tabs():

                
            with gr.Tab("Text to 3D Style IMAGE"):
                # 이미지 생성을 위한 파라미터 입력 컴포넌트 구성
                with gr.Row():
                    height_slider = gr.Slider(label="Height", minimum=256, maximum=2048, step=1, value=1024)
                    width_slider = gr.Slider(label="Width", minimum=256, maximum=2048, step=1, value=1024)
                with gr.Row():
                    steps_slider = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=8)
                    scales_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=3.5)
                prompt_text = gr.Textbox(label="Image Description", placeholder="Enter prompt here", lines=2)
                seed_number = gr.Number(label="Seed (optional, leave empty for random)", value=None)
                
                # 'Update Random Seed' 버튼을 누르면 API를 통해 새로운 시드값을 받아 입력란 업데이트
                update_seed_button = gr.Button("Update Random Seed")
                update_seed_button.click(fn=update_random_seed, inputs=[], outputs=seed_number)
                
                generate_button = gr.Button("Generate Image")
                image_output = gr.Image(label="Generated Image")
                
                # 'Generate Image' 버튼 클릭 시 text_to_image 함수를 호출하여 결과 이미지를 출력
                generate_button.click(
                    fn=text_to_image,
                    inputs=[height_slider, width_slider, steps_slider, scales_slider, prompt_text, seed_number],
                    outputs=image_output
                )

            with gr.Tab("Image to 3D"):
                create_3d_ui("wkl")
                
    # 공유 링크를 생성하기 위해 share=True 설정
    demo.queue().launch(share=True)
    
if __name__ == '__main__':
    fire.Fire(launch)