sanbu commited on
Commit
e99712e
·
1 Parent(s): 7a82ea8

Update space

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. models.py +86 -0
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import os
2
  import gradio as gr
3
  from dotenv import load_dotenv
4
- from tianji.knowledges.langchain_onlinellm.models import ZhipuAIEmbeddings, ZhipuLLM
5
  from langchain_chroma import Chroma
6
  from langchain_community.document_loaders import DirectoryLoader, TextLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
- from langchain_core.runnables import RunnablePassthrough
9
  from langchain_core.output_parsers import StrOutputParser
10
  from langchain import hub
11
  from huggingface_hub import snapshot_download
 
1
  import os
2
  import gradio as gr
3
  from dotenv import load_dotenv
4
+ from models import ZhipuAIEmbeddings, ZhipuLLM
5
  from langchain_chroma import Chroma
6
  from langchain_community.document_loaders import DirectoryLoader, TextLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ from langchain_core.runnables import RunnablePassthrough
9
  from langchain_core.output_parsers import StrOutputParser
10
  from langchain import hub
11
  from huggingface_hub import snapshot_download
models.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.language_models.llms import LLM
2
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
3
+ from langchain.embeddings.base import Embeddings
4
+ from typing import Any, Dict, List, Optional
5
+ import os
6
+ from zhipuai import ZhipuAI
7
+ from langchain.pydantic_v1 import BaseModel, root_validator
8
+
9
+
10
+ class ZhipuLLM(LLM):
11
+ """A custom chat model for ZhipuAI."""
12
+
13
+ client: Any = None
14
+
15
+ def __init__(self):
16
+ super().__init__()
17
+ print("Initializing model...")
18
+ self.client = ZhipuAI(api_key=os.environ.get("ZHIPUAI_API_KEY"))
19
+ print("Model initialization complete")
20
+
21
+ def _call(
22
+ self,
23
+ prompt: str,
24
+ stop: Optional[List[str]] = None,
25
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
26
+ **kwargs: Any,
27
+ ) -> str:
28
+ """Run the LLM on the given input."""
29
+
30
+ response = self.client.chat.completions.create(
31
+ model="glm-4-flash",
32
+ messages=[
33
+ {"role": "user", "content": prompt},
34
+ ],
35
+ )
36
+ return response.choices[0].message.content
37
+
38
+ @property
39
+ def _identifying_params(self) -> Dict[str, Any]:
40
+ """Return a dictionary of identifying parameters."""
41
+ return {"model_name": "ZhipuAI"}
42
+
43
+ @property
44
+ def _llm_type(self) -> str:
45
+ """Get the type of language model used by this chat model."""
46
+ return "ZhipuAI"
47
+
48
+
49
+ class ZhipuAIEmbeddings(BaseModel, Embeddings):
50
+ """`Zhipuai Embeddings` embedding models."""
51
+
52
+ zhipuai_api_key: Optional[str] = None
53
+
54
+ @root_validator()
55
+ def validate_environment(cls, values: Dict) -> Dict:
56
+ values["zhupuai_api_key"] = values.get("zhupuai_api_key") or os.getenv(
57
+ "ZHIPUAI_API_KEY"
58
+ )
59
+ try:
60
+ import zhipuai
61
+
62
+ zhipuai.api_key = values["zhupuai_api_key"]
63
+ values["client"] = zhipuai.ZhipuAI()
64
+ except ImportError:
65
+ raise ValueError(
66
+ "Zhipuai package not found, please install it with `pip install zhipuai`"
67
+ )
68
+ return values
69
+
70
+ def _embed(self, texts: str) -> List[float]:
71
+ try:
72
+ resp = self.client.embeddings.create(
73
+ model="embedding-3",
74
+ input=texts,
75
+ )
76
+ except Exception as e:
77
+ raise ValueError(f"Error raised by inference endpoint: {e}")
78
+ embeddings = resp.data[0].embedding
79
+ return embeddings
80
+
81
+ def embed_query(self, text: str) -> List[float]:
82
+ resp = self.embed_documents([text])
83
+ return resp[0]
84
+
85
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
86
+ return [self._embed(text) for text in texts]