huzefa11 commited on
Commit
0c4c4f3
·
verified ·
1 Parent(s): 637bebb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +63 -44
model.py CHANGED
@@ -1,12 +1,13 @@
 
1
  # Merge image encoder and fuse module to create an ID Encoder
2
- # send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding
3
 
4
  import torch
5
  import torch.nn as nn
6
  from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
7
  from transformers.models.clip.configuration_clip import CLIPVisionConfig
8
- from transformers import PretrainedConfig
9
 
 
10
  VISION_CONFIG_DICT = {
11
  "hidden_size": 1024,
12
  "intermediate_size": 4096,
@@ -17,10 +18,11 @@ VISION_CONFIG_DICT = {
17
  }
18
 
19
  class MLP(nn.Module):
 
20
  def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
21
  super().__init__()
22
  if use_residual:
23
- assert in_dim == out_dim
24
  self.layernorm = nn.LayerNorm(in_dim)
25
  self.fc1 = nn.Linear(in_dim, hidden_dim)
26
  self.fc2 = nn.Linear(hidden_dim, out_dim)
@@ -34,11 +36,11 @@ class MLP(nn.Module):
34
  x = self.act_fn(x)
35
  x = self.fc2(x)
36
  if self.use_residual:
37
- x = x + residual
38
  return x
39
 
40
-
41
  class FuseModule(nn.Module):
 
42
  def __init__(self, embed_dim):
43
  super().__init__()
44
  self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
@@ -46,68 +48,85 @@ class FuseModule(nn.Module):
46
  self.layer_norm = nn.LayerNorm(embed_dim)
47
 
48
  def fuse_fn(self, prompt_embeds, id_embeds):
49
- stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
50
- stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
51
- stacked_id_embeds = self.mlp2(stacked_id_embeds)
52
- stacked_id_embeds = self.layer_norm(stacked_id_embeds)
53
- return stacked_id_embeds
54
-
55
- def forward(
56
- self,
57
- prompt_embeds,
58
- id_embeds,
59
- class_tokens_mask,
60
- ) -> torch.Tensor:
61
- # id_embeds shape: [b, max_num_inputs, 1, 2048]
 
 
 
62
  id_embeds = id_embeds.to(prompt_embeds.dtype)
63
- num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
64
  batch_size, max_num_inputs = id_embeds.shape[:2]
65
- # seq_length: 77
66
  seq_length = prompt_embeds.shape[1]
67
- # flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
68
- flat_id_embeds = id_embeds.view(
69
- -1, id_embeds.shape[-2], id_embeds.shape[-1]
70
- )
71
- # valid_id_mask [b*max_num_inputs]
72
- valid_id_mask = (
73
- torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
74
- < num_inputs[:, None]
75
- )
76
  valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
77
 
78
- prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
79
- class_tokens_mask = class_tokens_mask.view(-1)
 
80
  valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
81
- # slice out the image token embeddings
82
- image_token_embeds = prompt_embeds[class_tokens_mask]
83
- stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
84
- assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
85
- prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
86
- updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
 
 
 
 
 
87
  return updated_prompt_embeds
88
 
89
  class PhotoMakerIDEncoder(CLIPVisionModelWithProjection):
 
90
  def __init__(self):
91
  super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT))
92
  self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
93
- self.fuse_module = FuseModule(2048)
94
 
95
  def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
 
 
 
 
 
 
 
 
 
96
  b, num_inputs, c, h, w = id_pixel_values.shape
97
  id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
98
 
99
- shared_id_embeds = self.vision_model(id_pixel_values)[1]
 
 
100
  id_embeds = self.visual_projection(shared_id_embeds)
101
  id_embeds_2 = self.visual_projection_2(shared_id_embeds)
102
 
103
  id_embeds = id_embeds.view(b, num_inputs, 1, -1)
104
- id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
105
 
106
- id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
107
- updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
108
 
109
- return updated_prompt_embeds
110
 
 
111
 
112
  if __name__ == "__main__":
113
- PhotoMakerIDEncoder()
 
 
1
+ # model.py
2
  # Merge image encoder and fuse module to create an ID Encoder
3
+ # Allows multiple ID images to update the text encoder with a stacked ID embedding.
4
 
5
  import torch
6
  import torch.nn as nn
7
  from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
8
  from transformers.models.clip.configuration_clip import CLIPVisionConfig
 
9
 
10
+ # Vision backbone configuration for the CLIP-based encoder
11
  VISION_CONFIG_DICT = {
12
  "hidden_size": 1024,
13
  "intermediate_size": 4096,
 
18
  }
19
 
20
  class MLP(nn.Module):
21
+ """Simple MLP block with optional residual connection."""
22
  def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
