import { pipeline } from "@xenova/transformers"; import { Embeddings, EmbeddingsParams } from "langchain/embeddings/base"; export interface XenovaTransformersEmbeddingsParams extends EmbeddingsParams { model?: string; } export class XenovaTransformersEmbeddings extends Embeddings implements XenovaTransformersEmbeddingsParams { model: string; client: any; constructor(fields?: XenovaTransformersEmbeddingsParams) { super(fields ?? {}); this.model = fields?.model ?? "Xenova/all-MiniLM-L6-v2"; } async _embed(texts: string[]): Promise { if (!this.client) { this.client = await pipeline("embeddings", this.model); } return this.caller.call(async () => { return await Promise.all( texts.map(async (t) => (await this.client(t, { pooling: "mean", normalize: true })).data) ); }); } embedQuery(document: string): Promise { return this._embed([document]).then((embeddings) => embeddings[0]); } embedDocuments(documents: string[]): Promise { return this._embed(documents); } }