from __future__ import annotations

import math
import os
import subprocess
from pathlib import Path

import gradio as gr
import pygltflib
import trimesh


def convert_formats(path_input, target_ext):
    """
    Converts an input 3D model under input path to the target extensions format and returns a path to that file.
    :param path_input: path to user input
    :param target_ext: target extension
    :return: path to the input 3D model stored in target format.
    """
    path_input_base, ext = os.path.splitext(path_input)
    if ext == "." + target_ext:
        return path_input
    path_output = path_input_base + "." + target_ext
    if not os.path.exists(path_output):
        trimesh.load_mesh(path_input).export(path_output)
    return path_output


def add_lights(path_input, path_output):
    glb = pygltflib.GLTF2().load(path_input)

    N = 3  # default max num lights in Babylon.js is 4
    angle_step = 2 * math.pi / N

    lights_extension = {
        "lights": [
            {
                "type": "directional",
                "color": [1.0, 1.0, 1.0],
                "intensity": 2.0
            }
            for _ in range(N)
        ]
    }

    if "KHR_lights_punctual" not in glb.extensionsUsed:
        glb.extensionsUsed.append("KHR_lights_punctual")
    glb.extensions["KHR_lights_punctual"] = lights_extension

    light_nodes = []
    for i in range(N):
        angle = i * angle_step
        rotation = [
            0.0,
            math.sin(angle / 2),
            0.0,
            math.cos(angle / 2)
        ]
        node = {
            "rotation": rotation,
            "extensions": {
                "KHR_lights_punctual": {
                    "light": i
                }
            }
        }
        light_nodes.append(node)

    light_node_indices = list(range(len(glb.nodes), len(glb.nodes) + N))
    glb.nodes.extend(light_nodes)

    root_node_index = glb.scenes[glb.scene].nodes[0]
    root_node = glb.nodes[root_node_index]
    if hasattr(root_node, 'children'):
        root_node.children.extend(light_node_indices)
    else:
        root_node.children = light_node_indices

    glb.save(path_output)


class Model3D(gr.Model3D):
    """
    A simple overload of Gradio Model3D that accepts arbitrary 3D formats supported by trimesh.
    """

    def postprocess(self, y: str | Path | None) -> dict[str, str] | None:
        if y is not None:
            y = convert_formats(y, "glb")
        out = super().postprocess(y)
        return out


def breathe_new_life_into_3d_model(path_input, prompt):
    """
    @inproceedings{wang2023breathing,
        title={Breathing New Life into 3D Assets with Generative Repainting},
        author={Wang, Tianfu and Kanakis, Menelaos and Schindler, Konrad and Van Gool, Luc and Obukhov, Anton},
        booktitle={Proceedings of the British Machine Vision Conference (BMVC)},
        year={2023},
        publisher={BMVA Press}
    }
    """
    path_output_dir = path_input + ".output"
    os.makedirs(path_output_dir, exist_ok=True)

    path_input_ply = convert_formats(path_input, "ply")

    cmd = [
        "bash",
        "/repainting_3d_assets/code/scripts/conda_run.sh",
        "/repainting_3d_assets",
        path_input_ply,
        path_output_dir,
        prompt,
    ]
    result = subprocess.run(cmd, env=os.environ, text=True)

    if result.returncode != 0:
        print(f"Output: {result.stdout}")
        print(f"Stderr: {result.stderr}")
        raise RuntimeError("Processing failed")

    path_output_glb = os.path.join(path_output_dir, "model_draco.glb")
    path_output_glb_vis = path_output_glb[:-4] + "_vis.glb"
    add_lights(path_output_glb, path_output_glb_vis)

    return path_output_glb_vis


def run():
    desc = """
        <p align="center">
        <a title="Website" href="https://www.obukhov.ai/repainting_3d_assets" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
        </a>
        <a title="arXiv" href="https://arxiv.org/abs/2309.08523" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
        </a>
        <a title="Github" href="https://github.com/kongdai123/repainting_3d_assets" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://img.shields.io/github/stars/kongdai123/repainting_3d_assets?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
        </a>
        <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
        </a>
        </p>
        <p align="justify">
        Repaint your 3D models with a text prompt, guided by a method from our BMVC'2023 Oral paper 'Breathing New Life 
        into 3D Assets with Generative Repainting'. Simply drop a model into the left pane, specify your repainting 
        preferences, and wait for the outcome (~20 min). Explore precomputed examples at the bottom, or follow the 
        Project Website badge for additional precomputed models and comparison with other repainting techniques.    
        </p>
    """
    demo = gr.Interface(
        title="Repainting 3D Assets",
        description=desc,
        thumbnail="thumbnail.jpg",
        fn=breathe_new_life_into_3d_model,
        inputs=[
            Model3D(
                camera_position=(30.0, 90.0, 3.0),
                elem_classes="viewport",
                label="Input Model",
            ),
            gr.Textbox(label="Text Prompt"),
        ],
        outputs=[
            gr.Model3D(
                camera_position=(30.0, 90.0, 3.0),
                elem_classes="viewport",
                label="Repainted Model",
            ),
        ],
        examples=[
            [
                os.path.join(os.path.dirname(__file__), "files/horse.ply"),
                "pastel superhero unicorn",
            ],
        ],
        cache_examples=True,
        css="""
            .viewport {
                aspect-ratio: 16/9;
            }
        """,
        allow_flagging="never",
    )

    demo.queue().launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    run()