Spaces:
Build error
Build error
add quick_gelu
Browse files- stable_diffusion.py +7 -3
stable_diffusion.py
CHANGED
@@ -28,6 +28,9 @@ def apply_seq(seqs, x):
|
|
28 |
def gelu(self):
|
29 |
return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
|
30 |
|
|
|
|
|
|
|
31 |
class Normalize(Module):
|
32 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
33 |
super(Normalize, self).__init__()
|
@@ -275,7 +278,7 @@ class GEGLU(Module):
|
|
275 |
|
276 |
def forward(self, x):
|
277 |
x, gate = self.proj(x).chunk(2, dim=-1)
|
278 |
-
return x *
|
279 |
|
280 |
class FeedForward(Module):
|
281 |
def __init__(self, dim, mult=4, name="FeedForward"):
|
@@ -523,7 +526,7 @@ class CLIPMLP(Module):
|
|
523 |
|
524 |
def forward(self, hidden_states):
|
525 |
hidden_states = self.fc1(hidden_states)
|
526 |
-
hidden_states =
|
527 |
hidden_states = self.fc2(hidden_states)
|
528 |
return hidden_states
|
529 |
|
@@ -926,6 +929,7 @@ def text2img(phrase, steps, model_file, guidance_scale, img_width, img_height, s
|
|
926 |
try:
|
927 |
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
|
928 |
im = Text2img.instance(args).forward(args.phrase)
|
|
|
929 |
finally:
|
930 |
pass
|
931 |
return im
|
@@ -954,4 +958,4 @@ if __name__ == "__main__":
|
|
954 |
|
955 |
im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
|
956 |
print(f"saving {args.out}")
|
957 |
-
im.save(args.out)
|
|
|
28 |
def gelu(self):
|
29 |
return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
|
30 |
|
31 |
+
def quick_gelu(x):
|
32 |
+
return x * torch.sigmoid(x * 1.702)
|
33 |
+
|
34 |
class Normalize(Module):
|
35 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
36 |
super(Normalize, self).__init__()
|
|
|
278 |
|
279 |
def forward(self, x):
|
280 |
x, gate = self.proj(x).chunk(2, dim=-1)
|
281 |
+
return x * quick_gelu(gate)
|
282 |
|
283 |
class FeedForward(Module):
|
284 |
def __init__(self, dim, mult=4, name="FeedForward"):
|
|
|
526 |
|
527 |
def forward(self, hidden_states):
|
528 |
hidden_states = self.fc1(hidden_states)
|
529 |
+
hidden_states = quick_gelu(hidden_states)
|
530 |
hidden_states = self.fc2(hidden_states)
|
531 |
return hidden_states
|
532 |
|
|
|
929 |
try:
|
930 |
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
|
931 |
im = Text2img.instance(args).forward(args.phrase)
|
932 |
+
im = Text2img.instance(args).decode_latent2img(im)
|
933 |
finally:
|
934 |
pass
|
935 |
return im
|
|
|
958 |
|
959 |
im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
|
960 |
print(f"saving {args.out}")
|
961 |
+
im.save(args.out)
|