23
  super().__init__()
24
  if use_residual:
25
+ assert in_dim == out_dim, "Input and output dimensions must match when using residual."
26
  self.layernorm = nn.LayerNorm(in_dim)
27
  self.fc1 = nn.Linear(in_dim, hidden_dim)
28
  self.fc2 = nn.Linear(hidden_dim, out_dim)
 
36
  x = self.act_fn(x)
37
  x = self.fc2(x)
38
  if self.use_residual:
39
+ x += residual
40
  return x
41
 
 
42
  class FuseModule(nn.Module):
43
+ """Module that fuses prompt embeddings with ID embeddings."""
44
  def __init__(self, embed_dim):
45
  super().__init__()
46
  self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
 
48
  self.layer_norm = nn.LayerNorm(embed_dim)
49
 
50
  def fuse_fn(self, prompt_embeds, id_embeds):
51
+ """Performs two-step fusion of prompt and ID embeddings."""
52
+ stacked = torch.cat([prompt_embeds, id_embeds], dim=-1)
53
+ fused = self.mlp1(stacked) + prompt_embeds
54
+ fused = self.mlp2(fused)
55
+ return self.layer_norm(fused)
56
+
57
+ def forward(self, prompt_embeds, id_embeds, class_tokens_mask):
58
+ """
59
+ Args:
60
+ prompt_embeds (Tensor): Text encoder embeddings [batch, seq_len, embed_dim]
61
+ id_embeds (Tensor): ID embeddings [batch, max_inputs, 1, embed_dim]
62
+ class_tokens_mask (Tensor): Mask indicating which tokens to replace [batch, seq_len]
63
+
64
+ Returns:
65
+ Tensor: Updated prompt embeddings.
66
+ """
67
  id_embeds = id_embeds.to(prompt_embeds.dtype)
 
68
  batch_size, max_num_inputs = id_embeds.shape[:2]
 
69
  seq_length = prompt_embeds.shape[1]
70
+
71
+ num_inputs = class_tokens_mask.sum(dim=1)
72
+
73
+ flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1])
74
+ valid_id_mask = (torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] < num_inputs[:, None])
75
+
 
 
 
76
  valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
77
 
78
+ prompt_embeds_flat = prompt_embeds.view(-1, prompt_embeds.shape[-1])
79
+ class_tokens_mask_flat = class_tokens_mask.view(-1)
80
+
81
  valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
82
+
83
+ image_token_embeds = prompt_embeds_flat[class_tokens_mask_flat]
84
+ stacked_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
85
+
86
+ assert class_tokens_mask_flat.sum() == stacked_embeds.shape[0], (
87
+ f"Mismatch between mask sum and stacked embeds: {class_tokens_mask_flat.sum()} vs {stacked_embeds.shape[0]}"
88
+ )
89
+
90
+ prompt_embeds_flat.masked_scatter_(class_tokens_mask_flat[:, None], stacked_embeds.to(prompt_embeds.dtype))
91
+ updated_prompt_embeds = prompt_embeds_flat.view(batch_size, seq_length, -1)
92
+
93
  return updated_prompt_embeds
94
 
95
  class PhotoMakerIDEncoder(CLIPVisionModelWithProjection):
96
+ """ID Encoder combining vision features and text prompts."""
97
  def __init__(self):
98
  super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT))
99
  self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
100
+ self.fuse_module = FuseModule(embed_dim=2048)
101
 
102
  def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
103
+ """
104
+ Args:
105
+ id_pixel_values (Tensor): Images [batch, num_inputs, channels, height, width]
106
+ prompt_embeds (Tensor): Text embeddings [batch, seq_len, embed_dim]
107
+ class_tokens_mask (Tensor): Mask of class tokens to update
108
+
109
+ Returns:
110
+ Tensor: Updated text embeddings incorporating ID image features.
111
+ """
112
  b, num_inputs, c, h, w = id_pixel_values.shape
113
  id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
114
 
115
+ vision_outputs = self.vision_model(id_pixel_values)
116
+ shared_id_embeds = vision_outputs[1] # Use pooled output
117
+
118
  id_embeds = self.visual_projection(shared_id_embeds)
119
  id_embeds_2 = self.visual_projection_2(shared_id_embeds)
120
 
121
  id_embeds = id_embeds.view(b, num_inputs, 1, -1)
122
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
123
 
124
+ combined_id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
 
125
 
126
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, combined_id_embeds, class_tokens_mask)
127
 
128
+ return updated_prompt_embeds
129
 
130
  if __name__ == "__main__":
131
+ encoder = PhotoMakerIDEncoder()
132
+ print("PhotoMakerIDEncoder initialized successfully.")