|
#ifndef GGML_METAL_IMPL |
|
#define GGML_METAL_IMPL |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
int32_t dim; |
|
} ggml_metal_kargs_concat; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
uint64_t offs; |
|
} ggml_metal_kargs_bin; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
} ggml_metal_kargs_repeat; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
int64_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int64_t ne0; |
|
int64_t ne1; |
|
int64_t ne2; |
|
int64_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
} ggml_metal_kargs_cpy; |
|
|
|
typedef struct { |
|
int64_t ne10; |
|
int64_t ne11; |
|
int64_t ne12; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
uint64_t offs; |
|
bool inplace; |
|
} ggml_metal_kargs_set; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
int32_t n_past; |
|
int32_t n_dims; |
|
int32_t n_ctx_orig; |
|
float freq_base; |
|
float freq_scale; |
|
float ext_factor; |
|
float attn_factor; |
|
float beta_fast; |
|
float beta_slow; |
|
} ggml_metal_kargs_rope; |
|
|
|
typedef struct { |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne11; |
|
int32_t ne_12_2; |
|
int32_t ne_12_3; |
|
uint64_t nb_12_1; |
|
uint64_t nb_12_2; |
|
uint64_t nb_12_3; |
|
uint64_t nb31; |
|
int32_t ne1; |
|
int32_t ne2; |
|
float scale; |
|
float max_bias; |
|
float m0; |
|
float m1; |
|
uint16_t n_head_log2; |
|
float logit_softcap; |
|
} ggml_metal_kargs_flash_attn_ext; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne02; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne12; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int16_t r2; |
|
int16_t r3; |
|
} ggml_metal_kargs_mul_mm; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int16_t r2; |
|
int16_t r3; |
|
} ggml_metal_kargs_mul_mv; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int16_t r2; |
|
int16_t r3; |
|
int16_t nsg; |
|
int16_t nxpsg; |
|
int16_t r1ptg; |
|
} ggml_metal_kargs_mul_mv_ext; |
|
|
|
typedef struct { |
|
int32_t nei0; |
|
int32_t nei1; |
|
uint64_t nbi1; |
|
int32_t ne00; |
|
int32_t ne02; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
int32_t ne0; |
|
int32_t ne1; |
|
} ggml_metal_kargs_mul_mm_id; |
|
|
|
typedef struct { |
|
int32_t nei0; |
|
int32_t nei1; |
|
uint64_t nbi1; |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
int32_t ne0; |
|
int32_t ne1; |
|
uint64_t nb1; |
|
} ggml_metal_kargs_mul_mv_id; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne00_4; |
|
uint64_t nb01; |
|
float eps; |
|
} ggml_metal_kargs_norm; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne00_4; |
|
uint64_t nb01; |
|
float eps; |
|
} ggml_metal_kargs_rms_norm; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
int32_t n_groups; |
|
float eps; |
|
} ggml_metal_kargs_group_norm; |
|
|
|
typedef struct { |
|
int32_t IC; |
|
int32_t IL; |
|
int32_t K; |
|
int32_t s0; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
} ggml_metal_kargs_conv_transpose_1d; |
|
|
|
typedef struct { |
|
uint64_t ofs0; |
|
uint64_t ofs1; |
|
int32_t IW; |
|
int32_t IH; |
|
int32_t CHW; |
|
int32_t s0; |
|
int32_t s1; |
|
int32_t p0; |
|
int32_t p1; |
|
int32_t d0; |
|
int32_t d1; |
|
int32_t N; |
|
int32_t KH; |
|
int32_t KW; |
|
int32_t KHW; |
|
} ggml_metal_kargs_im2col; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
int64_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int64_t ne10; |
|
int64_t ne11; |
|
int64_t ne12; |
|
int64_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int64_t ne0; |
|
int64_t ne1; |
|
int64_t ne2; |
|
int64_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
} ggml_metal_kargs_sum_rows; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
float scale; |
|
float max_bias; |
|
float m0; |
|
float m1; |
|
uint32_t n_head_log2; |
|
} ggml_metal_kargs_soft_max; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int n_past; |
|
} ggml_metal_kargs_diag_mask_inf; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
int64_t ne10; |
|
int64_t ne11; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
int64_t ne0; |
|
int64_t ne1; |
|
int64_t ne2; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
} ggml_metal_kargs_ssm_conv; |
|
|
|
typedef struct { |
|
int64_t d_state; |
|
int64_t d_inner; |
|
int64_t n_seq_tokens; |
|
int64_t n_seqs; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
uint64_t nb20; |
|
uint64_t nb21; |
|
uint64_t nb22; |
|
uint64_t nb30; |
|
uint64_t nb31; |
|
uint64_t nb40; |
|
uint64_t nb41; |
|
uint64_t nb42; |
|
uint64_t nb50; |
|
uint64_t nb51; |
|
uint64_t nb52; |
|
} ggml_metal_kargs_ssm_scan; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
int64_t ne10; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
} ggml_metal_kargs_get_rows; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
int64_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int64_t ne0; |
|
int64_t ne1; |
|
int64_t ne2; |
|
int64_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
float sf0; |
|
float sf1; |
|
float sf2; |
|
float sf3; |
|
} ggml_metal_kargs_upscale; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
int64_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int64_t ne0; |
|
int64_t ne1; |
|
int64_t ne2; |
|
int64_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
} ggml_metal_kargs_pad; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
int64_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int64_t ne0; |
|
int64_t ne1; |
|
int64_t ne2; |
|
int64_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
int32_t p0; |
|
int32_t p1; |
|
} ggml_metal_kargs_pad_reflect_1d; |
|
|
|
typedef struct { |
|
uint64_t nb1; |
|
int dim; |
|
int max_period; |
|
} ggml_metal_kargs_timestep_embedding; |
|
|
|
typedef struct { |
|
float slope; |
|
} ggml_metal_kargs_leaky_relu; |
|
|
|
typedef struct { |
|
int64_t ncols; |
|
int64_t ncols_pad; |
|
} ggml_metal_kargs_argsort; |
|
|
|
typedef struct { |
|
int64_t ne0; |
|
float start; |
|
float step; |
|
} ggml_metal_kargs_arange; |
|
|
|
typedef struct { |
|
int32_t k0; |
|
int32_t k1; |
|
int32_t s0; |
|
int32_t s1; |
|
int32_t p0; |
|
int32_t p1; |
|
int64_t IH; |
|
int64_t IW; |
|
int64_t OH; |
|
int64_t OW; |
|
int64_t parallel_elements; |
|
} ggml_metal_kargs_pool_2d; |
|
|
|
#endif |
|
|