briankchan commited on
Commit
2afc3f4
·
0 Parent(s):

Initial commit

Browse files
.chroma/chroma-collections.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02c564513891989cdab60ecfe81c42b09048a2d0e33d99b00ff302f9bd1567f2
3
+ size 532
.chroma/chroma-embeddings.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:677722053474bf1aab4c939cf46e3bd7994e14982acf2a02ebaf80418a2e2d8f
3
+ size 1433761
.chroma/index/id_to_uuid_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4300adaeb3576278eb1ef7d7ca332603326b53dd4a79465793c69fe3829f7527
3
+ size 3466
.chroma/index/index_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:810ec276acabcfb03d602423bf0f8316297a76ebb66600d75d7a6aa0b1b20014
3
+ size 686100
.chroma/index/index_metadata_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d262ae2323e12d22ad009e5a8d374fccdffc0c065aec98f4c87d0accddc4264
3
+ size 73
.chroma/index/uuid_to_id_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34484296518ecc1d0f9998024eab730a31210b722d21d547400cf8eb6803fb9e
3
+ size 4049
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ ### VisualStudioCode ###
177
+ .vscode/*
178
+ !.vscode/settings.json
179
+ !.vscode/tasks.json
180
+ !.vscode/launch.json
181
+ !.vscode/extensions.json
182
+ !.vscode/*.code-snippets
183
+
184
+ # Local History for Visual Studio Code
185
+ .history/
186
+
187
+ # Built Visual Studio Code Extensions
188
+ *.vsix
189
+
190
+ ### VisualStudioCode Patch ###
191
+ # Ignore all local history of files
192
+ .history
193
+ .ionide
194
+
195
+ # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Doc Search
3
+ emoji: 😻
4
+ colorFrom: pink
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.32.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py
2
+ import collections
3
+ import os
4
+ from queue import Queue
5
+ from time import sleep
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import gradio as gr
9
+ from anyio.from_thread import start_blocking_portal
10
+ from langchain import PromptTemplate
11
+ from langchain.callbacks.manager import AsyncCallbackManager
12
+ from langchain.chains import LLMChain
13
+ from langchain.chat_models import ChatOpenAI, PromptLayerChatOpenAI
14
+ from langchain.memory import ConversationBufferMemory
15
+ from langchain.prompts import PromptTemplate
16
+ from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
17
+ from langchain.prompts.chat import (ChatPromptTemplate,
18
+ HumanMessagePromptTemplate)
19
+ from langchain.schema import HumanMessage
20
+ from langchain.vectorstores import Chroma
21
+ from langchain.docstore.document import Document
22
+
23
+ from util import SyncStreamingLLMCallbackHandler, CustomOpenAIEmbeddings
24
+
25
+
26
+ def I(x):
27
+ "Identity function; does nothing."
28
+ return x
29
+
30
+ class PreprocessingPromptTemplate(PromptTemplate):
31
+ arg_preprocessing: Dict = {} # this is probably the wrong type
32
+ def format(self, **kwargs: Any) -> str:
33
+ """Format the prompt with the inputs.
34
+
35
+ Args:
36
+ kwargs: Any arguments to be passed to the prompt template.
37
+
38
+ Returns:
39
+ A formatted string.
40
+
41
+ Example:
42
+
43
+ .. code-block:: python
44
+
45
+ prompt.format(variable1="foo")
46
+ """
47
+ kwargs = self._merge_partial_and_user_variables(**kwargs)
48
+ kwargs = self._preprocess_args(kwargs)
49
+ return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
50
+
51
+ def _preprocess_args(self, args: dict):
52
+ return {k: self.arg_preprocessing.get(k, I)(v) for k, v in args.items()}
53
+
54
+
55
+ def top_results_to_string(x: List[Tuple[Document, float]]):
56
+ return "\n~~~\n".join(f"Result {i} Title: {doc.metadata['title']}\nResult {i} Content: {doc.page_content}" for i, (doc, score) in enumerate(x, 1))
57
+
58
+
59
+ PROMPT = """You are a helpful AI assistant that summarizes search results for users.
60
+ ---
61
+ A user has searched for the following query:
62
+ {query}
63
+ ---
64
+ The search engine returned the following 5 search results:
65
+ {top_results}
66
+ ---
67
+ Based on the search results, answer the user's query, and use the same language as the user's query.
68
+ Say which search result you used.
69
+ Do not use information other than the search results.
70
+ Say 'No answer found.' if there are no relevant results.
71
+ Afterwards, say how confident you are in your answer as a percentage.
72
+ """
73
+ PROMPT_TEMPLATE = PreprocessingPromptTemplate(template=PROMPT, input_variables=['query', 'top_results'])
74
+ PROMPT_TEMPLATE.arg_preprocessing['top_results'] = top_results_to_string
75
+
76
+ # TODO give relevance value in prompt
77
+ # TODO ask gpt to say which sources it used
78
+
79
+
80
+
81
+ # TODO azure?
82
+ COLLECTION = Chroma(
83
+ embedding_function=CustomOpenAIEmbeddings(api_key=os.environ.get("OPENAI_API_KEY", None)),
84
+ persist_directory="./.chroma",
85
+ collection_name="CUHK",
86
+ )
87
+ # COLLECTION = CHROMA_CLIENT.get_collection(name='CUHK')
88
+
89
+
90
+ def load_chain(api_type):
91
+ shared_args = {
92
+ "temperature": 0,
93
+ "model_name": "gpt-3.5-turbo",
94
+ "pl_tags": ["cuhk-demo"],
95
+ "streaming": True,
96
+ }
97
+ if api_type == "OpenAI":
98
+ chat = PromptLayerChatOpenAI(
99
+ **shared_args,
100
+ api_key = os.environ.get("OPENAI_API_KEY", None),
101
+ )
102
+ elif api_type == "Azure OpenAI":
103
+ chat = PromptLayerChatOpenAI(
104
+ api_type = "azure",
105
+ api_key = os.environ.get("AZURE_OPENAI_API_KEY", None),
106
+ api_base = os.environ.get("AZURE_OPENAI_API_BASE", None),
107
+ api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"),
108
+ engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None),
109
+ **shared_args
110
+ )
111
+
112
+ chain = chain = LLMChain(llm=chat, prompt=PROMPT_TEMPLATE)
113
+ return chat, chain
114
+
115
+ def initialize_chain(api_type):
116
+ "Runs at app start"
117
+ chat, chain = load_chain(api_type)
118
+ return chat, chain
119
+
120
+ def change_chain(api_type, old_chain):
121
+ chat, chain = load_chain(api_type)
122
+ return chat, chain
123
+
124
+ def find_top_results(query):
125
+ results = COLLECTION.similarity_search_with_score(query, k=4) # TODO filter by device (windows, mac, android, ios)
126
+
127
+ output = "\n".join(f"1. [{d.metadata['title']}]({d.metadata['url']}) <small>(dist: {s})</small>" for d, s in results)
128
+ return results, output
129
+
130
+ def ask_gpt(chain, query, top_results): # top_results: List[Tuple[Document, float]]
131
+ q = Queue()
132
+ job_done = object()
133
+ def task():
134
+ chain.run(
135
+ query=query,
136
+ top_results=top_results,
137
+ callbacks=[SyncStreamingLLMCallbackHandler(q)],
138
+ )
139
+ q.put(job_done)
140
+ return
141
+
142
+ with start_blocking_portal() as portal:
143
+ portal.start_task_soon(task)
144
+
145
+ content = ""
146
+ while True:
147
+ next_token = q.get(True, timeout=15)
148
+ if next_token is job_done:
149
+ break
150
+ content += next_token
151
+ yield content
152
+
153
+
154
+ demo = gr.Blocks(css="""
155
+ #sidebar {
156
+ max-width: 300px;
157
+ }
158
+ """)
159
+ with demo:
160
+ with gr.Row():
161
+ # sidebar
162
+ with gr.Column(elem_id="sidebar"):
163
+ api_type = gr.Radio(
164
+ ["OpenAI", "Azure OpenAI"],
165
+ value="OpenAI",
166
+ label="Server",
167
+ info="You can try changing this if responses are slow."
168
+ )
169
+
170
+ # main
171
+ with gr.Column():
172
+ # Company img
173
+ gr.HTML(r'<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>')
174
+
175
+ chat = gr.State()
176
+ chain = gr.State()
177
+
178
+ query = gr.Textbox(label="Search Query:")
179
+ top_results_data = gr.State()
180
+ top_results = gr.Markdown(label="Search Results")
181
+ response = gr.Textbox(label="AI Response")
182
+
183
+
184
+ load_event = demo.load(initialize_chain, [api_type], [chat, chain])
185
+
186
+ query_event = query.submit(find_top_results, [query], [top_results_data, top_results])
187
+ ask_event = query_event.then(ask_gpt, [chain, query, top_results_data], [response])
188
+ api_type.change(change_chain,
189
+ [api_type, chain],
190
+ [chat, chain],
191
+ cancels=[load_event, query_event, ask_event])
192
+
193
+
194
+ demo.queue()
195
+ if __name__ == "__main__":
196
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ openai
3
+ promptlayer
4
+ langchain
5
+ chromadb
6
+ tiktoken
thinkcol-logo.png ADDED
util.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+ from types import GeneratorType
3
+ from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
4
+ from langchain.schema import AgentAction, AgentFinish, LLMResult
5
+ from langchain.embeddings.openai import embed_with_retry, OpenAIEmbeddings
6
+ from pydantic import Extra, Field, root_validator
7
+ import numpy as np
8
+
9
+ class StreamingLLMCallbackHandler(AsyncCallbackHandler):
10
+ """Callback handler for streaming LLM responses to a queue."""
11
+
12
+ def __init__(self, q):
13
+ self.q = q
14
+
15
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
16
+ self.q.put(token)
17
+
18
+
19
+ class SyncStreamingLLMCallbackHandler(BaseCallbackHandler):
20
+ """Callback handler for streaming LLM responses to a queue."""
21
+
22
+ def __init__(self, q):
23
+ self.q = q
24
+
25
+ def on_llm_start(
26
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
27
+ ) -> None:
28
+ """Do nothing."""
29
+ pass
30
+
31
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
32
+ self.q.put(token)
33
+
34
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
35
+ """Do nothing."""
36
+ pass
37
+
38
+ def on_llm_error(
39
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
40
+ ) -> None:
41
+ """Do nothing."""
42
+ pass
43
+
44
+ def on_chain_start(
45
+ self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
46
+ ) -> None:
47
+ """Do nothing."""
48
+ pass
49
+
50
+ def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
51
+ """Do nothing."""
52
+ pass
53
+
54
+ def on_chain_error(
55
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
56
+ ) -> None:
57
+ """Do nothing."""
58
+ pass
59
+
60
+ def on_tool_start(
61
+ self,
62
+ serialized: Dict[str, Any],
63
+ input_str: str,
64
+ **kwargs: Any,
65
+ ) -> None:
66
+ """Do nothing."""
67
+ pass
68
+
69
+ def on_tool_end(
70
+ self,
71
+ output: str,
72
+ color: Optional[str] = None,
73
+ observation_prefix: Optional[str] = None,
74
+ llm_prefix: Optional[str] = None,
75
+ **kwargs: Any,
76
+ ) -> None:
77
+ """Do nothing."""
78
+ pass
79
+
80
+ def on_tool_error(
81
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
82
+ ) -> None:
83
+ """Do nothing."""
84
+ pass
85
+
86
+ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
87
+ """Run on agent action."""
88
+ pass
89
+
90
+ def on_agent_finish(
91
+ self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
92
+ ) -> None:
93
+ """Run on agent end."""
94
+ pass
95
+
96
+
97
+ def concatenate_generators(*args):
98
+ final_outputs = ""
99
+ for g in args:
100
+ if isinstance(g, GeneratorType):
101
+ for v in g:
102
+ yield final_outputs + v
103
+ result = v
104
+ else:
105
+ yield final_outputs + g
106
+ result = g
107
+ final_outputs += result
108
+
109
+
110
+ class CustomOpenAIEmbeddings(OpenAIEmbeddings):
111
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
112
+
113
+ """
114
+ A version of OpenAIEmbeddings that allows extra args
115
+ to be passed to OpenAI functions.
116
+ Based on langchain's ChatOpenAI.
117
+ """
118
+ class Config:
119
+ """Configuration for this pydantic object."""
120
+
121
+ extra = Extra.ignore
122
+
123
+ @root_validator(pre=True)
124
+ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
125
+ """Build extra kwargs from additional params that were passed in."""
126
+ all_required_field_names = {field.alias for field in cls.__fields__.values()}
127
+
128
+ extra = values.get("model_kwargs", {})
129
+ for field_name in list(values):
130
+ if field_name in extra:
131
+ raise ValueError(f"Found {field_name} supplied twice.")
132
+ if field_name not in all_required_field_names:
133
+ # logger.warning(
134
+ # f"""WARNING! {field_name} is not default parameter.
135
+ # {field_name} was transferred to model_kwargs.
136
+ # Please confirm that {field_name} is what you intended."""
137
+ # )
138
+ extra[field_name] = values.pop(field_name)
139
+
140
+ disallowed_model_kwargs = all_required_field_names | {"model"}
141
+ invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
142
+ if invalid_model_kwargs:
143
+ raise ValueError(
144
+ f"Parameters {invalid_model_kwargs} should be specified explicitly. "
145
+ f"Instead they were passed in as part of `model_kwargs` parameter."
146
+ )
147
+
148
+ values["model_kwargs"] = extra
149
+ return values
150
+
151
+ # use extra args in calls
152
+
153
+ # please refer to
154
+ # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
155
+ def _get_len_safe_embeddings(
156
+ self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
157
+ ) -> List[List[float]]:
158
+ embeddings: List[List[float]] = [[] for _ in range(len(texts))]
159
+ try:
160
+ import tiktoken
161
+
162
+ tokens = []
163
+ indices = []
164
+ encoding = tiktoken.model.encoding_for_model(self.model)
165
+ for i, text in enumerate(texts):
166
+ if self.model.endswith("001"):
167
+ # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
168
+ # replace newlines, which can negatively affect performance.
169
+ text = text.replace("\n", " ")
170
+ token = encoding.encode(
171
+ text,
172
+ allowed_special=self.allowed_special,
173
+ disallowed_special=self.disallowed_special,
174
+ )
175
+ for j in range(0, len(token), self.embedding_ctx_length):
176
+ tokens += [token[j : j + self.embedding_ctx_length]]
177
+ indices += [i]
178
+
179
+ batched_embeddings = []
180
+ _chunk_size = chunk_size or self.chunk_size
181
+ for i in range(0, len(tokens), _chunk_size):
182
+ response = embed_with_retry(
183
+ self,
184
+ input=tokens[i : i + _chunk_size],
185
+ engine=self.deployment,
186
+ request_timeout=self.request_timeout,
187
+ **self.model_kwargs,
188
+ )
189
+ batched_embeddings += [r["embedding"] for r in response["data"]]
190
+
191
+ results: List[List[List[float]]] = [[] for _ in range(len(texts))]
192
+ num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
193
+ for i in range(len(indices)):
194
+ results[indices[i]].append(batched_embeddings[i])
195
+ num_tokens_in_batch[indices[i]].append(len(tokens[i]))
196
+
197
+ for i in range(len(texts)):
198
+ _result = results[i]
199
+ if len(_result) == 0:
200
+ average = embed_with_retry(
201
+ self,
202
+ input="",
203
+ engine=self.deployment,
204
+ request_timeout=self.request_timeout,
205
+ **self.model_kwargs,
206
+ )["data"][0]["embedding"]
207
+ else:
208
+ average = np.average(
209
+ _result, axis=0, weights=num_tokens_in_batch[i]
210
+ )
211
+ embeddings[i] = (average / np.linalg.norm(average)).tolist()
212
+
213
+ return embeddings
214
+
215
+ except ImportError:
216
+ raise ValueError(
217
+ "Could not import tiktoken python package. "
218
+ "This is needed in order to for OpenAIEmbeddings. "
219
+ "Please install it with `pip install tiktoken`."
220
+ )
221
+
222
+ def _embedding_func(self, text: str, *, engine: str) -> List[float]:
223
+ """Call out to OpenAI's embedding endpoint."""
224
+ # handle large input text
225
+ if len(text) > self.embedding_ctx_length:
226
+ return self._get_len_safe_embeddings([text], engine=engine)[0]
227
+ else:
228
+ if self.model.endswith("001"):
229
+ # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
230
+ # replace newlines, which can negatively affect performance.
231
+ text = text.replace("\n", " ")
232
+ return embed_with_retry(
233
+ self, input=[text], engine=engine, request_timeout=self.request_timeout,
234
+ **self.model_kwargs,
235
+ )["data"][0]["embedding"]