Spaces:
Runtime error
Runtime error
Commit
·
2afc3f4
0
Parent(s):
Initial commit
Browse files- .chroma/chroma-collections.parquet +3 -0
- .chroma/chroma-embeddings.parquet +3 -0
- .chroma/index/id_to_uuid_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl +3 -0
- .chroma/index/index_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.bin +3 -0
- .chroma/index/index_metadata_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl +3 -0
- .chroma/index/uuid_to_id_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl +3 -0
- .gitattributes +34 -0
- .gitignore +195 -0
- README.md +12 -0
- app.py +196 -0
- requirements.txt +6 -0
- thinkcol-logo.png +0 -0
- util.py +235 -0
.chroma/chroma-collections.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02c564513891989cdab60ecfe81c42b09048a2d0e33d99b00ff302f9bd1567f2
|
3 |
+
size 532
|
.chroma/chroma-embeddings.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:677722053474bf1aab4c939cf46e3bd7994e14982acf2a02ebaf80418a2e2d8f
|
3 |
+
size 1433761
|
.chroma/index/id_to_uuid_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4300adaeb3576278eb1ef7d7ca332603326b53dd4a79465793c69fe3829f7527
|
3 |
+
size 3466
|
.chroma/index/index_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:810ec276acabcfb03d602423bf0f8316297a76ebb66600d75d7a6aa0b1b20014
|
3 |
+
size 686100
|
.chroma/index/index_metadata_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d262ae2323e12d22ad009e5a8d374fccdffc0c065aec98f4c87d0accddc4264
|
3 |
+
size 73
|
.chroma/index/uuid_to_id_6dc7068f-5504-4150-a2c1-7f17eb4a2ced.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34484296518ecc1d0f9998024eab730a31210b722d21d547400cf8eb6803fb9e
|
3 |
+
size 4049
|
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
|
2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode
|
3 |
+
|
4 |
+
### Python ###
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
cover/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# poetry
|
102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
104 |
+
# commonly ignored for libraries.
|
105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
106 |
+
#poetry.lock
|
107 |
+
|
108 |
+
# pdm
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
110 |
+
#pdm.lock
|
111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
112 |
+
# in version control.
|
113 |
+
# https://pdm.fming.dev/#use-with-ide
|
114 |
+
.pdm.toml
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
165 |
+
|
166 |
+
### Python Patch ###
|
167 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
168 |
+
poetry.toml
|
169 |
+
|
170 |
+
# ruff
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# LSP config files
|
174 |
+
pyrightconfig.json
|
175 |
+
|
176 |
+
### VisualStudioCode ###
|
177 |
+
.vscode/*
|
178 |
+
!.vscode/settings.json
|
179 |
+
!.vscode/tasks.json
|
180 |
+
!.vscode/launch.json
|
181 |
+
!.vscode/extensions.json
|
182 |
+
!.vscode/*.code-snippets
|
183 |
+
|
184 |
+
# Local History for Visual Studio Code
|
185 |
+
.history/
|
186 |
+
|
187 |
+
# Built Visual Studio Code Extensions
|
188 |
+
*.vsix
|
189 |
+
|
190 |
+
### VisualStudioCode Patch ###
|
191 |
+
# Ignore all local history of files
|
192 |
+
.history
|
193 |
+
.ionide
|
194 |
+
|
195 |
+
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Doc Search
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.32.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/hwchase17/langchain-gradio-template/blob/master/app.py
|
2 |
+
import collections
|
3 |
+
import os
|
4 |
+
from queue import Queue
|
5 |
+
from time import sleep
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
from anyio.from_thread import start_blocking_portal
|
10 |
+
from langchain import PromptTemplate
|
11 |
+
from langchain.callbacks.manager import AsyncCallbackManager
|
12 |
+
from langchain.chains import LLMChain
|
13 |
+
from langchain.chat_models import ChatOpenAI, PromptLayerChatOpenAI
|
14 |
+
from langchain.memory import ConversationBufferMemory
|
15 |
+
from langchain.prompts import PromptTemplate
|
16 |
+
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
|
17 |
+
from langchain.prompts.chat import (ChatPromptTemplate,
|
18 |
+
HumanMessagePromptTemplate)
|
19 |
+
from langchain.schema import HumanMessage
|
20 |
+
from langchain.vectorstores import Chroma
|
21 |
+
from langchain.docstore.document import Document
|
22 |
+
|
23 |
+
from util import SyncStreamingLLMCallbackHandler, CustomOpenAIEmbeddings
|
24 |
+
|
25 |
+
|
26 |
+
def I(x):
|
27 |
+
"Identity function; does nothing."
|
28 |
+
return x
|
29 |
+
|
30 |
+
class PreprocessingPromptTemplate(PromptTemplate):
|
31 |
+
arg_preprocessing: Dict = {} # this is probably the wrong type
|
32 |
+
def format(self, **kwargs: Any) -> str:
|
33 |
+
"""Format the prompt with the inputs.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
kwargs: Any arguments to be passed to the prompt template.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
A formatted string.
|
40 |
+
|
41 |
+
Example:
|
42 |
+
|
43 |
+
.. code-block:: python
|
44 |
+
|
45 |
+
prompt.format(variable1="foo")
|
46 |
+
"""
|
47 |
+
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
48 |
+
kwargs = self._preprocess_args(kwargs)
|
49 |
+
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
50 |
+
|
51 |
+
def _preprocess_args(self, args: dict):
|
52 |
+
return {k: self.arg_preprocessing.get(k, I)(v) for k, v in args.items()}
|
53 |
+
|
54 |
+
|
55 |
+
def top_results_to_string(x: List[Tuple[Document, float]]):
|
56 |
+
return "\n~~~\n".join(f"Result {i} Title: {doc.metadata['title']}\nResult {i} Content: {doc.page_content}" for i, (doc, score) in enumerate(x, 1))
|
57 |
+
|
58 |
+
|
59 |
+
PROMPT = """You are a helpful AI assistant that summarizes search results for users.
|
60 |
+
---
|
61 |
+
A user has searched for the following query:
|
62 |
+
{query}
|
63 |
+
---
|
64 |
+
The search engine returned the following 5 search results:
|
65 |
+
{top_results}
|
66 |
+
---
|
67 |
+
Based on the search results, answer the user's query, and use the same language as the user's query.
|
68 |
+
Say which search result you used.
|
69 |
+
Do not use information other than the search results.
|
70 |
+
Say 'No answer found.' if there are no relevant results.
|
71 |
+
Afterwards, say how confident you are in your answer as a percentage.
|
72 |
+
"""
|
73 |
+
PROMPT_TEMPLATE = PreprocessingPromptTemplate(template=PROMPT, input_variables=['query', 'top_results'])
|
74 |
+
PROMPT_TEMPLATE.arg_preprocessing['top_results'] = top_results_to_string
|
75 |
+
|
76 |
+
# TODO give relevance value in prompt
|
77 |
+
# TODO ask gpt to say which sources it used
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
# TODO azure?
|
82 |
+
COLLECTION = Chroma(
|
83 |
+
embedding_function=CustomOpenAIEmbeddings(api_key=os.environ.get("OPENAI_API_KEY", None)),
|
84 |
+
persist_directory="./.chroma",
|
85 |
+
collection_name="CUHK",
|
86 |
+
)
|
87 |
+
# COLLECTION = CHROMA_CLIENT.get_collection(name='CUHK')
|
88 |
+
|
89 |
+
|
90 |
+
def load_chain(api_type):
|
91 |
+
shared_args = {
|
92 |
+
"temperature": 0,
|
93 |
+
"model_name": "gpt-3.5-turbo",
|
94 |
+
"pl_tags": ["cuhk-demo"],
|
95 |
+
"streaming": True,
|
96 |
+
}
|
97 |
+
if api_type == "OpenAI":
|
98 |
+
chat = PromptLayerChatOpenAI(
|
99 |
+
**shared_args,
|
100 |
+
api_key = os.environ.get("OPENAI_API_KEY", None),
|
101 |
+
)
|
102 |
+
elif api_type == "Azure OpenAI":
|
103 |
+
chat = PromptLayerChatOpenAI(
|
104 |
+
api_type = "azure",
|
105 |
+
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None),
|
106 |
+
api_base = os.environ.get("AZURE_OPENAI_API_BASE", None),
|
107 |
+
api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"),
|
108 |
+
engine = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", None),
|
109 |
+
**shared_args
|
110 |
+
)
|
111 |
+
|
112 |
+
chain = chain = LLMChain(llm=chat, prompt=PROMPT_TEMPLATE)
|
113 |
+
return chat, chain
|
114 |
+
|
115 |
+
def initialize_chain(api_type):
|
116 |
+
"Runs at app start"
|
117 |
+
chat, chain = load_chain(api_type)
|
118 |
+
return chat, chain
|
119 |
+
|
120 |
+
def change_chain(api_type, old_chain):
|
121 |
+
chat, chain = load_chain(api_type)
|
122 |
+
return chat, chain
|
123 |
+
|
124 |
+
def find_top_results(query):
|
125 |
+
results = COLLECTION.similarity_search_with_score(query, k=4) # TODO filter by device (windows, mac, android, ios)
|
126 |
+
|
127 |
+
output = "\n".join(f"1. [{d.metadata['title']}]({d.metadata['url']}) <small>(dist: {s})</small>" for d, s in results)
|
128 |
+
return results, output
|
129 |
+
|
130 |
+
def ask_gpt(chain, query, top_results): # top_results: List[Tuple[Document, float]]
|
131 |
+
q = Queue()
|
132 |
+
job_done = object()
|
133 |
+
def task():
|
134 |
+
chain.run(
|
135 |
+
query=query,
|
136 |
+
top_results=top_results,
|
137 |
+
callbacks=[SyncStreamingLLMCallbackHandler(q)],
|
138 |
+
)
|
139 |
+
q.put(job_done)
|
140 |
+
return
|
141 |
+
|
142 |
+
with start_blocking_portal() as portal:
|
143 |
+
portal.start_task_soon(task)
|
144 |
+
|
145 |
+
content = ""
|
146 |
+
while True:
|
147 |
+
next_token = q.get(True, timeout=15)
|
148 |
+
if next_token is job_done:
|
149 |
+
break
|
150 |
+
content += next_token
|
151 |
+
yield content
|
152 |
+
|
153 |
+
|
154 |
+
demo = gr.Blocks(css="""
|
155 |
+
#sidebar {
|
156 |
+
max-width: 300px;
|
157 |
+
}
|
158 |
+
""")
|
159 |
+
with demo:
|
160 |
+
with gr.Row():
|
161 |
+
# sidebar
|
162 |
+
with gr.Column(elem_id="sidebar"):
|
163 |
+
api_type = gr.Radio(
|
164 |
+
["OpenAI", "Azure OpenAI"],
|
165 |
+
value="OpenAI",
|
166 |
+
label="Server",
|
167 |
+
info="You can try changing this if responses are slow."
|
168 |
+
)
|
169 |
+
|
170 |
+
# main
|
171 |
+
with gr.Column():
|
172 |
+
# Company img
|
173 |
+
gr.HTML(r'<div style="display: flex; justify-content: center; align-items: center"><a href="https://thinkcol.com/"><img src="./file=thinkcol-logo.png" alt="ThinkCol" width="357" height="87" /></a></div>')
|
174 |
+
|
175 |
+
chat = gr.State()
|
176 |
+
chain = gr.State()
|
177 |
+
|
178 |
+
query = gr.Textbox(label="Search Query:")
|
179 |
+
top_results_data = gr.State()
|
180 |
+
top_results = gr.Markdown(label="Search Results")
|
181 |
+
response = gr.Textbox(label="AI Response")
|
182 |
+
|
183 |
+
|
184 |
+
load_event = demo.load(initialize_chain, [api_type], [chat, chain])
|
185 |
+
|
186 |
+
query_event = query.submit(find_top_results, [query], [top_results_data, top_results])
|
187 |
+
ask_event = query_event.then(ask_gpt, [chain, query, top_results_data], [response])
|
188 |
+
api_type.change(change_chain,
|
189 |
+
[api_type, chain],
|
190 |
+
[chat, chain],
|
191 |
+
cancels=[load_event, query_event, ask_event])
|
192 |
+
|
193 |
+
|
194 |
+
demo.queue()
|
195 |
+
if __name__ == "__main__":
|
196 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
openai
|
3 |
+
promptlayer
|
4 |
+
langchain
|
5 |
+
chromadb
|
6 |
+
tiktoken
|
thinkcol-logo.png
ADDED
![]() |
util.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Union
|
2 |
+
from types import GeneratorType
|
3 |
+
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
4 |
+
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
5 |
+
from langchain.embeddings.openai import embed_with_retry, OpenAIEmbeddings
|
6 |
+
from pydantic import Extra, Field, root_validator
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
10 |
+
"""Callback handler for streaming LLM responses to a queue."""
|
11 |
+
|
12 |
+
def __init__(self, q):
|
13 |
+
self.q = q
|
14 |
+
|
15 |
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
16 |
+
self.q.put(token)
|
17 |
+
|
18 |
+
|
19 |
+
class SyncStreamingLLMCallbackHandler(BaseCallbackHandler):
|
20 |
+
"""Callback handler for streaming LLM responses to a queue."""
|
21 |
+
|
22 |
+
def __init__(self, q):
|
23 |
+
self.q = q
|
24 |
+
|
25 |
+
def on_llm_start(
|
26 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
27 |
+
) -> None:
|
28 |
+
"""Do nothing."""
|
29 |
+
pass
|
30 |
+
|
31 |
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
32 |
+
self.q.put(token)
|
33 |
+
|
34 |
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
35 |
+
"""Do nothing."""
|
36 |
+
pass
|
37 |
+
|
38 |
+
def on_llm_error(
|
39 |
+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
40 |
+
) -> None:
|
41 |
+
"""Do nothing."""
|
42 |
+
pass
|
43 |
+
|
44 |
+
def on_chain_start(
|
45 |
+
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
46 |
+
) -> None:
|
47 |
+
"""Do nothing."""
|
48 |
+
pass
|
49 |
+
|
50 |
+
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
51 |
+
"""Do nothing."""
|
52 |
+
pass
|
53 |
+
|
54 |
+
def on_chain_error(
|
55 |
+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
56 |
+
) -> None:
|
57 |
+
"""Do nothing."""
|
58 |
+
pass
|
59 |
+
|
60 |
+
def on_tool_start(
|
61 |
+
self,
|
62 |
+
serialized: Dict[str, Any],
|
63 |
+
input_str: str,
|
64 |
+
**kwargs: Any,
|
65 |
+
) -> None:
|
66 |
+
"""Do nothing."""
|
67 |
+
pass
|
68 |
+
|
69 |
+
def on_tool_end(
|
70 |
+
self,
|
71 |
+
output: str,
|
72 |
+
color: Optional[str] = None,
|
73 |
+
observation_prefix: Optional[str] = None,
|
74 |
+
llm_prefix: Optional[str] = None,
|
75 |
+
**kwargs: Any,
|
76 |
+
) -> None:
|
77 |
+
"""Do nothing."""
|
78 |
+
pass
|
79 |
+
|
80 |
+
def on_tool_error(
|
81 |
+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
82 |
+
) -> None:
|
83 |
+
"""Do nothing."""
|
84 |
+
pass
|
85 |
+
|
86 |
+
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
87 |
+
"""Run on agent action."""
|
88 |
+
pass
|
89 |
+
|
90 |
+
def on_agent_finish(
|
91 |
+
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
92 |
+
) -> None:
|
93 |
+
"""Run on agent end."""
|
94 |
+
pass
|
95 |
+
|
96 |
+
|
97 |
+
def concatenate_generators(*args):
|
98 |
+
final_outputs = ""
|
99 |
+
for g in args:
|
100 |
+
if isinstance(g, GeneratorType):
|
101 |
+
for v in g:
|
102 |
+
yield final_outputs + v
|
103 |
+
result = v
|
104 |
+
else:
|
105 |
+
yield final_outputs + g
|
106 |
+
result = g
|
107 |
+
final_outputs += result
|
108 |
+
|
109 |
+
|
110 |
+
class CustomOpenAIEmbeddings(OpenAIEmbeddings):
|
111 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
112 |
+
|
113 |
+
"""
|
114 |
+
A version of OpenAIEmbeddings that allows extra args
|
115 |
+
to be passed to OpenAI functions.
|
116 |
+
Based on langchain's ChatOpenAI.
|
117 |
+
"""
|
118 |
+
class Config:
|
119 |
+
"""Configuration for this pydantic object."""
|
120 |
+
|
121 |
+
extra = Extra.ignore
|
122 |
+
|
123 |
+
@root_validator(pre=True)
|
124 |
+
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
125 |
+
"""Build extra kwargs from additional params that were passed in."""
|
126 |
+
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
127 |
+
|
128 |
+
extra = values.get("model_kwargs", {})
|
129 |
+
for field_name in list(values):
|
130 |
+
if field_name in extra:
|
131 |
+
raise ValueError(f"Found {field_name} supplied twice.")
|
132 |
+
if field_name not in all_required_field_names:
|
133 |
+
# logger.warning(
|
134 |
+
# f"""WARNING! {field_name} is not default parameter.
|
135 |
+
# {field_name} was transferred to model_kwargs.
|
136 |
+
# Please confirm that {field_name} is what you intended."""
|
137 |
+
# )
|
138 |
+
extra[field_name] = values.pop(field_name)
|
139 |
+
|
140 |
+
disallowed_model_kwargs = all_required_field_names | {"model"}
|
141 |
+
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
142 |
+
if invalid_model_kwargs:
|
143 |
+
raise ValueError(
|
144 |
+
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
145 |
+
f"Instead they were passed in as part of `model_kwargs` parameter."
|
146 |
+
)
|
147 |
+
|
148 |
+
values["model_kwargs"] = extra
|
149 |
+
return values
|
150 |
+
|
151 |
+
# use extra args in calls
|
152 |
+
|
153 |
+
# please refer to
|
154 |
+
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
155 |
+
def _get_len_safe_embeddings(
|
156 |
+
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
|
157 |
+
) -> List[List[float]]:
|
158 |
+
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
|
159 |
+
try:
|
160 |
+
import tiktoken
|
161 |
+
|
162 |
+
tokens = []
|
163 |
+
indices = []
|
164 |
+
encoding = tiktoken.model.encoding_for_model(self.model)
|
165 |
+
for i, text in enumerate(texts):
|
166 |
+
if self.model.endswith("001"):
|
167 |
+
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
168 |
+
# replace newlines, which can negatively affect performance.
|
169 |
+
text = text.replace("\n", " ")
|
170 |
+
token = encoding.encode(
|
171 |
+
text,
|
172 |
+
allowed_special=self.allowed_special,
|
173 |
+
disallowed_special=self.disallowed_special,
|
174 |
+
)
|
175 |
+
for j in range(0, len(token), self.embedding_ctx_length):
|
176 |
+
tokens += [token[j : j + self.embedding_ctx_length]]
|
177 |
+
indices += [i]
|
178 |
+
|
179 |
+
batched_embeddings = []
|
180 |
+
_chunk_size = chunk_size or self.chunk_size
|
181 |
+
for i in range(0, len(tokens), _chunk_size):
|
182 |
+
response = embed_with_retry(
|
183 |
+
self,
|
184 |
+
input=tokens[i : i + _chunk_size],
|
185 |
+
engine=self.deployment,
|
186 |
+
request_timeout=self.request_timeout,
|
187 |
+
**self.model_kwargs,
|
188 |
+
)
|
189 |
+
batched_embeddings += [r["embedding"] for r in response["data"]]
|
190 |
+
|
191 |
+
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
|
192 |
+
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
|
193 |
+
for i in range(len(indices)):
|
194 |
+
results[indices[i]].append(batched_embeddings[i])
|
195 |
+
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
196 |
+
|
197 |
+
for i in range(len(texts)):
|
198 |
+
_result = results[i]
|
199 |
+
if len(_result) == 0:
|
200 |
+
average = embed_with_retry(
|
201 |
+
self,
|
202 |
+
input="",
|
203 |
+
engine=self.deployment,
|
204 |
+
request_timeout=self.request_timeout,
|
205 |
+
**self.model_kwargs,
|
206 |
+
)["data"][0]["embedding"]
|
207 |
+
else:
|
208 |
+
average = np.average(
|
209 |
+
_result, axis=0, weights=num_tokens_in_batch[i]
|
210 |
+
)
|
211 |
+
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
212 |
+
|
213 |
+
return embeddings
|
214 |
+
|
215 |
+
except ImportError:
|
216 |
+
raise ValueError(
|
217 |
+
"Could not import tiktoken python package. "
|
218 |
+
"This is needed in order to for OpenAIEmbeddings. "
|
219 |
+
"Please install it with `pip install tiktoken`."
|
220 |
+
)
|
221 |
+
|
222 |
+
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
223 |
+
"""Call out to OpenAI's embedding endpoint."""
|
224 |
+
# handle large input text
|
225 |
+
if len(text) > self.embedding_ctx_length:
|
226 |
+
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
227 |
+
else:
|
228 |
+
if self.model.endswith("001"):
|
229 |
+
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
230 |
+
# replace newlines, which can negatively affect performance.
|
231 |
+
text = text.replace("\n", " ")
|
232 |
+
return embed_with_retry(
|
233 |
+
self, input=[text], engine=engine, request_timeout=self.request_timeout,
|
234 |
+
**self.model_kwargs,
|
235 |
+
)["data"][0]["embedding"]
|