File size: 5,650 Bytes
431c7a8
 
 
c83db8b
431c7a8
8168982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3db2f15
0b67988
 
431c7a8
 
8168982
 
 
 
 
 
50b44f7
fe7389a
 
 
4386ef2
53b232b
8168982
 
 
 
 
 
3506dfe
8168982
 
 
 
 
 
 
e4ffaa9
8168982
 
 
e77b448
 
 
 
 
 
 
 
8168982
 
 
 
 
 
fbf7ae4
8168982
 
 
 
 
 
fbf7ae4
8168982
 
 
 
e4ffaa9
fbf7ae4
8168982
 
 
 
fbf7ae4
 
8168982
 
fbf7ae4
 
8168982
fbf7ae4
8168982
fbf7ae4
8168982
fbf7ae4
8168982
fbf7ae4
 
8168982
fbf7ae4
8168982
fbf7ae4
 
8168982
fbf7ae4
 
8168982
fbf7ae4
 
8168982
 
3db2f15
 
431c7a8
 
 
0c4832c
 
 
 
 
e4ffaa9
431c7a8
 
5c828ac
 
 
3506dfe
 
5c828ac
 
 
 
c561654
 
530d878
 
 
 
5c828ac
 
53b232b
 
5c828ac
4be8881
5c828ac
 
 
 
 
 
0b67988
5c828ac
431c7a8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic import Field
from typing import Optional
import logging
import os
import boto3
import json
import shlex
import subprocess
import tempfile
import time
import base64
import gradio as gr
import numpy as np
import rembg
import spaces
import torch
from PIL import Image
from functools import partial
import io
from io import BytesIO
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
import datetime

app = FastAPI()

subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))

from tsr.system import TSR
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation


if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

# torch.cuda.synchronize()

model = TSR.from_pretrained(
    "stabilityai/TripoSR",
    config_name="config.yaml",
    weight_name="model.ckpt",
)
model.renderer.set_chunk_size(131072) 
model.to(device)

rembg_session = rembg.new_session()
ACCESS = os.getenv("ACCESS")
SECRET = os.getenv("SECRET")
bedrock = boto3.client(service_name='bedrock', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
s3_client = boto3.client('s3',aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')




def upload_file_to_s3(file_path, bucket_name, object_name=None):
    
    s3_client.upload_file(file_path, bucket_name, object_name)

    return True


def check_input_image(input_image):
    if input_image is None:
        raise gr.Error("No image uploaded!")

def preprocess(input_image, do_remove_background, foreground_ratio):
    def fill_background(image):
        torch.cuda.synchronize()  # Ensure previous CUDA operations are complete
        image = np.array(image).astype(np.float32) / 255.0
        image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
        image = Image.fromarray((image * 255.0).astype(np.uint8))
        return image

    if do_remove_background:
        torch.cuda.synchronize()
        image = input_image.convert("RGB")
        image = remove_background(image, rembg_session)
        image = resize_foreground(image, foreground_ratio)
        image = fill_background(image)
        
        torch.cuda.synchronize()
    else:
        image = input_image
        if image.mode == "RGBA":
            image = fill_background(image)
    torch.cuda.synchronize()  # Wait for all CUDA operations to complete
    torch.cuda.empty_cache()
    return image



def generate(image, mc_resolution, formats=["obj", "glb"]):
    torch.cuda.synchronize()
    scene_codes = model(image, device=device)
    torch.cuda.synchronize()
    mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
    torch.cuda.synchronize()
    mesh = to_gradio_3d_orientation(mesh)
    torch.cuda.synchronize()
    
    mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
    torch.cuda.synchronize()
    mesh.export(mesh_path_glb.name)
    torch.cuda.synchronize()
    
    mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
    torch.cuda.synchronize()
    mesh.apply_scale([-1, 1, 1])
    mesh.export(mesh_path_obj.name)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    return mesh_path_obj.name, mesh_path_glb.name

    

@app.post("/process_image/")
async def process_image(
    file: UploadFile = File(...),
    seed: int = Form(...), 
    enhance_image: bool = Form(...),  # Default enhance_image value
    do_remove_background: bool = Form(...),  # Default do_remove_background value
    foreground_ratio: float = Form(...),  # Ratio must be between 0.0 and 1.0 (exclusive)
    mc_resolution: int = Form(...),  # Resolution must be between 256 and 4096
    auth: str = Form(...),  
    text_prompt: Optional[str] = Form(None)
):
    
    if auth == os.getenv("AUTHORIZE"):
        image_bytes = await file.read()
        image_pil = Image.open(BytesIO(image_bytes))

    
        preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio)
        mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution)
        timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
        object_name = f'object_{timestamp}.obj'
        object_name_2 = f'object_{timestamp}.glb'
        object_name_3 = f"object_{timestamp}.png"
        preprocessed_image_tempfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
        preprocessed.save(preprocessed_image_tempfile.name)
        upload_file_to_s3(preprocessed_image_tempfile.name, 'framebucket3d', object_name_3)
    
        if upload_file_to_s3(mesh_name_obj, 'framebucket3d',object_name) and upload_file_to_s3(mesh_name_glb, 'framebucket3d',object_name_2):
            # torch.cuda.synchronize()  # Wait for all CUDA operations to complete
            # torch.cuda.empty_cache()
            return {
                "img_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_3}",
                "obj_path": f"https://framebucket3d.s3.amazonaws.com/{object_name}",
                "glb_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_2}"
            }
    
        else:
            return {"Internal Server Error": False}
    else:
        return {"Authentication":"Failed"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)