xfh commited on
Commit
46dc4a6
·
1 Parent(s): 42da032

add quick_gelu

Browse files
Files changed (1) hide show
  1. stable_diffusion.py +7 -3
stable_diffusion.py CHANGED
@@ -28,6 +28,9 @@ def apply_seq(seqs, x):
28
  def gelu(self):
29
  return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
30
 
 
 
 
31
  class Normalize(Module):
32
  def __init__(self, in_channels, num_groups=32, name="normalize"):
33
  super(Normalize, self).__init__()
@@ -275,7 +278,7 @@ class GEGLU(Module):
275
 
276
  def forward(self, x):
277
  x, gate = self.proj(x).chunk(2, dim=-1)
278
- return x * gelu(gate)
279
 
280
  class FeedForward(Module):
281
  def __init__(self, dim, mult=4, name="FeedForward"):
@@ -523,7 +526,7 @@ class CLIPMLP(Module):
523
 
524
  def forward(self, hidden_states):
525
  hidden_states = self.fc1(hidden_states)
526
- hidden_states = gelu(hidden_states)
527
  hidden_states = self.fc2(hidden_states)
528
  return hidden_states
529
 
@@ -926,6 +929,7 @@ def text2img(phrase, steps, model_file, guidance_scale, img_width, img_height, s
926
  try:
927
  args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
928
  im = Text2img.instance(args).forward(args.phrase)
 
929
  finally:
930
  pass
931
  return im
@@ -954,4 +958,4 @@ if __name__ == "__main__":
954
 
955
  im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
956
  print(f"saving {args.out}")
957
- im.save(args.out)
 
28
  def gelu(self):
29
  return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
30
 
31
+ def quick_gelu(x):
32
+ return x * torch.sigmoid(x * 1.702)
33
+
34
  class Normalize(Module):
35
  def __init__(self, in_channels, num_groups=32, name="normalize"):
36
  super(Normalize, self).__init__()
 
278
 
279
  def forward(self, x):
280
  x, gate = self.proj(x).chunk(2, dim=-1)
281
+ return x * quick_gelu(gate)
282
 
283
  class FeedForward(Module):
284
  def __init__(self, dim, mult=4, name="FeedForward"):
 
526
 
527
  def forward(self, hidden_states):
528
  hidden_states = self.fc1(hidden_states)
529
+ hidden_states = quick_gelu(hidden_states)
530
  hidden_states = self.fc2(hidden_states)
531
  return hidden_states
532
 
 
929
  try:
930
  args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
931
  im = Text2img.instance(args).forward(args.phrase)
932
+ im = Text2img.instance(args).decode_latent2img(im)
933
  finally:
934
  pass
935
  return im
 
958
 
959
  im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
960
  print(f"saving {args.out}")
961
+ im.save(args.out)