zhuzhu2dandan
commited on
Upload 2 files
Browse files- gitlab_push_file.ipynb +72 -0
- shou_xin.safetensors +3 -0
gitlab_push_file.ipynb
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from ...modules import sparse as sp
|
6 |
+
from .base import SparseTransformerBase
|
7 |
+
|
8 |
+
|
9 |
+
class SLatEncoder(SparseTransformerBase):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
resolution: int,
|
13 |
+
in_channels: int,
|
14 |
+
model_channels: int,
|
15 |
+
latent_channels: int,
|
16 |
+
num_blocks: int,
|
17 |
+
num_heads: Optional[int] = None,
|
18 |
+
num_head_channels: Optional[int] = 64,
|
19 |
+
mlp_ratio: float = 4,
|
20 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
21 |
+
window_size: int = 8,
|
22 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
23 |
+
use_fp16: bool = False,
|
24 |
+
use_checkpoint: bool = False,
|
25 |
+
qk_rms_norm: bool = False,
|
26 |
+
):
|
27 |
+
super().__init__(
|
28 |
+
in_channels=in_channels,
|
29 |
+
model_channels=model_channels,
|
30 |
+
num_blocks=num_blocks,
|
31 |
+
num_heads=num_heads,
|
32 |
+
num_head_channels=num_head_channels,
|
33 |
+
mlp_ratio=mlp_ratio,
|
34 |
+
attn_mode=attn_mode,
|
35 |
+
window_size=window_size,
|
36 |
+
pe_mode=pe_mode,
|
37 |
+
use_fp16=use_fp16,
|
38 |
+
use_checkpoint=use_checkpoint,
|
39 |
+
qk_rms_norm=qk_rms_norm,
|
40 |
+
)
|
41 |
+
self.resolution = resolution
|
42 |
+
self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
|
43 |
+
|
44 |
+
self.initialize_weights()
|
45 |
+
if use_fp16:
|
46 |
+
self.convert_to_fp16()
|
47 |
+
|
48 |
+
def initialize_weights(self) -> None:
|
49 |
+
super().initialize_weights()
|
50 |
+
# Zero-out output layers:
|
51 |
+
nn.init.constant_(self.out_layer.weight, 0)
|
52 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
53 |
+
|
54 |
+
def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
|
55 |
+
h = super().forward(x)
|
56 |
+
h = h.type(x.dtype)
|
57 |
+
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
58 |
+
h = self.out_layer(h)
|
59 |
+
|
60 |
+
# Sample from the posterior distribution
|
61 |
+
mean, logvar = h.feats.chunk(2, dim=-1)
|
62 |
+
if sample_posterior:
|
63 |
+
std = torch.exp(0.5 * logvar)
|
64 |
+
z = mean + std * torch.randn_like(std)
|
65 |
+
else:
|
66 |
+
z = mean
|
67 |
+
z = h.replace(z)
|
68 |
+
|
69 |
+
if return_raw:
|
70 |
+
return z, mean, logvar
|
71 |
+
else:
|
72 |
+
return z
|
shou_xin.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e59714179016388dee044952d92a733f5cdf462d815be67fe87259e88fdb4703
|
3 |
+
size 171969400
|