FQiao commited on
Commit
984c5f5
·
verified ·
1 Parent(s): 2936be9

Update genstereo/GenStereo.py

Browse files
Files changed (1) hide show
  1. genstereo/GenStereo.py +17 -9
genstereo/GenStereo.py CHANGED
@@ -24,7 +24,7 @@ from .models import (
24
  UNet3DConditionModel,
25
  ReferenceAttentionControl
26
  )
27
- from .ops import get_viewport_matrix, forward_warper, convert_left_to_right, convert_left_to_right_torch
28
 
29
  class AdaptiveFusionLayer(nn.Module):
30
  def __init__(self):
@@ -47,8 +47,8 @@ class GenStereo():
47
  pretrained_model_path: str = ''
48
  checkpoint_name: str = ''
49
  half_precision_weights: bool = False
50
- height: int = 512
51
- width: int = 512
52
  num_inference_steps: int = 50
53
  guidance_scale: float = 1.5
54
  cfg: Config
@@ -88,18 +88,28 @@ class GenStereo():
88
  def __init__(
89
  self,
90
  cfg: Optional[Union[dict, DictConfig]] = None,
91
- device: Optional[str] = 'cuda:0'
 
92
  ) -> None:
93
  self.cfg = OmegaConf.structured(self.Config(**cfg))
94
  self.model_path = join(
95
  self.cfg.pretrained_model_path, self.cfg.checkpoint_name
96
  )
97
  self.device = device
 
98
  self.configure()
99
  self.transform_pixels = transforms.Compose([
100
  transforms.ToTensor(), # Converts image to Tensor
101
  transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
102
- ])
 
 
 
 
 
 
 
 
103
 
104
  def configure(self) -> None:
105
  print(f"Loading GenStereo...")
@@ -108,10 +118,6 @@ class GenStereo():
108
  self.dtype = (
109
  torch.float16 if self.cfg.half_precision_weights else torch.float32
110
  )
111
- self.viewport_mtx: Float[Tensor, 'B 4 4'] = get_viewport_matrix(
112
- self.cfg.width, self.cfg.height,
113
- batch_size=1, device=self.device
114
- ).to(self.dtype)
115
 
116
  # Load models.
117
  self.load_models()
@@ -276,6 +282,8 @@ class GenStereo():
276
  ).image_embeds
277
 
278
  image_prompt_embeds = clip_image_embeds.unsqueeze(1)
 
 
279
  uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
280
 
281
  image_prompt_embeds = torch.cat(
 
24
  UNet3DConditionModel,
25
  ReferenceAttentionControl
26
  )
27
+ from .ops import convert_left_to_right, convert_left_to_right_torch
28
 
29
  class AdaptiveFusionLayer(nn.Module):
30
  def __init__(self):
 
47
  pretrained_model_path: str = ''
48
  checkpoint_name: str = ''
49
  half_precision_weights: bool = False
50
+ height: int = 768
51
+ width: int = 768
52
  num_inference_steps: int = 50
53
  guidance_scale: float = 1.5
54
  cfg: Config
 
88
  def __init__(
89
  self,
90
  cfg: Optional[Union[dict, DictConfig]] = None,
91
+ device: Optional[str] = 'cuda:0',
92
+ sd_version: Optional[str] = 'v2.1'
93
  ) -> None:
94
  self.cfg = OmegaConf.structured(self.Config(**cfg))
95
  self.model_path = join(
96
  self.cfg.pretrained_model_path, self.cfg.checkpoint_name
97
  )
98
  self.device = device
99
+ self.sd_version = sd_version
100
  self.configure()
101
  self.transform_pixels = transforms.Compose([
102
  transforms.ToTensor(), # Converts image to Tensor
103
  transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
104
+ ])
105
+ if self.sd_version == "v1.5":
106
+ self.cfg.height = 512
107
+ self.cfg.width = 512
108
+ elif self.sd_version == "v2.1":
109
+ self.cfg.height = 768
110
+ self.cfg.width = 768
111
+ else:
112
+ raise ValueError(f"Unknown SD version: {self.sd_version}")
113
 
114
  def configure(self) -> None:
115
  print(f"Loading GenStereo...")
 
118
  self.dtype = (
119
  torch.float16 if self.cfg.half_precision_weights else torch.float32
120
  )
 
 
 
 
121
 
122
  # Load models.
123
  self.load_models()
 
282
  ).image_embeds
283
 
284
  image_prompt_embeds = clip_image_embeds.unsqueeze(1)
285
+ if self.sd_version == "v2.1":
286
+ image_prompt_embeds = F.pad(image_prompt_embeds, (0, 256), "constant", 0) # Now shape is (bs, 1, 1024)
287
  uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
288
 
289
  image_prompt_embeds = torch.cat(