ginipick commited on
Commit
77ed819
·
verified ·
1 Parent(s): d945073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -37,9 +37,9 @@ class NagWanTransformer3DModel(nn.Module):
37
  self,
38
  in_channels: int = 4,
39
  out_channels: int = 4,
40
- hidden_size: int = 1280,
41
- num_layers: int = 2,
42
- num_heads: int = 8,
43
  ):
44
  super().__init__()
45
  self.in_channels = in_channels
@@ -57,15 +57,15 @@ class NagWanTransformer3DModel(nn.Module):
57
  })()
58
 
59
  # Simple conv layers for demo
60
- self.conv_in = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)
61
- self.conv_mid = nn.Conv3d(64, 64, kernel_size=3, padding=1)
62
- self.conv_out = nn.Conv3d(64, out_channels, kernel_size=3, padding=1)
63
 
64
  # Time embedding
65
  self.time_embed = nn.Sequential(
66
  nn.Linear(1, hidden_size),
67
  nn.SiLU(),
68
- nn.Linear(hidden_size, 64),
69
  )
70
 
71
  @staticmethod
@@ -428,10 +428,10 @@ DEFAULT_STEPS = 2
428
  DEFAULT_SEED = 2025
429
  DEFAULT_H_SLIDER_VALUE = 128
430
  DEFAULT_W_SLIDER_VALUE = 128
431
- NEW_FORMULA_MAX_AREA = 256.0 * 256.0
432
 
433
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 512
434
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 512
435
  MAX_SEED = np.iinfo(np.int32).max
436
 
437
  FIXED_FPS = 16
@@ -493,9 +493,9 @@ print("Creating simplified NAG transformer model...")
493
  transformer = NagWanTransformer3DModel(
494
  in_channels=4,
495
  out_channels=4,
496
- hidden_size=1280,
497
- num_layers=2, # Reduced for demo
498
- num_heads=8
499
  )
500
 
501
  print("Creating pipeline...")
@@ -511,8 +511,7 @@ pipe = NAGWanPipeline(
511
  beta_end=0.012,
512
  beta_schedule="scaled_linear",
513
  clip_sample=False,
514
- set_alpha_to_one=False,
515
- steps_offset=1,
516
  )
517
  )
518
 
 
37
  self,
38
  in_channels: int = 4,
39
  out_channels: int = 4,
40
+ hidden_size: int = 64,
41
+ num_layers: int = 1,
42
+ num_heads: int = 4,
43
  ):
44
  super().__init__()
45
  self.in_channels = in_channels
 
57
  })()
58
 
59
  # Simple conv layers for demo
60
+ self.conv_in = nn.Conv3d(in_channels, hidden_size, kernel_size=3, padding=1)
61
+ self.conv_mid = nn.Conv3d(hidden_size, hidden_size, kernel_size=3, padding=1)
62
+ self.conv_out = nn.Conv3d(hidden_size, out_channels, kernel_size=3, padding=1)
63
 
64
  # Time embedding
65
  self.time_embed = nn.Sequential(
66
  nn.Linear(1, hidden_size),
67
  nn.SiLU(),
68
+ nn.Linear(hidden_size, hidden_size),
69
  )
70
 
71
  @staticmethod
 
428
  DEFAULT_SEED = 2025
429
  DEFAULT_H_SLIDER_VALUE = 128
430
  DEFAULT_W_SLIDER_VALUE = 128
431
+ NEW_FORMULA_MAX_AREA = 128.0 * 128.0
432
 
433
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 256
434
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
435
  MAX_SEED = np.iinfo(np.int32).max
436
 
437
  FIXED_FPS = 16
 
493
  transformer = NagWanTransformer3DModel(
494
  in_channels=4,
495
  out_channels=4,
496
+ hidden_size=64, # Reduced from 1280 for demo
497
+ num_layers=1, # Reduced for demo
498
+ num_heads=4 # Reduced for demo
499
  )
500
 
501
  print("Creating pipeline...")
 
511
  beta_end=0.012,
512
  beta_schedule="scaled_linear",
513
  clip_sample=False,
514
+ prediction_type="epsilon",
 
515
  )
516
  )
517