Spaces:
Running
on
A100
Running
on
A100
Commit
•
8079453
1
Parent(s):
a5e4f9a
Update lora.py
Browse files
lora.py
CHANGED
@@ -114,7 +114,7 @@ class LoRAModule(torch.nn.Module):
|
|
114 |
|
115 |
lx = self.lora_up(lx)
|
116 |
|
117 |
-
return org_forwarded + lx * self.multiplier
|
118 |
|
119 |
|
120 |
class LoRAInfModule(LoRAModule):
|
@@ -219,7 +219,12 @@ class LoRAInfModule(LoRAModule):
|
|
219 |
|
220 |
def default_forward(self, x):
|
221 |
# print("default_forward", self.lora_name, x.size())
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
def forward(self, x):
|
225 |
if not self.enabled:
|
|
|
114 |
|
115 |
lx = self.lora_up(lx)
|
116 |
|
117 |
+
return org_forwarded + lx * self.multiplier * scale
|
118 |
|
119 |
|
120 |
class LoRAInfModule(LoRAModule):
|
|
|
219 |
|
220 |
def default_forward(self, x):
|
221 |
# print("default_forward", self.lora_name, x.size())
|
222 |
+
org_forward = self.org_forward(x)
|
223 |
+
lora_up_down = self.lora_up(self.lora_down(x))
|
224 |
+
print(org_forward)
|
225 |
+
print(lora_up_down)
|
226 |
+
print(self.multiplier)
|
227 |
+
return org_forward + lora_up_down * self.multiplier #* self.scale
|
228 |
|
229 |
def forward(self, x):
|
230 |
if not self.enabled:
|