jadechoghari commited on
Commit
47af768
·
verified ·
1 Parent(s): 95ddb7e

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +7 -0
modeling.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  from .dino_wrapper2 import DinoWrapper
9
  from .transformer import TriplaneTransformer
10
  from .synthesizer_part import TriplaneSynthesizer
 
11
 
12
  class CameraEmbedder(nn.Module):
13
  def __init__(self, raw_dim: int, embed_dim: int):
@@ -46,6 +47,8 @@ class LRMGenerator(PreTrainedModel):
46
  def __init__(self, config: LRMGeneratorConfig):
47
  super().__init__(config)
48
 
 
 
49
  self.encoder_feat_dim = config.encoder_feat_dim
50
  self.camera_embed_dim = config.camera_embed_dim
51
 
@@ -67,6 +70,10 @@ class LRMGenerator(PreTrainedModel):
67
  )
68
 
69
  def forward(self, image, camera):
 
 
 
 
70
  assert image.shape[0] == camera.shape[0], "Batch size mismatch"
71
  N = image.shape[0]
72
 
 
8
  from .dino_wrapper2 import DinoWrapper
9
  from .transformer import TriplaneTransformer
10
  from .synthesizer_part import TriplaneSynthesizer
11
+ from .processor import LRMImageProcessor
12
 
13
  class CameraEmbedder(nn.Module):
14
  def __init__(self, raw_dim: int, embed_dim: int):
 
47
  def __init__(self, config: LRMGeneratorConfig):
48
  super().__init__(config)
49
 
50
+ self.image_processor = LRMImageProcessor(source_size=512)
51
+
52
  self.encoder_feat_dim = config.encoder_feat_dim
53
  self.camera_embed_dim = config.camera_embed_dim
54
 
 
70
  )
71
 
72
  def forward(self, image, camera):
73
+
74
+ # we use image processor directly in the forward pass
75
+ processed_image, source_camera = self.image_processor(image)
76
+
77
  assert image.shape[0] == camera.shape[0], "Batch size mismatch"
78
  N = image.shape[0]
79