Spaces:
Running
on
Zero
Running
on
Zero
Update src/attentionhacked_garmnet.py
Browse files
src/attentionhacked_garmnet.py
CHANGED
@@ -55,6 +55,7 @@ def _chunked_feed_forward(
|
|
55 |
class GatedSelfAttentionDense(nn.Module):
|
56 |
r"""
|
57 |
A gated self-attention dense layer that combines visual features and object features.
|
|
|
58 |
Parameters:
|
59 |
query_dim (`int`): The number of channels in the query.
|
60 |
context_dim (`int`): The number of channels in the context.
|
@@ -96,6 +97,7 @@ class GatedSelfAttentionDense(nn.Module):
|
|
96 |
class BasicTransformerBlock(nn.Module):
|
97 |
r"""
|
98 |
A basic Transformer block.
|
|
|
99 |
Parameters:
|
100 |
dim (`int`): The number of channels in the input and output.
|
101 |
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
@@ -408,6 +410,7 @@ class BasicTransformerBlock(nn.Module):
|
|
408 |
class TemporalBasicTransformerBlock(nn.Module):
|
409 |
r"""
|
410 |
A basic Transformer block for video like data.
|
|
|
411 |
Parameters:
|
412 |
dim (`int`): The number of channels in the input and output.
|
413 |
time_mix_inner_dim (`int`): The number of channels for temporal attention.
|
@@ -609,6 +612,7 @@ class SkipFFTransformerBlock(nn.Module):
|
|
609 |
class FeedForward(nn.Module):
|
610 |
r"""
|
611 |
A feed-forward layer.
|
|
|
612 |
Parameters:
|
613 |
dim (`int`): The number of channels in the input.
|
614 |
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
@@ -663,4 +667,4 @@ class FeedForward(nn.Module):
|
|
663 |
hidden_states = module(hidden_states, scale)
|
664 |
else:
|
665 |
hidden_states = module(hidden_states)
|
666 |
-
return hidden_states
|
|
|
55 |
class GatedSelfAttentionDense(nn.Module):
|
56 |
r"""
|
57 |
A gated self-attention dense layer that combines visual features and object features.
|
58 |
+
|
59 |
Parameters:
|
60 |
query_dim (`int`): The number of channels in the query.
|
61 |
context_dim (`int`): The number of channels in the context.
|
|
|
97 |
class BasicTransformerBlock(nn.Module):
|
98 |
r"""
|
99 |
A basic Transformer block.
|
100 |
+
|
101 |
Parameters:
|
102 |
dim (`int`): The number of channels in the input and output.
|
103 |
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
|
|
410 |
class TemporalBasicTransformerBlock(nn.Module):
|
411 |
r"""
|
412 |
A basic Transformer block for video like data.
|
413 |
+
|
414 |
Parameters:
|
415 |
dim (`int`): The number of channels in the input and output.
|
416 |
time_mix_inner_dim (`int`): The number of channels for temporal attention.
|
|
|
612 |
class FeedForward(nn.Module):
|
613 |
r"""
|
614 |
A feed-forward layer.
|
615 |
+
|
616 |
Parameters:
|
617 |
dim (`int`): The number of channels in the input.
|
618 |
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
|
|
667 |
hidden_states = module(hidden_states, scale)
|
668 |
else:
|
669 |
hidden_states = module(hidden_states)
|
670 |
+
return hidden_states
|