Anurag181011 commited on
Commit
022cc77
·
verified ·
1 Parent(s): 12afb40

Update src/attentionhacked_garmnet.py

Browse files
Files changed (1) hide show
  1. src/attentionhacked_garmnet.py +5 -1
src/attentionhacked_garmnet.py CHANGED
@@ -55,6 +55,7 @@ def _chunked_feed_forward(
55
  class GatedSelfAttentionDense(nn.Module):
56
  r"""
57
  A gated self-attention dense layer that combines visual features and object features.
 
58
  Parameters:
59
  query_dim (`int`): The number of channels in the query.
60
  context_dim (`int`): The number of channels in the context.
@@ -96,6 +97,7 @@ class GatedSelfAttentionDense(nn.Module):
96
  class BasicTransformerBlock(nn.Module):
97
  r"""
98
  A basic Transformer block.
 
99
  Parameters:
100
  dim (`int`): The number of channels in the input and output.
101
  num_attention_heads (`int`): The number of heads to use for multi-head attention.
@@ -408,6 +410,7 @@ class BasicTransformerBlock(nn.Module):
408
  class TemporalBasicTransformerBlock(nn.Module):
409
  r"""
410
  A basic Transformer block for video like data.
 
411
  Parameters:
412
  dim (`int`): The number of channels in the input and output.
413
  time_mix_inner_dim (`int`): The number of channels for temporal attention.
@@ -609,6 +612,7 @@ class SkipFFTransformerBlock(nn.Module):
609
  class FeedForward(nn.Module):
610
  r"""
611
  A feed-forward layer.
 
612
  Parameters:
613
  dim (`int`): The number of channels in the input.
614
  dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
@@ -663,4 +667,4 @@ class FeedForward(nn.Module):
663
  hidden_states = module(hidden_states, scale)
664
  else:
665
  hidden_states = module(hidden_states)
666
- return hidden_states
 
55
  class GatedSelfAttentionDense(nn.Module):
56
  r"""
57
  A gated self-attention dense layer that combines visual features and object features.
58
+
59
  Parameters:
60
  query_dim (`int`): The number of channels in the query.
61
  context_dim (`int`): The number of channels in the context.
 
97
  class BasicTransformerBlock(nn.Module):
98
  r"""
99
  A basic Transformer block.
100
+
101
  Parameters:
102
  dim (`int`): The number of channels in the input and output.
103
  num_attention_heads (`int`): The number of heads to use for multi-head attention.
 
410
  class TemporalBasicTransformerBlock(nn.Module):
411
  r"""
412
  A basic Transformer block for video like data.
413
+
414
  Parameters:
415
  dim (`int`): The number of channels in the input and output.
416
  time_mix_inner_dim (`int`): The number of channels for temporal attention.
 
612
  class FeedForward(nn.Module):
613
  r"""
614
  A feed-forward layer.
615
+
616
  Parameters:
617
  dim (`int`): The number of channels in the input.
618
  dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
 
667
  hidden_states = module(hidden_states, scale)
668
  else:
669
  hidden_states = module(hidden_states)
670
+ return hidden_states