|
from __future__ import annotations |
|
|
|
import logging |
|
from typing import Dict, List, Any |
|
|
|
|
|
from langchain.embeddings.base import Embeddings |
|
from langchain.pydantic_v1 import BaseModel, root_validator |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class ZhipuAIEmbeddings(BaseModel, Embeddings): |
|
"""`Zhipuai Embeddings` embedding models.""" |
|
|
|
client: Any |
|
"""`zhipuai.ZhipuAI""" |
|
|
|
@root_validator() |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
""" |
|
实例化ZhipuAI为values["client"] |
|
|
|
Args: |
|
|
|
values (Dict): 包含配置信息的字典,必须包含 client 的字段. |
|
Returns: |
|
|
|
values (Dict): 包含配置信息的字典。如果环境中有zhipuai库,则将返回实例化的ZhipuAI类;否则将报错 'ModuleNotFoundError: No module named 'zhipuai''. |
|
""" |
|
from zhipuai import ZhipuAI |
|
values["client"] = ZhipuAI() |
|
return values |
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
""" |
|
生成输入文本的 embedding. |
|
|
|
Args: |
|
texts (str): 要生成 embedding 的文本. |
|
|
|
Return: |
|
embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表. |
|
""" |
|
embeddings = self.client.embeddings.create( |
|
model="embedding-2", |
|
input=text |
|
) |
|
return embeddings.data[0].embedding |
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
""" |
|
生成输入文本列表的 embedding. |
|
Args: |
|
texts (List[str]): 要生成 embedding 的文本列表. |
|
|
|
Returns: |
|
List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。 |
|
""" |
|
return [self.embed_query(text) for text in texts] |
|
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: |
|
"""Asynchronous Embed search docs.""" |
|
raise NotImplementedError("Please use `embed_documents`. Official does not support asynchronous requests") |
|
|
|
async def aembed_query(self, text: str) -> List[float]: |
|
"""Asynchronous Embed query text.""" |
|
raise NotImplementedError("Please use `aembed_query`. Official does not support asynchronous requests") |