#ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage // however, be careful from int overflows when using those in the kernel implementation // // - strides (e.g. nb00) use uint64_t typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t dim; } ggml_metal_kargs_concat; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; uint64_t offs; } ggml_metal_kargs_bin; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_repeat; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_cpy; typedef struct { int64_t ne10; int64_t ne11; int64_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; uint64_t nb1; uint64_t nb2; uint64_t nb3; uint64_t offs; bool inplace; } ggml_metal_kargs_set; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t n_past; int32_t n_dims; int32_t n_ctx_orig; float freq_base; float freq_scale; float ext_factor; float attn_factor; float beta_fast; float beta_slow; } ggml_metal_kargs_rope; typedef struct { int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; uint64_t nb_12_1; uint64_t nb_12_2; uint64_t nb_12_3; uint64_t nb31; int32_t ne1; int32_t ne2; float scale; float max_bias; float m0; float m1; uint16_t n_head_log2; float logit_softcap; } ggml_metal_kargs_flash_attn_ext; typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; int16_t nsg; int16_t nxpsg; int16_t r1ptg; } ggml_metal_kargs_mul_mv_ext; typedef struct { int32_t nei0; int32_t nei1; uint64_t nbi1; int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; int32_t ne0; int32_t ne1; } ggml_metal_kargs_mul_mm_id; typedef struct { int32_t nei0; int32_t nei1; uint64_t nbi1; int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; int32_t ne0; int32_t ne1; uint64_t nb1; } ggml_metal_kargs_mul_mv_id; typedef struct { int32_t ne00; int32_t ne00_4; uint64_t nb01; float eps; } ggml_metal_kargs_norm; typedef struct { int32_t ne00; int32_t ne00_4; uint64_t nb01; float eps; } ggml_metal_kargs_rms_norm; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; int32_t n_groups; float eps; } ggml_metal_kargs_group_norm; typedef struct { int32_t IC; int32_t IL; int32_t K; int32_t s0; uint64_t nb0; uint64_t nb1; } ggml_metal_kargs_conv_transpose_1d; typedef struct { uint64_t ofs0; uint64_t ofs1; int32_t IW; int32_t IH; int32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; int32_t d0; int32_t d1; int32_t N; int32_t KH; int32_t KW; int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne10; int64_t ne11; int64_t ne12; int64_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_sum_rows; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; float scale; float max_bias; float m0; float m1; uint32_t n_head_log2; } ggml_metal_kargs_soft_max; typedef struct { int64_t ne00; int64_t ne01; int n_past; } ggml_metal_kargs_diag_mask_inf; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; int64_t ne10; int64_t ne11; uint64_t nb10; uint64_t nb11; int64_t ne0; int64_t ne1; int64_t ne2; uint64_t nb0; uint64_t nb1; uint64_t nb2; } ggml_metal_kargs_ssm_conv; typedef struct { int64_t d_state; int64_t d_inner; int64_t n_seq_tokens; int64_t n_seqs; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; uint64_t nb20; uint64_t nb21; uint64_t nb22; uint64_t nb30; uint64_t nb31; uint64_t nb40; uint64_t nb41; uint64_t nb42; uint64_t nb50; uint64_t nb51; uint64_t nb52; } ggml_metal_kargs_ssm_scan; typedef struct { int64_t ne00; uint64_t nb01; uint64_t nb02; int64_t ne10; uint64_t nb10; uint64_t nb11; uint64_t nb1; uint64_t nb2; } ggml_metal_kargs_get_rows; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; float sf0; float sf1; float sf2; float sf3; } ggml_metal_kargs_upscale; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_pad; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t p0; int32_t p1; } ggml_metal_kargs_pad_reflect_1d; typedef struct { uint64_t nb1; int dim; int max_period; } ggml_metal_kargs_timestep_embedding; typedef struct { float slope; } ggml_metal_kargs_leaky_relu; typedef struct { int64_t ncols; int64_t ncols_pad; } ggml_metal_kargs_argsort; typedef struct { int64_t ne0; float start; float step; } ggml_metal_kargs_arange; typedef struct { int32_t k0; int32_t k1; int32_t s0; int32_t s1; int32_t p0; int32_t p1; int64_t IH; int64_t IW; int64_t OH; int64_t OW; int64_t parallel_elements; } ggml_metal_kargs_pool_2d; #endif // GGML_METAL_IMPL