|
#ifndef __COMMON_HPP__ |
|
#define __COMMON_HPP__ |
|
|
|
#include "ggml_extend.hpp" |
|
|
|
class DownSampleBlock : public GGMLBlock { |
|
protected: |
|
int channels; |
|
int out_channels; |
|
bool vae_downsample; |
|
|
|
public: |
|
DownSampleBlock(int channels, |
|
int out_channels, |
|
bool vae_downsample = false) |
|
: channels(channels), |
|
out_channels(out_channels), |
|
vae_downsample(vae_downsample) { |
|
if (vae_downsample) { |
|
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0})); |
|
} else { |
|
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1})); |
|
} |
|
} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
|
|
|
if (vae_downsample) { |
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]); |
|
|
|
x = ggml_pad(ctx, x, 1, 1, 0, 0); |
|
x = conv->forward(ctx, x); |
|
} else { |
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]); |
|
|
|
x = conv->forward(ctx, x); |
|
} |
|
return x; |
|
} |
|
}; |
|
|
|
class UpSampleBlock : public GGMLBlock { |
|
protected: |
|
int channels; |
|
int out_channels; |
|
|
|
public: |
|
UpSampleBlock(int channels, |
|
int out_channels) |
|
: channels(channels), |
|
out_channels(out_channels) { |
|
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1})); |
|
} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
|
|
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]); |
|
|
|
x = ggml_upscale(ctx, x, 2); |
|
x = conv->forward(ctx, x); |
|
return x; |
|
} |
|
}; |
|
|
|
class ResBlock : public GGMLBlock { |
|
protected: |
|
|
|
int64_t channels; |
|
int64_t emb_channels; |
|
int64_t out_channels; |
|
std::pair<int, int> kernel_size; |
|
int dims; |
|
bool skip_t_emb; |
|
bool exchange_temb_dims; |
|
|
|
std::shared_ptr<GGMLBlock> conv_nd(int dims, |
|
int64_t in_channels, |
|
int64_t out_channels, |
|
std::pair<int, int> kernel_size, |
|
std::pair<int, int> padding) { |
|
GGML_ASSERT(dims == 2 || dims == 3); |
|
if (dims == 3) { |
|
return std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(in_channels, out_channels, kernel_size.first, 1, padding.first)); |
|
} else { |
|
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding)); |
|
} |
|
} |
|
|
|
public: |
|
ResBlock(int64_t channels, |
|
int64_t emb_channels, |
|
int64_t out_channels, |
|
std::pair<int, int> kernel_size = {3, 3}, |
|
int dims = 2, |
|
bool exchange_temb_dims = false, |
|
bool skip_t_emb = false) |
|
: channels(channels), |
|
emb_channels(emb_channels), |
|
out_channels(out_channels), |
|
kernel_size(kernel_size), |
|
dims(dims), |
|
skip_t_emb(skip_t_emb), |
|
exchange_temb_dims(exchange_temb_dims) { |
|
std::pair<int, int> padding = {kernel_size.first / 2, kernel_size.second / 2}; |
|
blocks["in_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(channels)); |
|
|
|
blocks["in_layers.2"] = conv_nd(dims, channels, out_channels, kernel_size, padding); |
|
|
|
if (!skip_t_emb) { |
|
|
|
blocks["emb_layers.1"] = std::shared_ptr<GGMLBlock>(new Linear(emb_channels, out_channels)); |
|
} |
|
|
|
blocks["out_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels)); |
|
|
|
|
|
blocks["out_layers.3"] = conv_nd(dims, out_channels, out_channels, kernel_size, padding); |
|
|
|
if (out_channels != channels) { |
|
blocks["skip_connection"] = conv_nd(dims, channels, out_channels, {1, 1}, {0, 0}); |
|
} |
|
} |
|
|
|
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = NULL) { |
|
|
|
|
|
|
|
|
|
auto in_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["in_layers.0"]); |
|
auto in_layers_2 = std::dynamic_pointer_cast<UnaryBlock>(blocks["in_layers.2"]); |
|
auto out_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out_layers.0"]); |
|
auto out_layers_3 = std::dynamic_pointer_cast<UnaryBlock>(blocks["out_layers.3"]); |
|
|
|
if (emb == NULL) { |
|
GGML_ASSERT(skip_t_emb); |
|
} |
|
|
|
|
|
auto h = in_layers_0->forward(ctx, x); |
|
h = ggml_silu_inplace(ctx, h); |
|
h = in_layers_2->forward(ctx, h); |
|
|
|
|
|
if (!skip_t_emb) { |
|
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]); |
|
|
|
auto emb_out = ggml_silu(ctx, emb); |
|
emb_out = emb_layer_1->forward(ctx, emb_out); |
|
|
|
if (dims == 2) { |
|
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); |
|
} else { |
|
emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); |
|
if (exchange_temb_dims) { |
|
|
|
emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); |
|
} |
|
} |
|
|
|
h = ggml_add(ctx, h, emb_out); |
|
} |
|
|
|
|
|
h = out_layers_0->forward(ctx, h); |
|
h = ggml_silu_inplace(ctx, h); |
|
|
|
h = out_layers_3->forward(ctx, h); |
|
|
|
|
|
if (out_channels != channels) { |
|
auto skip_connection = std::dynamic_pointer_cast<UnaryBlock>(blocks["skip_connection"]); |
|
x = skip_connection->forward(ctx, x); |
|
} |
|
|
|
h = ggml_add(ctx, h, x); |
|
return h; |
|
} |
|
}; |
|
|
|
class GEGLU : public GGMLBlock { |
|
protected: |
|
int64_t dim_in; |
|
int64_t dim_out; |
|
|
|
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") { |
|
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32; |
|
enum ggml_type bias_wtype = GGML_TYPE_F32; |
|
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); |
|
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); |
|
} |
|
|
|
public: |
|
GEGLU(int64_t dim_in, int64_t dim_out) |
|
: dim_in(dim_in), dim_out(dim_out) {} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
|
|
|
|
|
struct ggml_tensor* w = params["proj.weight"]; |
|
struct ggml_tensor* b = params["proj.bias"]; |
|
|
|
auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); |
|
auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); |
|
auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); |
|
auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); |
|
|
|
auto x_in = x; |
|
x = ggml_nn_linear(ctx, x_in, x_w, x_b); |
|
auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); |
|
|
|
gate = ggml_gelu_inplace(ctx, gate); |
|
|
|
x = ggml_mul(ctx, x, gate); |
|
|
|
return x; |
|
} |
|
}; |
|
|
|
class FeedForward : public GGMLBlock { |
|
public: |
|
FeedForward(int64_t dim, |
|
int64_t dim_out, |
|
int64_t mult = 4) { |
|
int64_t inner_dim = dim * mult; |
|
|
|
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim)); |
|
|
|
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out)); |
|
} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { |
|
|
|
|
|
|
|
auto net_0 = std::dynamic_pointer_cast<GEGLU>(blocks["net.0"]); |
|
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]); |
|
|
|
x = net_0->forward(ctx, x); |
|
x = net_2->forward(ctx, x); |
|
return x; |
|
} |
|
}; |
|
|
|
class CrossAttention : public GGMLBlock { |
|
protected: |
|
int64_t query_dim; |
|
int64_t context_dim; |
|
int64_t n_head; |
|
int64_t d_head; |
|
bool flash_attn; |
|
|
|
public: |
|
CrossAttention(int64_t query_dim, |
|
int64_t context_dim, |
|
int64_t n_head, |
|
int64_t d_head, |
|
bool flash_attn = false) |
|
: n_head(n_head), |
|
d_head(d_head), |
|
query_dim(query_dim), |
|
context_dim(context_dim), |
|
flash_attn(flash_attn) { |
|
int64_t inner_dim = d_head * n_head; |
|
|
|
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false)); |
|
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false)); |
|
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false)); |
|
|
|
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, query_dim)); |
|
|
|
} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { |
|
|
|
|
|
|
|
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]); |
|
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]); |
|
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]); |
|
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]); |
|
|
|
int64_t n = x->ne[2]; |
|
int64_t n_token = x->ne[1]; |
|
int64_t n_context = context->ne[1]; |
|
int64_t inner_dim = d_head * n_head; |
|
|
|
auto q = to_q->forward(ctx, x); |
|
auto k = to_k->forward(ctx, context); |
|
auto v = to_v->forward(ctx, context); |
|
|
|
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); |
|
|
|
x = to_out_0->forward(ctx, x); |
|
return x; |
|
} |
|
}; |
|
|
|
class BasicTransformerBlock : public GGMLBlock { |
|
protected: |
|
int64_t n_head; |
|
int64_t d_head; |
|
bool ff_in; |
|
|
|
public: |
|
BasicTransformerBlock(int64_t dim, |
|
int64_t n_head, |
|
int64_t d_head, |
|
int64_t context_dim, |
|
bool ff_in = false, |
|
bool flash_attn = false) |
|
: n_head(n_head), d_head(d_head), ff_in(ff_in) { |
|
|
|
|
|
|
|
|
|
|
|
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn)); |
|
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn)); |
|
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim)); |
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim)); |
|
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim)); |
|
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim)); |
|
|
|
if (ff_in) { |
|
blocks["norm_in"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim)); |
|
blocks["ff_in"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim)); |
|
} |
|
} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { |
|
|
|
|
|
|
|
|
|
auto attn1 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn1"]); |
|
auto attn2 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn2"]); |
|
auto ff = std::dynamic_pointer_cast<FeedForward>(blocks["ff"]); |
|
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]); |
|
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]); |
|
auto norm3 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm3"]); |
|
|
|
if (ff_in) { |
|
auto norm_in = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_in"]); |
|
auto ff_in = std::dynamic_pointer_cast<FeedForward>(blocks["ff_in"]); |
|
|
|
auto x_skip = x; |
|
x = norm_in->forward(ctx, x); |
|
x = ff_in->forward(ctx, x); |
|
|
|
x = ggml_add(ctx, x, x_skip); |
|
} |
|
|
|
auto r = x; |
|
x = norm1->forward(ctx, x); |
|
x = attn1->forward(ctx, x, x); |
|
x = ggml_add(ctx, x, r); |
|
r = x; |
|
x = norm2->forward(ctx, x); |
|
x = attn2->forward(ctx, x, context); |
|
x = ggml_add(ctx, x, r); |
|
r = x; |
|
x = norm3->forward(ctx, x); |
|
x = ff->forward(ctx, x); |
|
x = ggml_add(ctx, x, r); |
|
|
|
return x; |
|
} |
|
}; |
|
|
|
class SpatialTransformer : public GGMLBlock { |
|
protected: |
|
int64_t in_channels; |
|
int64_t n_head; |
|
int64_t d_head; |
|
int64_t depth = 1; |
|
int64_t context_dim = 768; |
|
|
|
public: |
|
SpatialTransformer(int64_t in_channels, |
|
int64_t n_head, |
|
int64_t d_head, |
|
int64_t depth, |
|
int64_t context_dim, |
|
bool flash_attn = false) |
|
: in_channels(in_channels), |
|
n_head(n_head), |
|
d_head(d_head), |
|
depth(depth), |
|
context_dim(context_dim) { |
|
|
|
|
|
int64_t inner_dim = n_head * d_head; |
|
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels)); |
|
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1})); |
|
|
|
for (int i = 0; i < depth; i++) { |
|
std::string name = "transformer_blocks." + std::to_string(i); |
|
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn)); |
|
} |
|
|
|
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1})); |
|
} |
|
|
|
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { |
|
|
|
|
|
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]); |
|
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]); |
|
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]); |
|
|
|
auto x_in = x; |
|
int64_t n = x->ne[3]; |
|
int64_t h = x->ne[1]; |
|
int64_t w = x->ne[0]; |
|
int64_t inner_dim = n_head * d_head; |
|
|
|
x = norm->forward(ctx, x); |
|
x = proj_in->forward(ctx, x); |
|
|
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); |
|
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); |
|
|
|
for (int i = 0; i < depth; i++) { |
|
std::string name = "transformer_blocks." + std::to_string(i); |
|
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]); |
|
|
|
x = transformer_block->forward(ctx, x, context); |
|
} |
|
|
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); |
|
x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); |
|
|
|
|
|
x = proj_out->forward(ctx, x); |
|
|
|
x = ggml_add(ctx, x, x_in); |
|
return x; |
|
} |
|
}; |
|
|
|
class AlphaBlender : public GGMLBlock { |
|
protected: |
|
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") { |
|
|
|
enum ggml_type wtype = GGML_TYPE_F32; |
|
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); |
|
} |
|
|
|
float get_alpha() { |
|
|
|
|
|
float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]); |
|
return sigmoid(alpha); |
|
} |
|
|
|
public: |
|
AlphaBlender() { |
|
|
|
|
|
|
|
} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, |
|
struct ggml_tensor* x_spatial, |
|
struct ggml_tensor* x_temporal) { |
|
|
|
float alpha = get_alpha(); |
|
auto x = ggml_add(ctx, |
|
ggml_scale(ctx, x_spatial, alpha), |
|
ggml_scale(ctx, x_temporal, 1.0f - alpha)); |
|
return x; |
|
} |
|
}; |
|
|
|
class VideoResBlock : public ResBlock { |
|
public: |
|
VideoResBlock(int channels, |
|
int emb_channels, |
|
int out_channels, |
|
std::pair<int, int> kernel_size = {3, 3}, |
|
int64_t video_kernel_size = 3, |
|
int dims = 2) |
|
: ResBlock(channels, emb_channels, out_channels, kernel_size, dims) { |
|
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, emb_channels, out_channels, kernel_size, 3, true)); |
|
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender()); |
|
} |
|
|
|
struct ggml_tensor* forward(struct ggml_context* ctx, |
|
struct ggml_tensor* x, |
|
struct ggml_tensor* emb, |
|
int num_video_frames) { |
|
|
|
|
|
|
|
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]); |
|
auto time_mixer = std::dynamic_pointer_cast<AlphaBlender>(blocks["time_mixer"]); |
|
|
|
x = ResBlock::forward(ctx, x, emb); |
|
|
|
int64_t T = num_video_frames; |
|
int64_t B = x->ne[3] / T; |
|
int64_t C = x->ne[2]; |
|
int64_t H = x->ne[1]; |
|
int64_t W = x->ne[0]; |
|
|
|
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); |
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); |
|
auto x_mix = x; |
|
|
|
emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]); |
|
|
|
x = time_stack->forward(ctx, x, emb); |
|
|
|
x = time_mixer->forward(ctx, x_mix, x); |
|
|
|
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); |
|
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); |
|
|
|
return x; |
|
} |
|
}; |
|
|
|
#endif |
|
|