Update modeling.py
Browse files- modeling.py +7 -0
modeling.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
from .dino_wrapper2 import DinoWrapper
|
9 |
from .transformer import TriplaneTransformer
|
10 |
from .synthesizer_part import TriplaneSynthesizer
|
|
|
11 |
|
12 |
class CameraEmbedder(nn.Module):
|
13 |
def __init__(self, raw_dim: int, embed_dim: int):
|
@@ -46,6 +47,8 @@ class LRMGenerator(PreTrainedModel):
|
|
46 |
def __init__(self, config: LRMGeneratorConfig):
|
47 |
super().__init__(config)
|
48 |
|
|
|
|
|
49 |
self.encoder_feat_dim = config.encoder_feat_dim
|
50 |
self.camera_embed_dim = config.camera_embed_dim
|
51 |
|
@@ -67,6 +70,10 @@ class LRMGenerator(PreTrainedModel):
|
|
67 |
)
|
68 |
|
69 |
def forward(self, image, camera):
|
|
|
|
|
|
|
|
|
70 |
assert image.shape[0] == camera.shape[0], "Batch size mismatch"
|
71 |
N = image.shape[0]
|
72 |
|
|
|
8 |
from .dino_wrapper2 import DinoWrapper
|
9 |
from .transformer import TriplaneTransformer
|
10 |
from .synthesizer_part import TriplaneSynthesizer
|
11 |
+
from .processor import LRMImageProcessor
|
12 |
|
13 |
class CameraEmbedder(nn.Module):
|
14 |
def __init__(self, raw_dim: int, embed_dim: int):
|
|
|
47 |
def __init__(self, config: LRMGeneratorConfig):
|
48 |
super().__init__(config)
|
49 |
|
50 |
+
self.image_processor = LRMImageProcessor(source_size=512)
|
51 |
+
|
52 |
self.encoder_feat_dim = config.encoder_feat_dim
|
53 |
self.camera_embed_dim = config.camera_embed_dim
|
54 |
|
|
|
70 |
)
|
71 |
|
72 |
def forward(self, image, camera):
|
73 |
+
|
74 |
+
# we use image processor directly in the forward pass
|
75 |
+
processed_image, source_camera = self.image_processor(image)
|
76 |
+
|
77 |
assert image.shape[0] == camera.shape[0], "Batch size mismatch"
|
78 |
N = image.shape[0]
|
79 |
|