import logging
import os
import tempfile
import time

import numpy as np
import rembg
import torch
from PIL import Image
from functools import partial

from tsr.system import TSR
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation

import argparse

if torch.cuda.is_available():
    device = "cuda:0"
    device = "cpu"

model = TSR.from_pretrained(

# adjust the chunk size to balance between speed and memory usage

rembg_session = rembg.new_session()

def preprocess(input_image, do_remove_background, foreground_ratio):
    def fill_background(image):
        image = np.array(image).astype(np.float32) / 255.0
        image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
        image = Image.fromarray((image * 255.0).astype(np.uint8))
        return image

    if do_remove_background:
        image = input_image.convert("RGB")
        image = remove_background(image, rembg_session)
        image = resize_foreground(image, foreground_ratio)
        image = fill_background(image)
        image = input_image
        if image.mode == "RGBA":
            image = fill_background(image)
    return image

def generate(image, mc_resolution, formats=["obj", "glb"], path="output.obj"):
    scene_codes = model(image, device=device)
    mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
    mesh = to_gradio_3d_orientation(mesh)
    rv = []
    for format in formats:
        mesh_path = path.replace(".obj", f".{format}")
    return rv

def run_example(image_pil):
    preprocessed = preprocess(image_pil, False, 0.9)
    mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
    return preprocessed, mesh_name_obj, mesh_name_glb

def generate_obj_from_image(image_pil, path="output.obj"):
        # Preprocess the image without removing the background and with a foreground ratio of 0.9
        print("Preprocessing image")
        preprocessed = preprocess(image_pil, True, 0.9)
        print("Generating mesh")
        # Generate the mesh and get the paths to the .obj and .glb files
        mesh_paths = generate(preprocessed, 256, ["obj"], path)
    except Exception as e:
        print(f"Error generating mesh: {e}")
        return None
    # Return the path to the .obj file
    return mesh_paths[0]

if __name__ == "__main__":
    # run a test
    image_path = "output.png"
    image =
    generate_obj_from_image(image, "output.obj")
    # move the .obj file to the output directory