Fix incorrect image embedding when running with a single GPU and 24GB VRAM
#3
by
xdedss
- opened
- modeling_internvl.py +16 -0
modeling_internvl.py
CHANGED
@@ -114,13 +114,29 @@ class CrossAttention(nn.Module):
|
|
114 |
k_bias = self.k_bias
|
115 |
v_bias = self.v_bias
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
|
|
|
|
118 |
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
|
119 |
|
|
|
|
|
120 |
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
|
|
|
|
121 |
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
122 |
|
|
|
|
|
123 |
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
|
|
|
|
124 |
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
125 |
|
126 |
q = q * self.scale
|
|
|
114 |
k_bias = self.k_bias
|
115 |
v_bias = self.v_bias
|
116 |
|
117 |
+
# simulate module forward hooks to let accelerate load the actual weight
|
118 |
+
# see https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/hooks.py#L163
|
119 |
+
simulate_hooks = hasattr(self.q, '_hf_hook')
|
120 |
+
|
121 |
+
if simulate_hooks:
|
122 |
+
self.q._hf_hook.pre_forward(self.q, x)
|
123 |
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
124 |
+
if simulate_hooks:
|
125 |
+
self.q._hf_hook.post_forward(self.q, x)
|
126 |
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
|
127 |
|
128 |
+
if simulate_hooks:
|
129 |
+
self.k._hf_hook.pre_forward(self.k, k)
|
130 |
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
131 |
+
if simulate_hooks:
|
132 |
+
self.k._hf_hook.post_forward(self.k, k)
|
133 |
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
134 |
|
135 |
+
if simulate_hooks:
|
136 |
+
self.v._hf_hook.pre_forward(self.v, v)
|
137 |
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
138 |
+
if simulate_hooks:
|
139 |
+
self.v._hf_hook.post_forward(self.v, v)
|
140 |
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
141 |
|
142 |
q = q * self.scale
|