File size: 2,548 Bytes
98bebfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import numpy as np
from PIL import Image
import os
from pytorch3d.io import load_obj
import trimesh
from pytorch3d.structures import Meshes
# from rembg import remove

def remove_color(arr):
    if arr.shape[-1] == 4:
        arr = arr[..., :3]
    
    # Convert to torch tensor
    if type(arr) is not torch.Tensor:
        arr = torch.tensor(arr, dtype=torch.int32)
    
    # Calculate diffs
    base = arr[0, 0]
    diffs = torch.abs(arr - base).sum(dim=-1)
    alpha = (diffs <= 80)
    
    arr[alpha] = 255
    alpha = ~alpha
    alpha = alpha.unsqueeze(-1).int() * 255
    arr = torch.cat([arr, alpha], dim=-1)
    
    return arr

def simple_remove_bkg_normal(imgs, rm_bkg_with_rembg, return_Image=False):
    """Only works for normal"""
    rets = []
    for img in imgs:
        if rm_bkg_with_rembg:
            from rembg import remove
            image = Image.fromarray(img.to(torch.uint8).detach().cpu().numpy())  if isinstance(img, torch.Tensor) else img
            removed_image = remove(image)
            arr = np.array(removed_image)
            arr = torch.tensor(arr, dtype=torch.uint8)
        else:
            arr = remove_color(img)

        if return_Image:
            rets.append(Image.fromarray(arr.to(torch.uint8).detach().cpu().numpy()))
        else:
            rets.append(arr.to(torch.uint8))
    
    return rets


def load_glb(file_path):
    # Load the .glb file as a scene and merge all meshes
    scene_or_mesh = trimesh.load(file_path)

    mesh = scene_or_mesh.dump(concatenate=True) if isinstance(scene_or_mesh, trimesh.Scene) else scene_or_mesh

    # Extract vertices and faces from the merged mesh
    verts = torch.tensor(mesh.vertices, dtype=torch.float32)
    faces = torch.tensor(mesh.faces, dtype=torch.int64)
    
    
    textured_mesh = Meshes(verts=[verts], faces=[faces])


    return textured_mesh

def load_obj_with_verts_faces(file_path, return_mesh=True):
    verts, faces, _ = load_obj(file_path)
    
    verts = torch.tensor(verts, dtype=torch.float32)
    faces = faces.verts_idx 
    faces = torch.tensor(faces, dtype=torch.int64)

    if return_mesh:
        return Meshes(verts=[verts], faces=[faces])
    else:
        return verts, faces

def normalize_mesh(vertices):
    min_vals, _ = torch.min(vertices, axis=0)
    max_vals, _ = torch.max(vertices, axis=0)
    center = (max_vals + min_vals) / 2
    vertices = vertices - center
    max_extent = torch.max(max_vals - min_vals)
    scale = 2.0 / max_extent
    vertices = vertices * scale
    return vertices