Spaces:
Sleeping
Sleeping
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +25 -0
- config/logger.ini +20 -0
- config/metadata.yaml +25 -0
- lib/aicloudlibs-0.1.0-py3-none-any.whl +0 -0
- models/gpt3tokenizer.pkl +3 -0
- requirement.txt +19 -0
- requirements/requirement.txt +19 -0
- setup.py +42 -0
- src/llm_explain/.env +9 -0
- src/llm_explain/__init__.py +16 -0
- src/llm_explain/config/__init__.py +16 -0
- src/llm_explain/config/config.py +40 -0
- src/llm_explain/config/logger.py +164 -0
- src/llm_explain/exception/constants.py +44 -0
- src/llm_explain/exception/exception.py +29 -0
- src/llm_explain/exception/global_exception.py +60 -0
- src/llm_explain/exception/global_exception_handler.py +40 -0
- src/llm_explain/mappers/__init__.py +16 -0
- src/llm_explain/mappers/mappers.py +101 -0
- src/llm_explain/routing/__init__.py +16 -0
- src/llm_explain/routing/explain_router.py +204 -0
- src/llm_explain/service/__init__.py +16 -0
- src/llm_explain/service/responsible_ai_explain.py +600 -0
- src/llm_explain/service/service.py +230 -0
- src/llm_explain/utility/__init__.py +17 -0
- src/llm_explain/utility/azure.py +62 -0
- src/llm_explain/utility/config.json +2 -0
- src/llm_explain/utility/got.py +423 -0
- src/llm_explain/utility/graph_of_thoughts/__init__.py +16 -0
- src/llm_explain/utility/graph_of_thoughts/controller/__init__.py +18 -0
- src/llm_explain/utility/graph_of_thoughts/controller/controller.py +259 -0
- src/llm_explain/utility/graph_of_thoughts/language_models/__init__.py +20 -0
- src/llm_explain/utility/graph_of_thoughts/language_models/abstract_language_model.py +104 -0
- src/llm_explain/utility/graph_of_thoughts/language_models/azure.py +181 -0
- src/llm_explain/utility/graph_of_thoughts/language_models/chatgpt.py +167 -0
- src/llm_explain/utility/graph_of_thoughts/operations/__init__.py +31 -0
- src/llm_explain/utility/graph_of_thoughts/operations/graph_of_operations.py +78 -0
- src/llm_explain/utility/graph_of_thoughts/operations/operations.py +912 -0
- src/llm_explain/utility/graph_of_thoughts/operations/thought.py +129 -0
- src/llm_explain/utility/graph_of_thoughts/parser/__init__.py +18 -0
- src/llm_explain/utility/graph_of_thoughts/parser/parser.py +99 -0
- src/llm_explain/utility/graph_of_thoughts/prompter/__init__.py +18 -0
- src/llm_explain/utility/graph_of_thoughts/prompter/prompter.py +95 -0
- src/llm_explain/utility/prompt_utils.py +111 -0
- src/llm_explain/utility/prompts/base.py +281 -0
- src/llm_explain/utility/prompts/few_shot.py +218 -0
- src/llm_explain/utility/prompts/instructions.py +20 -0
- src/llm_explain/utility/prompts/output_format.py +122 -0
- src/llm_explain/utility/query_serper.py +143 -0
- src/llm_explain/utility/utility.py +316 -0
Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
6 |
+
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY --chown=user ./requirement.txt requirement.txt
|
10 |
+
|
11 |
+
COPY --chown=user ./lib/aicloudlibs-0.1.0-py3-none-any.whl /lib/
|
12 |
+
|
13 |
+
# COPY --chown=user ./lib/better_profanity-2.0.0-py3-none-any.whl /lib/
|
14 |
+
|
15 |
+
# COPY --chown=user ./lib/privacy-1.0.9-py3-none-any.whl /lib/
|
16 |
+
|
17 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
18 |
+
|
19 |
+
COPY --chown=user . /app
|
20 |
+
|
21 |
+
# Expose the port (default for Hugging Face is 7860)
|
22 |
+
EXPOSE 7860
|
23 |
+
|
24 |
+
# CMD to run the FastAPI app with Uvicorn
|
25 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
config/logger.ini
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
; Copyright 2024 Infosys Ltd.
|
2 |
+
|
3 |
+
; Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
4 |
+
; to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
5 |
+
; and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
; The above copyright notice and this permission notice shall be included in all copies
|
8 |
+
; or substantial portions of the Software.
|
9 |
+
|
10 |
+
; THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
11 |
+
; INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
12 |
+
; AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
13 |
+
; DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
14 |
+
; OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
15 |
+
|
16 |
+
[logDetails]
|
17 |
+
LOG_LEVEL=ERROR
|
18 |
+
FILE_NAME=responsible-ai-servicelogs
|
19 |
+
VERBOSE=False
|
20 |
+
LOG_DIR=/responsible-ai/logs
|
config/metadata.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Infosys Ltd.
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
4 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
5 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
# The above copyright notice and this permission notice shall be included in all copies
|
8 |
+
# or substantial portions of the Software.
|
9 |
+
|
10 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
11 |
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
12 |
+
# AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
13 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
14 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
15 |
+
|
16 |
+
openapi_url: /rai/v1/llm-explainability/openapi.json
|
17 |
+
docs_url: /rai/v1/llm-explainability/docs
|
18 |
+
title: Infosys Responsible AI - responsible-ai-llm-explain - OpenAPI 3.0
|
19 |
+
description: API specs for Infosys Responsible AI LLM-Explainability pillar in OpenAPI 3.0 format
|
20 |
+
contact:
|
21 |
+
email: [email protected]
|
22 |
+
version: v$version
|
23 |
+
openapi_tags:
|
24 |
+
- name: LLM-Explainability
|
25 |
+
description: Operations required for explainability of an LLM
|
lib/aicloudlibs-0.1.0-py3-none-any.whl
ADDED
Binary file (11.1 kB). View file
|
|
models/gpt3tokenizer.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1588f575baa6928e93cd483b385affa2744670c6a2fbe0eeffe5c1ad5eec53e
|
3 |
+
size 59
|
requirement.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
pydantic
|
3 |
+
uvicorn
|
4 |
+
python-dotenv
|
5 |
+
pyyaml
|
6 |
+
scikit-learn
|
7 |
+
seaborn
|
8 |
+
tenacity
|
9 |
+
tiktoken
|
10 |
+
openai
|
11 |
+
kaleido
|
12 |
+
plotly
|
13 |
+
backoff
|
14 |
+
numpy==1.26.4
|
15 |
+
beautifulsoup4
|
16 |
+
rouge_score
|
17 |
+
pip-system-certs
|
18 |
+
aiohttp
|
19 |
+
../lib/aicloudlibs-0.1.0-py3-none-any.whl
|
requirements/requirement.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
pydantic
|
3 |
+
uvicorn
|
4 |
+
python-dotenv
|
5 |
+
pyyaml
|
6 |
+
scikit-learn
|
7 |
+
seaborn
|
8 |
+
tenacity
|
9 |
+
tiktoken
|
10 |
+
openai
|
11 |
+
kaleido
|
12 |
+
plotly
|
13 |
+
backoff
|
14 |
+
numpy==1.26.4
|
15 |
+
beautifulsoup4
|
16 |
+
rouge_score
|
17 |
+
pip-system-certs
|
18 |
+
aiohttp
|
19 |
+
../lib/aicloudlibs-0.1.0-py3-none-any.whl
|
setup.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
|
17 |
+
'''
|
18 |
+
|
19 |
+
from setuptools import find_packages,setup
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
def get_install_requires() -> list[str]:
|
23 |
+
"""Returns requirements.txt parsed to a list"""
|
24 |
+
fname = Path(__file__).parent / 'requirement/requirements.txt'
|
25 |
+
targets = []
|
26 |
+
if fname.exists():
|
27 |
+
with open(fname, 'r') as f:
|
28 |
+
targets = f.read().splitlines()
|
29 |
+
return targets
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
setup(
|
33 |
+
name='responsible-ai-llm_explain',
|
34 |
+
url="responsible_ai_llm_explain",
|
35 |
+
packages=find_packages(),
|
36 |
+
include_package_data=True,
|
37 |
+
python_requires='>=3.6',
|
38 |
+
version='0.1.0',
|
39 |
+
description='AI Cloud Project Management Services',
|
40 |
+
install_requires=get_install_requires(),
|
41 |
+
license='MIT',
|
42 |
+
)
|
src/llm_explain/.env
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AZURE_OPENAI_API_KEY = "${apikey}"
|
2 |
+
AZURE_OPENAI_API_VERSION = "${apiversion}"
|
3 |
+
AZURE_OPENAI_ENDPOINT = "${azureendpoint}"
|
4 |
+
AZURE_DEPLOYMENT_ENGINE = "${engine}"
|
5 |
+
SERPER_KEY = "${serperkey}"
|
6 |
+
# ALLOWED_ORIGINS = "${allowedorigins}"
|
7 |
+
ALLOWED_ORIGINS = "*"
|
8 |
+
ERROR_LOG_TELEMETRY_URL = "${errorlogtelemetryurl}"
|
9 |
+
TELEMETRY_FLAG = "False"
|
src/llm_explain/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
src/llm_explain/config/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
src/llm_explain/config/config.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from configparser import ConfigParser
|
19 |
+
import yaml
|
20 |
+
|
21 |
+
def readConfig(section,filename):
|
22 |
+
|
23 |
+
parser = ConfigParser() # create a parser
|
24 |
+
parser.read(filename) # read config file
|
25 |
+
|
26 |
+
# get section, default to postgresql
|
27 |
+
db = {}
|
28 |
+
if parser.has_section(section):
|
29 |
+
params = parser.items(section)
|
30 |
+
for param in params:
|
31 |
+
db[param[0]] = param[1]
|
32 |
+
else:
|
33 |
+
raise Exception('Section {0} not found in the {1} file'.format(section, filename))
|
34 |
+
|
35 |
+
return db
|
36 |
+
|
37 |
+
def read_config_yaml(filename):
|
38 |
+
with open(filename) as config_file:
|
39 |
+
config_details = yaml.safe_load(config_file)
|
40 |
+
return config_details
|
src/llm_explain/config/logger.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
import datetime
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
from .config import readConfig
|
23 |
+
import contextvars
|
24 |
+
|
25 |
+
request_ids=[]
|
26 |
+
request_id_var = contextvars.ContextVar("request_id_var")
|
27 |
+
|
28 |
+
class CustomLogger(logging.getLoggerClass()):
|
29 |
+
def __init__(self):
|
30 |
+
"""Create a custom logger with the specified `name`. When `log_dir` is None, a simple
|
31 |
+
console logger is created. Otherwise, a file logger is created in addition to the console
|
32 |
+
logger.
|
33 |
+
|
34 |
+
By default, the five standard logging levels (DEBUG through CRITICAL) only display
|
35 |
+
information in the log file if a file handler is added to the logger, but **not** to the
|
36 |
+
console.
|
37 |
+
:param name: name for the logger
|
38 |
+
:param verbose: bool: whether the logging should be verbose; if True, then all messages get
|
39 |
+
logged both to stdout and to the log file (if `log_dir` is specified); if False, then
|
40 |
+
messages only get logged to the log file (if `log_dir` is specified)
|
41 |
+
:param log_dir: str: (optional) the directory for the log file; if not present, no log file
|
42 |
+
is created
|
43 |
+
"""
|
44 |
+
# Create custom logger logging all five levels
|
45 |
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
46 |
+
log_cfg_path = os.path.join(BASE_DIR, 'logger.ini')
|
47 |
+
log_params = readConfig('logDetails', log_cfg_path)
|
48 |
+
|
49 |
+
name = log_params['file_name']
|
50 |
+
try:
|
51 |
+
verbose = bool(log_params['verbose'])
|
52 |
+
except:
|
53 |
+
verbose = False
|
54 |
+
|
55 |
+
log_dir = str(log_params['log_dir'])
|
56 |
+
|
57 |
+
super().__init__(name)
|
58 |
+
self.setLevel(logging.DEBUG)
|
59 |
+
|
60 |
+
# Add new logging level
|
61 |
+
logging.addLevelName(logging.INFO, 'INFO')
|
62 |
+
|
63 |
+
# Determine verbosity settings
|
64 |
+
self.verbose = verbose
|
65 |
+
|
66 |
+
# Create stream handler for logging to stdout (log all five levels)
|
67 |
+
self.stdout_handler = logging.StreamHandler(sys.stdout)
|
68 |
+
self.stdout_handler.setLevel(logging.DEBUG)
|
69 |
+
self.stdout_handler.setFormatter(logging.Formatter('%(message)s'))
|
70 |
+
self.enable_console_output()
|
71 |
+
|
72 |
+
self.file_handler = None
|
73 |
+
if log_dir:
|
74 |
+
self.add_file_handler(name, log_dir)
|
75 |
+
|
76 |
+
def add_file_handler(self, name, log_dir):
|
77 |
+
"""Add a file handler for this logger with the specified `name` (and store the log file
|
78 |
+
under `log_dir`)."""
|
79 |
+
# Format for file log
|
80 |
+
fmt = '%(asctime)s | %(levelname)9s | %(filename)s:%(lineno)d | %(message)s'
|
81 |
+
formatter = logging.Formatter(fmt)
|
82 |
+
|
83 |
+
# Determine log path and file name; create log path if it does not exist
|
84 |
+
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
85 |
+
log_name = f'{str(name).replace(" ", "_")}_{now}'
|
86 |
+
if not os.path.exists(log_dir):
|
87 |
+
try:
|
88 |
+
os.makedirs(log_dir)
|
89 |
+
except:
|
90 |
+
log_dir = '/tmp' if sys.platform.startswith('linux') else '.'
|
91 |
+
|
92 |
+
log_file = os.path.join(log_dir, log_name) + '.log'
|
93 |
+
|
94 |
+
# Create file handler for logging to a file (log all five levels)
|
95 |
+
self.file_handler = logging.FileHandler(log_file)
|
96 |
+
self.file_handler.setLevel(logging.DEBUG)
|
97 |
+
self.file_handler.setFormatter(formatter)
|
98 |
+
self.addHandler(self.file_handler)
|
99 |
+
|
100 |
+
def has_console_handler(self):
|
101 |
+
return len([h for h in self.handlers if type(h) == logging.StreamHandler]) > 0
|
102 |
+
|
103 |
+
def has_file_handler(self):
|
104 |
+
return len([h for h in self.handlers if isinstance(h, logging.FileHandler)]) > 0
|
105 |
+
|
106 |
+
def disable_console_output(self):
|
107 |
+
if not self.has_console_handler():
|
108 |
+
return
|
109 |
+
self.removeHandler(self.stdout_handler)
|
110 |
+
|
111 |
+
def enable_console_output(self):
|
112 |
+
if self.has_console_handler():
|
113 |
+
return
|
114 |
+
self.addHandler(self.stdout_handler)
|
115 |
+
|
116 |
+
def disable_file_output(self):
|
117 |
+
if not self.has_file_handler():
|
118 |
+
return
|
119 |
+
self.removeHandler(self.file_handler)
|
120 |
+
|
121 |
+
def enable_file_output(self):
|
122 |
+
if self.has_file_handler():
|
123 |
+
return
|
124 |
+
self.addHandler(self.file_handler)
|
125 |
+
|
126 |
+
def framework(self, msg, *args, **kwargs):
|
127 |
+
"""Logging method for the FRAMEWORK level. The `msg` gets logged both to stdout and to file
|
128 |
+
(if a file handler is present), irrespective of verbosity settings."""
|
129 |
+
return super().info(msg, *args, **kwargs)
|
130 |
+
|
131 |
+
def _custom_log(self, func, msg, *args, **kwargs):
|
132 |
+
"""Helper method for logging DEBUG through CRITICAL messages by calling the appropriate
|
133 |
+
`func()` from the base class."""
|
134 |
+
# Log normally if verbosity is on
|
135 |
+
if self.verbose:
|
136 |
+
return func(msg, *args, **kwargs)
|
137 |
+
|
138 |
+
# If verbosity is off and there is no file handler, there is nothing left to do
|
139 |
+
if not self.has_file_handler():
|
140 |
+
return
|
141 |
+
|
142 |
+
# If verbosity is off and a file handler is present, then disable stdout logging, log, and
|
143 |
+
# finally reenable stdout logging
|
144 |
+
self.disable_console_output()
|
145 |
+
func(msg, *args, **kwargs)
|
146 |
+
self.enable_console_output()
|
147 |
+
|
148 |
+
def debug(self, msg, *args, **kwargs):
|
149 |
+
self._custom_log(super().debug, msg, *args, **kwargs)
|
150 |
+
|
151 |
+
def info(self, msg, *args, **kwargs):
|
152 |
+
self._custom_log(super().info, msg, *args, **kwargs)
|
153 |
+
|
154 |
+
def warning(self, msg, *args, **kwargs):
|
155 |
+
self._custom_log(super().warning, msg, *args, **kwargs)
|
156 |
+
|
157 |
+
def error(self, msg, *args, **kwargs):
|
158 |
+
self._custom_log(super().error, msg, *args, **kwargs)
|
159 |
+
|
160 |
+
def critical(self, msg, *args, **kwargs):
|
161 |
+
self._custom_log(super().critical, msg, *args, **kwargs)
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
CustomLogger()
|
src/llm_explain/exception/constants.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
# Success Response Codes
|
19 |
+
HTTP_STATUS_CODES = {
|
20 |
+
"OK": 200,
|
21 |
+
"NOT_FOUND": 404,
|
22 |
+
"METHOD_NOT_ALLOWED": 405,
|
23 |
+
"BAD_REQUEST": 400,
|
24 |
+
"CONFLICT": 409,
|
25 |
+
"UNSUPPORTED_MEDIA_TYPE": 415,
|
26 |
+
"UNPROCESSABLE_ENTITY": 422,
|
27 |
+
"SERVICE_UNAVAILABLE": 503,
|
28 |
+
"INTERNAL_SERVER_ERROR": 500,
|
29 |
+
"DATA_ERROR": 500,
|
30 |
+
}
|
31 |
+
|
32 |
+
# Message Constants
|
33 |
+
HTTP_STATUS_MESSAGES = {
|
34 |
+
"OK": "Request processed successfully",
|
35 |
+
"NOT_FOUND": "Resource not found",
|
36 |
+
"METHOD_NOT_ALLOWED": "Method not allowed",
|
37 |
+
"BAD_REQUEST": "Bad request",
|
38 |
+
"CONFLICT": "Conflict",
|
39 |
+
"UNSUPPORTED_MEDIA_TYPE": "Unsupported media type",
|
40 |
+
"UNPROCESSABLE_ENTITY": "Unprocessable entity",
|
41 |
+
"SERVICE_UNAVAILABLE": "Service unavailable",
|
42 |
+
"INTERNAL_SERVER_ERROR": "Internal server error",
|
43 |
+
"DATABASE_CONNECTION_REFUSED": "Database connection refused"
|
44 |
+
}
|
src/llm_explain/exception/exception.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
"""
|
19 |
+
fileName: exception.py
|
20 |
+
description: handles usecase module specific exception
|
21 |
+
"""
|
22 |
+
|
23 |
+
class CustomException(Exception):
|
24 |
+
def __init__(self, message, error_code):
|
25 |
+
super().__init__(message)
|
26 |
+
self.error_code = error_code
|
27 |
+
|
28 |
+
def __str__(self):
|
29 |
+
return f"[Error {self.error_code}]: {self.message}"
|
src/llm_explain/exception/global_exception.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from .constants import HTTP_STATUS_CODES
|
19 |
+
from abc import ABC
|
20 |
+
|
21 |
+
class BenchmarkExceptions(Exception, ABC):
|
22 |
+
"""
|
23 |
+
Abstract base class of all Aicloud DB exceptions.
|
24 |
+
"""
|
25 |
+
def __init__(self, message: str) -> None:
|
26 |
+
super().__init__(message)
|
27 |
+
|
28 |
+
class DBConnectionError(BenchmarkExceptions):
|
29 |
+
def __init__(self, name):
|
30 |
+
self.status_code = HTTP_STATUS_CODES.SERVICE_UNAVAILABLE
|
31 |
+
self.message = HTTP_STATUS_CODES.DATABASE_CONNECTION_REFUSED + name
|
32 |
+
|
33 |
+
class NotSupportedError(BenchmarkExceptions):
|
34 |
+
def __init__(self, msg):
|
35 |
+
self.status_code = HTTP_STATUS_CODES.METHOD_NOT_ALLOWED
|
36 |
+
if not msg:
|
37 |
+
self.message = HTTP_STATUS_CODES.METHOD_NOT_ALLOWED
|
38 |
+
else:
|
39 |
+
self.message=msg
|
40 |
+
|
41 |
+
class InternalServerError(BenchmarkExceptions):
|
42 |
+
def __init__(self, msg):
|
43 |
+
self.status_code = HTTP_STATUS_CODES.INTERNAL_SERVER_ERROR
|
44 |
+
if not msg:
|
45 |
+
self.message = HTTP_STATUS_CODES.INTERNAL_SERVER_ERROR
|
46 |
+
else:
|
47 |
+
self.message=msg
|
48 |
+
|
49 |
+
class MethodArgumentNotValidException(BenchmarkExceptions):
|
50 |
+
def __init__(self, msg):
|
51 |
+
self.status_code = HTTP_STATUS_CODES.BAD_REQUEST
|
52 |
+
if not msg:
|
53 |
+
self.message = HTTP_STATUS_CODES.BAD_REQUEST
|
54 |
+
else:
|
55 |
+
self.message=msg
|
56 |
+
|
57 |
+
class UnSupportedMediaTypeException(BenchmarkExceptions):
|
58 |
+
def __init__(self, contentTypeStr):
|
59 |
+
self.status_code = HTTP_STATUS_CODES.UNSUPPORTED_MEDIA_TYPE
|
60 |
+
self.message = HTTP_STATUS_CODES.UNSUPPORTED_MEDIA_TYPE + contentTypeStr
|
src/llm_explain/exception/global_exception_handler.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from fastapi.exceptions import RequestValidationError
|
19 |
+
from fastapi.encoders import jsonable_encoder
|
20 |
+
from fastapi.responses import JSONResponse
|
21 |
+
from .global_exception import UnSupportedMediaTypeException
|
22 |
+
from .constants import HTTP_STATUS_CODES
|
23 |
+
|
24 |
+
def validation_error_handler(exc: RequestValidationError):
|
25 |
+
return JSONResponse(
|
26 |
+
status_code=HTTP_STATUS_CODES["UNPROCESSABLE_ENTITY"],
|
27 |
+
content=jsonable_encoder({"ERROR": exc.errors()}),
|
28 |
+
)
|
29 |
+
|
30 |
+
def unsupported_mediatype_error_handler(exc: UnSupportedMediaTypeException):
|
31 |
+
return JSONResponse(
|
32 |
+
status_code=HTTP_STATUS_CODES["UNSUPPORTED_MEDIA_TYPE"],
|
33 |
+
content=jsonable_encoder({"ERROR": str(exc.message)}),
|
34 |
+
)
|
35 |
+
|
36 |
+
def http_exception_handler(exc):
|
37 |
+
return JSONResponse(
|
38 |
+
status_code=exc.status_code,
|
39 |
+
content=jsonable_encoder({"ERROR": str(exc.detail)}),
|
40 |
+
)
|
src/llm_explain/mappers/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
src/llm_explain/mappers/mappers.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
"""
|
19 |
+
module: LLM Explainability
|
20 |
+
fileName: mappers.py
|
21 |
+
description: A Pydantic model object for usecase entity model
|
22 |
+
which maps the data model to the usecase entity schema
|
23 |
+
"""
|
24 |
+
|
25 |
+
from pydantic import BaseModel, Field
|
26 |
+
from typing import Optional, List, Dict
|
27 |
+
|
28 |
+
class SentimentAnalysisRequest(BaseModel):
|
29 |
+
inputPrompt: str = Field(example="Unfortunately the movie served with bad visuals but the actors performed well")
|
30 |
+
|
31 |
+
class Config:
|
32 |
+
from_attributes = True
|
33 |
+
|
34 |
+
class SentimentAnalysisResponse(BaseModel):
|
35 |
+
explanation: List
|
36 |
+
|
37 |
+
class Config:
|
38 |
+
from_attributes = True
|
39 |
+
|
40 |
+
class UncertainityRequest(BaseModel):
|
41 |
+
inputPrompt: str = Field(example="Who are the co-founders of Infosys?")
|
42 |
+
response: str = Field(example="Infosys was co-founded by Narayana Murthy along with six other engineers: Nandan Nilekani, S. Gopalakrishnan (Kris), S. D. Shibulal, K. Dinesh, N. S. Raghavan, and Ashok Arora. Established in 1981, Infosys started with a modest capital of $250 and has since grown into one of the largest IT services companies in the world. Narayana Murthy, often regarded as the face of Infosys, played a pivotal role in shaping the company's culture and vision, while the combined efforts of all co-founders contributed to its remarkable growth and success in the global IT industry.")
|
43 |
+
|
44 |
+
class Config:
|
45 |
+
from_attributes = True
|
46 |
+
|
47 |
+
class UncertainityResponse(BaseModel):
|
48 |
+
uncertainty: Dict = Field(example={"score": 0.5, "explanation": "The response is uncertain as it mentions the co-founders of Infosys without providing specific details.", "recommendation": "Maintain the grammatical correctness and focus on providing additional information."})
|
49 |
+
coherence: Dict = Field(example={"score": 0.8, "explanation": "The response is relevant to the prompt as it provides information about the co-founders of Infosys.", "recommendation": "Maintain the grammatical correctness and focus on providing additional information."})
|
50 |
+
time_taken: float = Field(example=0.5)
|
51 |
+
|
52 |
+
class Config:
|
53 |
+
from_attributes = True
|
54 |
+
|
55 |
+
class TokenImportanceResponse(BaseModel):
|
56 |
+
token_importance_mapping:List
|
57 |
+
image_data:Optional[List]
|
58 |
+
token_heatmap:Optional[str]
|
59 |
+
time_taken: float = Field(example=0.5)
|
60 |
+
|
61 |
+
class Config:
|
62 |
+
from_attributes = True
|
63 |
+
|
64 |
+
class TokenImportanceRequest(BaseModel):
|
65 |
+
inputPrompt: str = Field(example="Who are the co-founders of Infosys?")
|
66 |
+
modelName: Optional[str] = Field(example="GPT")
|
67 |
+
|
68 |
+
class Config:
|
69 |
+
from_attributes = True
|
70 |
+
|
71 |
+
class GoTRequest(BaseModel):
|
72 |
+
inputPrompt: str = Field(example="Who are the co-founders of Infosys?")
|
73 |
+
modelName: Optional[str] = Field(example="gpt4")
|
74 |
+
|
75 |
+
class Config:
|
76 |
+
from_attributes = True
|
77 |
+
|
78 |
+
class GoTResponse(BaseModel):
|
79 |
+
final_thought: str = Field(example='The co-founders of Infosys are N. R. Narayana Murthy, ...')
|
80 |
+
score: float = Field(example=9.5)
|
81 |
+
cost_incurred: float = Field(example=0.5)
|
82 |
+
consistency_level: str = Field(example='High Consistent')
|
83 |
+
time_taken: float = Field(example=0.5)
|
84 |
+
|
85 |
+
class Config:
|
86 |
+
from_attributes = True
|
87 |
+
|
88 |
+
class SafeSearchRequest(BaseModel):
|
89 |
+
inputPrompt: str = Field(example="Who are the co-founders of Infosys?")
|
90 |
+
llm_response: str = Field(example="Infosys, a global leader in technology services and consulting, was founded in 1981 by seven visionaries: N.R. Narayana Murthy, Nandan Nilekani, S. Gopalakrishnan, S.D. Shibulal, K. Dinesh, N.S. Raghavan, and Ashok Arora. These co-founders combined their expertise and entrepreneurial spirit to create a company that has since grown into one of the largest and most respected IT services firms in the world. Infosys, headquartered in Bangalore, India, has been instrumental in the global IT revolution, providing innovative solutions and services to clients across various industries. The founders' commitment to excellence and their forward-thinking approach laid a strong foundation for the company's enduring success.")
|
91 |
+
|
92 |
+
class Config:
|
93 |
+
from_attributes = True
|
94 |
+
|
95 |
+
class SafeSearchResponse(BaseModel):
|
96 |
+
internetResponse: List = Field(example="The co-founders of Infosys are N. R. Narayana Murthy, ...")
|
97 |
+
metrics: List = Field(examples=[])
|
98 |
+
time_taken: float = Field(example=0.5)
|
99 |
+
|
100 |
+
class Config:
|
101 |
+
from_attributes = True
|
src/llm_explain/routing/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
src/llm_explain/routing/explain_router.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from llm_explain.mappers.mappers import UncertainityResponse, UncertainityRequest, \
|
19 |
+
TokenImportanceResponse, TokenImportanceRequest, GoTResponse, GoTRequest, \
|
20 |
+
SafeSearchResponse, SafeSearchRequest, SentimentAnalysisRequest, SentimentAnalysisResponse
|
21 |
+
from llm_explain.service.service import ExplainService as service
|
22 |
+
from llm_explain.config.logger import CustomLogger, request_id_var
|
23 |
+
from fastapi import APIRouter, HTTPException
|
24 |
+
from datetime import datetime
|
25 |
+
import concurrent.futures
|
26 |
+
import requests
|
27 |
+
import asyncio
|
28 |
+
import uuid
|
29 |
+
import os
|
30 |
+
|
31 |
+
explanation = APIRouter()
|
32 |
+
|
33 |
+
log = CustomLogger()
|
34 |
+
|
35 |
+
telemetry_flag = os.getenv("TELEMETRY_FLAG")
|
36 |
+
tel_error_url = os.getenv("ERROR_LOG_TELEMETRY_URL")
|
37 |
+
|
38 |
+
## FUNCTION FOR FAIL_SAFE TELEMETRY
|
39 |
+
def send_telemetry_request(explainability_telemetry_request, url):
|
40 |
+
try:
|
41 |
+
response = requests.post(url, json=explainability_telemetry_request)
|
42 |
+
response.raise_for_status()
|
43 |
+
response_data = response.json()
|
44 |
+
log.info(f"Telemetry response: {response_data}")
|
45 |
+
except Exception as e:
|
46 |
+
log.error(str(e))
|
47 |
+
raise HTTPException(
|
48 |
+
status_code=500,
|
49 |
+
detail="Please check with administration!!",
|
50 |
+
headers={"X-Error": "Please check with administration!!"})
|
51 |
+
|
52 |
+
def telemetry_error_logging(cie, request_id_var, api_endpoint):
|
53 |
+
function_name = None
|
54 |
+
# Get the traceback of the exception
|
55 |
+
current_tb = cie.__traceback__
|
56 |
+
# Traverse to the first traceback not from site-packages
|
57 |
+
while current_tb:
|
58 |
+
# Check if the traceback is not from site-packages
|
59 |
+
if "site-packages" not in current_tb.tb_frame.f_code.co_filename:
|
60 |
+
# Get the function name and file name
|
61 |
+
function_name = current_tb.tb_frame.f_code.co_name
|
62 |
+
|
63 |
+
# Move to the next traceback
|
64 |
+
current_tb = current_tb.tb_next
|
65 |
+
|
66 |
+
if telemetry_flag== "True":
|
67 |
+
error_input = {
|
68 |
+
"tenetName": "Explainability",
|
69 |
+
"errorCode": function_name +'_'+ request_id_var.get(),
|
70 |
+
"errorMessage": str(cie),
|
71 |
+
"apiEndPoint": api_endpoint,
|
72 |
+
"errorRequestMethod": "POST"
|
73 |
+
}
|
74 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
75 |
+
executor.submit(send_telemetry_request, error_input, tel_error_url)
|
76 |
+
|
77 |
+
@explanation.post('/llm-explainability/sentiment-analysis',
|
78 |
+
response_model = SentimentAnalysisResponse,
|
79 |
+
summary = "Sentiment analysis of the prompt along with token importance")
|
80 |
+
def sentiment_analysis(payload: SentimentAnalysisRequest):
|
81 |
+
id = uuid.uuid4().hex
|
82 |
+
request_id_var.set(id)
|
83 |
+
log.info("Entered create usecase routing method")
|
84 |
+
try:
|
85 |
+
start_time = datetime.now()
|
86 |
+
log.info(f"start_time: {start_time}")
|
87 |
+
log.info("before invoking sentiment_analysis service ")
|
88 |
+
response = service.sentiment_analysis(payload)
|
89 |
+
log.info("after invoking sentiment_analysis service ")
|
90 |
+
log.info("exit create usecase routing method")
|
91 |
+
end_time = datetime.now()
|
92 |
+
log.info(f"end_time: {end_time}")
|
93 |
+
total_time = end_time - start_time
|
94 |
+
log.info(f"total_time: {total_time}")
|
95 |
+
return response
|
96 |
+
|
97 |
+
except Exception as cie:
|
98 |
+
log.error(cie)
|
99 |
+
telemetry_error_logging(cie, request_id_var, "/llm-explainability/sentiment-analysis")
|
100 |
+
log.info("exit router sentiment_analysis method")
|
101 |
+
raise HTTPException(status_code=500, detail=str(cie))
|
102 |
+
|
103 |
+
@explanation.post('/llm-explainability/uncertainty',
|
104 |
+
response_model = UncertainityResponse,
|
105 |
+
summary = "Get uncertainty scores for the given input")
|
106 |
+
def calculate_uncertainty(payload: UncertainityRequest):
|
107 |
+
id = uuid.uuid4().hex
|
108 |
+
request_id_var.set(id)
|
109 |
+
log.info("Entered create usecase routing method")
|
110 |
+
try:
|
111 |
+
start_time = datetime.now()
|
112 |
+
log.info(f"start_time: {start_time}")
|
113 |
+
log.info("before invoking local_explanation service ")
|
114 |
+
response = asyncio.run(service.local_explanation(payload))
|
115 |
+
log.info("after invoking local_explanation service ")
|
116 |
+
log.info("exit create usecase routing method")
|
117 |
+
end_time = datetime.now()
|
118 |
+
log.info(f"end_time: {end_time}")
|
119 |
+
total_time = end_time - start_time
|
120 |
+
log.info(f"total_time: {total_time}")
|
121 |
+
return response
|
122 |
+
except Exception as cie:
|
123 |
+
log.error(cie)
|
124 |
+
telemetry_error_logging(cie, request_id_var, "/llm-explainability/uncertainty")
|
125 |
+
log.info("exit router local_explanation method")
|
126 |
+
raise HTTPException(status_code=500, detail=str(cie))
|
127 |
+
|
128 |
+
@explanation.post('/llm-explainability/token-importance',
|
129 |
+
response_model = TokenImportanceResponse,
|
130 |
+
summary = "Get importance for each token in the input prompt")
|
131 |
+
def token_importance(payload: TokenImportanceRequest):
|
132 |
+
id = uuid.uuid4().hex
|
133 |
+
request_id_var.set(id)
|
134 |
+
log.info("Entered create usecase routing method")
|
135 |
+
try:
|
136 |
+
start_time = datetime.now()
|
137 |
+
log.info(f"start_time: {start_time}")
|
138 |
+
log.info("before invoking token_importance service ")
|
139 |
+
response = asyncio.run(service.token_importance(payload))
|
140 |
+
log.info("after invoking token_importance service ")
|
141 |
+
log.info("exit create usecase routing method")
|
142 |
+
end_time = datetime.now()
|
143 |
+
log.info(f"end_time: {end_time}")
|
144 |
+
total_time = end_time - start_time
|
145 |
+
log.info(f"total_time: {total_time}")
|
146 |
+
|
147 |
+
return response
|
148 |
+
except Exception as cie:
|
149 |
+
log.error(cie)
|
150 |
+
telemetry_error_logging(cie, request_id_var, "/llm-explainability/token-importance")
|
151 |
+
log.info("exit router token_importance method")
|
152 |
+
raise HTTPException(status_code=500, detail=str(cie))
|
153 |
+
|
154 |
+
@explanation.post('/llm-explainability/got',
|
155 |
+
response_model = GoTResponse,
|
156 |
+
summary = "Graph-of-Thoughts Reasoning")
|
157 |
+
def graph_of_thoughts(payload: GoTRequest):
|
158 |
+
id = uuid.uuid4().hex
|
159 |
+
request_id_var.set(id)
|
160 |
+
log.info("Entered create usecase routing method")
|
161 |
+
try:
|
162 |
+
start_time = datetime.now()
|
163 |
+
log.info(f"start_time: {start_time}")
|
164 |
+
log.info("before invoking graph_of_thoughts service ")
|
165 |
+
response = asyncio.run(service.graph_of_thoughts(payload))
|
166 |
+
log.info("after invoking graph_of_thoughts service ")
|
167 |
+
log.info("exit create usecase routing method")
|
168 |
+
end_time = datetime.now()
|
169 |
+
log.info(f"end_time: {end_time}")
|
170 |
+
total_time = end_time - start_time
|
171 |
+
log.info(f"total_time: {total_time}")
|
172 |
+
|
173 |
+
return response
|
174 |
+
except Exception as cie:
|
175 |
+
log.error(cie)
|
176 |
+
telemetry_error_logging(cie, request_id_var, "/llm-explainability/got")
|
177 |
+
log.info("exit router graph_of_thoughts method")
|
178 |
+
raise HTTPException(status_code=500, detail=str(cie))
|
179 |
+
|
180 |
+
@explanation.post('/llm-explainability/serper_response',
|
181 |
+
response_model = SafeSearchResponse,
|
182 |
+
summary = "Verify LLM response with Google Search")
|
183 |
+
def searchAugmentation(payload: SafeSearchRequest):
|
184 |
+
id = uuid.uuid4().hex
|
185 |
+
request_id_var.set(id)
|
186 |
+
log.info("Entered create usecase routing method")
|
187 |
+
try:
|
188 |
+
start_time = datetime.now()
|
189 |
+
log.info(f"start_time: {start_time}")
|
190 |
+
log.info("before invoking search_augmentation service ")
|
191 |
+
response = asyncio.run(service.search_augmentation(payload))
|
192 |
+
log.info("after invoking search_augmentation service ")
|
193 |
+
log.info("exit create usecase routing method")
|
194 |
+
end_time = datetime.now()
|
195 |
+
log.info(f"end_time: {end_time}")
|
196 |
+
total_time = end_time - start_time
|
197 |
+
log.info(f"total_time: {total_time}")
|
198 |
+
|
199 |
+
return response
|
200 |
+
except Exception as cie:
|
201 |
+
log.error(cie)
|
202 |
+
telemetry_error_logging(cie, request_id_var, "/llm-explainability/serper_response")
|
203 |
+
log.info("exit router search_augmentation method")
|
204 |
+
raise HTTPException(status_code=500, detail=str(cie))
|
src/llm_explain/service/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
src/llm_explain/service/responsible_ai_explain.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from llm_explain.utility.query_serper import GoogleSerperAPIWrapper
|
19 |
+
from llm_explain.utility import got as GraphOfThoughts
|
20 |
+
from llm_explain.utility.prompts.base import Prompt
|
21 |
+
from llm_explain.config.logger import CustomLogger
|
22 |
+
from llm_explain.utility.utility import Utils
|
23 |
+
from llm_explain.utility.azure import Azure
|
24 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
25 |
+
from json.decoder import JSONDecodeError
|
26 |
+
from scipy.stats import gaussian_kde
|
27 |
+
from itertools import combinations
|
28 |
+
import plotly.graph_objects as go
|
29 |
+
import matplotlib.pyplot as plt
|
30 |
+
from openai import AzureOpenAI
|
31 |
+
import pandas as pd
|
32 |
+
import numpy as np
|
33 |
+
import matplotlib
|
34 |
+
import asyncio
|
35 |
+
import base64
|
36 |
+
import time
|
37 |
+
import html
|
38 |
+
import json
|
39 |
+
import ast
|
40 |
+
import os
|
41 |
+
|
42 |
+
log = CustomLogger()
|
43 |
+
|
44 |
+
class ResponsibleAIExplain:
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def llm_response_to_json(response):
|
48 |
+
"""
|
49 |
+
Converts a substring of the given response that is in JSON format into a Python dictionary.
|
50 |
+
|
51 |
+
This function searches for the first occurrence of '{' and the last occurrence of '}' to find the JSON substring.
|
52 |
+
It then attempts to parse this substring into a Python dictionary. If the parsing is successful, the dictionary
|
53 |
+
is returned. If the substring is not valid JSON, the function will return None.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
- response (str): The response string that potentially contains JSON content.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
- dict: A dictionary representation of the JSON substring found within the response.
|
60 |
+
- None: If no valid JSON substring is found or if an error occurs during parsing.
|
61 |
+
"""
|
62 |
+
try:
|
63 |
+
result = None # Initialize result to None in case no valid JSON is found
|
64 |
+
|
65 |
+
# Find the start index of the first '{' character and end index of the last '}' character
|
66 |
+
start_index = response.find('{')
|
67 |
+
end_index = response.rfind('}')
|
68 |
+
|
69 |
+
# Check if both '{' and '}' are found and '{' comes before '}'
|
70 |
+
if start_index != -1 and end_index != -1 and end_index > start_index:
|
71 |
+
json_content = response[start_index:end_index+1] # Extract the substring that is potentially in JSON format
|
72 |
+
result = json.loads(json_content) # Attempt to parse the JSON substring into a Python dictionary
|
73 |
+
|
74 |
+
return result
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
# Log the exception if any error occurs during parsing
|
78 |
+
log.error(f"An error occurred while parsing JSON from response: {e}", exc_info=True)
|
79 |
+
raise
|
80 |
+
|
81 |
+
async def analyze_heatmap(df_input):
|
82 |
+
base64_encoded_imgs=[]
|
83 |
+
try:
|
84 |
+
|
85 |
+
df = df_input.copy()
|
86 |
+
|
87 |
+
if "token" not in df.columns or "importance_value" not in df.columns:
|
88 |
+
raise ValueError("The DataFrame must contain 'token' and 'importance_value' columns.")
|
89 |
+
|
90 |
+
df["Position"] = range(len(df))
|
91 |
+
|
92 |
+
# Calculate histogram data
|
93 |
+
hist, bin_edges = np.histogram(df["importance_value"], bins=20)
|
94 |
+
# Get the viridis colormap
|
95 |
+
viridis = plt.get_cmap("viridis")
|
96 |
+
# Initialize the figure
|
97 |
+
fig = go.Figure()
|
98 |
+
|
99 |
+
# Create the histogram bars with viridis coloring
|
100 |
+
for i, freq in enumerate(hist):
|
101 |
+
color = f"rgb({int(viridis(i / (len(bin_edges) - 1))[0] * 255)}, {int(viridis(i / (len(bin_edges) - 1))[1] * 255)}, {int(viridis(i / (len(bin_edges) - 1))[2] * 255)})"
|
102 |
+
fig.add_trace(
|
103 |
+
go.Bar(
|
104 |
+
x=[(bin_edges[i] + bin_edges[i + 1]) / 2],
|
105 |
+
y=[freq],
|
106 |
+
width=np.diff(bin_edges)[i],
|
107 |
+
marker=dict(color=color),
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
# Calculate and add the KDE line
|
112 |
+
x_kde = np.linspace(min(df["importance_value"]), max(df["importance_value"]), 500)
|
113 |
+
kde = gaussian_kde(df["importance_value"])
|
114 |
+
y_kde = kde(x_kde) * sum(hist) * (bin_edges[1] - bin_edges[0])
|
115 |
+
fig.add_trace(
|
116 |
+
go.Scatter(
|
117 |
+
x=x_kde, y=y_kde, mode="lines", line_shape="spline", line=dict(color="red")
|
118 |
+
)
|
119 |
+
)
|
120 |
+
# Additional styling
|
121 |
+
fig.update_layout(
|
122 |
+
title=" Distribution of Importance Scores",
|
123 |
+
title_font={'size': 25},
|
124 |
+
xaxis_title="Importance Value",
|
125 |
+
yaxis_title="Frequency",
|
126 |
+
showlegend=False,
|
127 |
+
)
|
128 |
+
|
129 |
+
img_bytes = fig.to_image(format="png")
|
130 |
+
|
131 |
+
|
132 |
+
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
133 |
+
|
134 |
+
base64_encoded_imgs.append(img_base64)
|
135 |
+
|
136 |
+
# Normalize the importance values
|
137 |
+
min_val = df["importance_value"].min()
|
138 |
+
|
139 |
+
max_val = df["importance_value"].max()
|
140 |
+
|
141 |
+
normalized_values = (df["importance_value"] - min_val) / (max_val - min_val)
|
142 |
+
|
143 |
+
# Initialize the figure
|
144 |
+
fig = go.Figure()
|
145 |
+
|
146 |
+
# Create the bars, colored based on normalized importance_value
|
147 |
+
for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)):
|
148 |
+
color = f"rgb({int(viridis(norm_value)[0] * 255)}, {int(viridis(norm_value)[1] * 255)}, {int(viridis(norm_value)[2] * 255)})"
|
149 |
+
fig.add_trace(
|
150 |
+
go.Bar(
|
151 |
+
x=[i], # Use index for x-axis
|
152 |
+
y=[df["importance_value"].iloc[i]],
|
153 |
+
width=0.9, # Set the width to make bars touch each other
|
154 |
+
marker=dict(color=color),
|
155 |
+
)
|
156 |
+
)
|
157 |
+
# Additional styling
|
158 |
+
fig.update_layout(
|
159 |
+
title="Importance Score per Token",
|
160 |
+
title_font={'size': 25},
|
161 |
+
xaxis_title="Token",
|
162 |
+
yaxis_title="Importance Value",
|
163 |
+
showlegend=False,
|
164 |
+
bargap=0, # Remove gap between bars
|
165 |
+
xaxis=dict( # Set tick labels to tokens
|
166 |
+
tickmode="array",
|
167 |
+
tickvals=list(range(len(df["token"]))),
|
168 |
+
ticktext=list(df["token"]),
|
169 |
+
),
|
170 |
+
autosize=False, # Disable automatic sizing
|
171 |
+
width= max(10, len(df["token"]) * 0.3) * 100, # Convert to pixels
|
172 |
+
)
|
173 |
+
# Rotate x-axis labels by 45 degrees
|
174 |
+
fig.update_xaxes(tickangle=-45)
|
175 |
+
|
176 |
+
img_bytes = fig.to_image(format="png")
|
177 |
+
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
178 |
+
base64_encoded_imgs.append(img_base64)
|
179 |
+
|
180 |
+
top_10_important = df.nlargest(10, 'importance_value')
|
181 |
+
top_10=top_10_important.to_dict(orient='records')
|
182 |
+
|
183 |
+
# Extract the importance scores
|
184 |
+
importance_values = df["importance_value"].values
|
185 |
+
|
186 |
+
# Normalize the importance scores to be between 0 and 1
|
187 |
+
min_val = np.min(importance_values)
|
188 |
+
max_val = np.max(importance_values)
|
189 |
+
|
190 |
+
if max_val - min_val != 0:
|
191 |
+
normalized_importance_values = (importance_values - min_val) / (max_val - min_val)
|
192 |
+
else:
|
193 |
+
normalized_importance_values = np.zeros_like(importance_values)
|
194 |
+
|
195 |
+
# Generate a colormap for the heatmap
|
196 |
+
cmap = matplotlib.colormaps["inferno"]
|
197 |
+
|
198 |
+
# Helper function to determine the text color based on the background color
|
199 |
+
def get_text_color(bg_color):
|
200 |
+
brightness = 0.299 * bg_color[0] + 0.587 * bg_color[1] + 0.114 * bg_color[2]
|
201 |
+
if brightness < 0.5:
|
202 |
+
return "white"
|
203 |
+
else:
|
204 |
+
return "black"
|
205 |
+
|
206 |
+
# Initialize HTML string
|
207 |
+
html_string = ""
|
208 |
+
|
209 |
+
# Loop over tokens and construct the HTML string
|
210 |
+
for idx, (token, importance) in df_input.iterrows():
|
211 |
+
rgba = cmap(normalized_importance_values[idx])
|
212 |
+
bg_color = rgba[:3]
|
213 |
+
text_color = get_text_color(bg_color)
|
214 |
+
|
215 |
+
# Explicitly handle special characters
|
216 |
+
token_escaped = html.escape(token).replace('`', '`').replace('$', '$') # Handle backticks and dollar signs
|
217 |
+
html_string += f"<span style='background-color: rgba({int(bg_color[0]*255)}, {int(bg_color[1]*255)}, {int(bg_color[2]*255)}, 1); color: {text_color};'>{token_escaped}</span> "
|
218 |
+
|
219 |
+
return top_10,base64_encoded_imgs,html_string
|
220 |
+
except Exception as e:
|
221 |
+
log.error(e, exc_info=True)
|
222 |
+
raise
|
223 |
+
|
224 |
+
async def calculate_uncertainty(n : int, prompt: str):
|
225 |
+
try:
|
226 |
+
max_tokens=1000
|
227 |
+
client = AzureOpenAI(
|
228 |
+
api_key = os.getenv("AZURE_OPENAI_API_KEY") ,
|
229 |
+
api_version = os.getenv("AZURE_OPENAI_API_VERSION") ,
|
230 |
+
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
|
231 |
+
)
|
232 |
+
|
233 |
+
try:
|
234 |
+
response = client.chat.completions.create(
|
235 |
+
n=n,
|
236 |
+
model=os.getenv("AZURE_DEPLOYMENT_ENGINE"), # model = "deployment_name".
|
237 |
+
messages=[
|
238 |
+
{"role": "system", "content": "Assistant is a large language model trained by OpenAI."},
|
239 |
+
{"role": "user", "content": prompt}
|
240 |
+
],
|
241 |
+
logprobs=True,
|
242 |
+
top_logprobs=2,
|
243 |
+
max_tokens=100
|
244 |
+
)
|
245 |
+
except Exception as e:
|
246 |
+
log.error("An error occurred while calling the AzureOpenAI API", exc_info=True)
|
247 |
+
raise Exception
|
248 |
+
cc=response.choices
|
249 |
+
response_object ={}
|
250 |
+
choices = []
|
251 |
+
for i,c in enumerate(cc):
|
252 |
+
contents=c.logprobs.content
|
253 |
+
choice_i={
|
254 |
+
"text": c.message.content
|
255 |
+
}
|
256 |
+
logprobs = {}
|
257 |
+
token_logprobs = []
|
258 |
+
tokens=[]
|
259 |
+
top_logprobs=[]
|
260 |
+
for content in contents:
|
261 |
+
token_logprobs.append(content.logprob)
|
262 |
+
temp={}
|
263 |
+
tokens.append(content.token)
|
264 |
+
top_props=content.top_logprobs
|
265 |
+
for k in top_props:
|
266 |
+
temp[k.token]=k.logprob
|
267 |
+
top_logprobs.append(temp)
|
268 |
+
logprobs["token_logprobs"]=token_logprobs
|
269 |
+
logprobs["tokens"]=tokens
|
270 |
+
logprobs["top_logprobs"]=top_logprobs
|
271 |
+
choice_i["logprobs"]=logprobs
|
272 |
+
choice_i["index"]=i
|
273 |
+
|
274 |
+
choices.append(choice_i)
|
275 |
+
response_object["choices"]=choices
|
276 |
+
|
277 |
+
entropies = []
|
278 |
+
distances = []
|
279 |
+
choice_embeddings = []
|
280 |
+
choice_embedding_tasks = [Utils.get_embedding(choice['text']) for choice in response_object['choices']]
|
281 |
+
choice_embeddings = await asyncio.gather(*choice_embedding_tasks)
|
282 |
+
|
283 |
+
async def process_choice(choice, choice_embedding):
|
284 |
+
top_logprobs_list = choice['logprobs']['top_logprobs']
|
285 |
+
mean_cosine_distances = []
|
286 |
+
normalized_entropies = []
|
287 |
+
|
288 |
+
tasks = [Utils.process_token_async(i, top_logprobs_list, choice, choice_embedding, max_tokens) for i in range(len(top_logprobs_list))]
|
289 |
+
results = await asyncio.gather(*tasks)
|
290 |
+
|
291 |
+
for mean_distance, normalized_entropy in results:
|
292 |
+
mean_cosine_distances.append(mean_distance)
|
293 |
+
normalized_entropies.append(normalized_entropy)
|
294 |
+
|
295 |
+
return mean_cosine_distances, normalized_entropies
|
296 |
+
|
297 |
+
|
298 |
+
choice_tasks = [process_choice(choice, emb) for choice, emb in zip(response_object['choices'], choice_embeddings)]
|
299 |
+
results = await asyncio.gather(*choice_tasks)
|
300 |
+
|
301 |
+
|
302 |
+
for mean_cosine_distances, normalized_entropies in results:
|
303 |
+
distances.append(mean_cosine_distances)
|
304 |
+
entropies.append(normalized_entropies)
|
305 |
+
|
306 |
+
choice_distances = []
|
307 |
+
for emb1, emb2 in combinations(choice_embeddings, 2):
|
308 |
+
cosine_sim = cosine_similarity(emb1.reshape(1, -1), emb2.reshape(1, -1))[0][0]
|
309 |
+
choice_distances.append(1 - cosine_sim)
|
310 |
+
mean_choice_distance = np.mean(choice_distances)
|
311 |
+
uncertainty_scores = {'entropies': entropies, 'distances': distances, 'mean_choice_distance': mean_choice_distance}
|
312 |
+
return Utils.display_metrics(uncertainty_scores, response_object, n)
|
313 |
+
except Exception as e:
|
314 |
+
log.error(e, exc_info=True)
|
315 |
+
raise Exception
|
316 |
+
|
317 |
+
@staticmethod
|
318 |
+
def normalize_scores(dict_list):
|
319 |
+
try:
|
320 |
+
# Calculate the total sum of all importance scores
|
321 |
+
total_sum = sum(d['importance_score'] for d in dict_list)
|
322 |
+
|
323 |
+
# If the total sum is zero, return the original list (to handle cases where all scores are zero)
|
324 |
+
if total_sum == 0:
|
325 |
+
return dict_list
|
326 |
+
|
327 |
+
# Normalize the scores to ensure their sum equals 100
|
328 |
+
normalized_scores = [round((d['importance_score'] / total_sum) * 100) for d in dict_list]
|
329 |
+
|
330 |
+
# Adjust the scores to ensure their sum is exactly 100
|
331 |
+
adjustment = 100 - sum(normalized_scores)
|
332 |
+
normalized_scores[0] += adjustment
|
333 |
+
|
334 |
+
# Update the original list with normalized scores
|
335 |
+
for i, d in enumerate(dict_list):
|
336 |
+
d['importance_score'] = normalized_scores[i]
|
337 |
+
|
338 |
+
return dict_list
|
339 |
+
|
340 |
+
except KeyError as e:
|
341 |
+
log.error(f"KeyError: Missing key in one of the dictionaries - {e}")
|
342 |
+
raise
|
343 |
+
except TypeError as e:
|
344 |
+
log.error(f"TypeError: Invalid type encountered - {e}")
|
345 |
+
raise
|
346 |
+
except Exception as e:
|
347 |
+
log.error(f"An unexpected error occurred: {e}")
|
348 |
+
raise
|
349 |
+
|
350 |
+
@staticmethod
|
351 |
+
def filter_token_importance(scores, anchors):
|
352 |
+
import re
|
353 |
+
try:
|
354 |
+
# Split each phrase in anchors into individual words, remove special characters, and convert to lowercase
|
355 |
+
anchors = [re.sub(r'\W+', '', word).lower() for anchor in anchors for word in anchor.split()]
|
356 |
+
|
357 |
+
importance_scores = [] # Initialize a list to store the importance scores of the anchors
|
358 |
+
for score in scores: # Iterate through the scores list
|
359 |
+
cleaned_token = re.sub(r'\W+', '', str(score['token'])).lower()
|
360 |
+
if cleaned_token in anchors: # Check if the token value is in the anchors list
|
361 |
+
importance_scores.append(score['importance_score']) # Append the importance score to the list
|
362 |
+
|
363 |
+
# Calculate the remaining importance score
|
364 |
+
x = 100 - sum(importance_scores)
|
365 |
+
|
366 |
+
filtered_tokens = []
|
367 |
+
for score in scores: # Iterate through the scores list
|
368 |
+
cleaned_token = re.sub(r'\W+', '', str(score['token'])).lower()
|
369 |
+
if cleaned_token in anchors: # Check if the token value is in the anchors list
|
370 |
+
updated_importance = {'token': score['token'],
|
371 |
+
'importance_score': score['importance_score'] + (x / len(importance_scores)),
|
372 |
+
'position': score['position']}
|
373 |
+
filtered_tokens.append(updated_importance) # Append the updated importance score to the new list
|
374 |
+
return filtered_tokens
|
375 |
+
|
376 |
+
except KeyError as e:
|
377 |
+
log.error(f"KeyError: Missing key in one of the dictionaries - {e}")
|
378 |
+
raise
|
379 |
+
except TypeError as e:
|
380 |
+
log.error(f"TypeError: Invalid type encountered - {e}")
|
381 |
+
raise
|
382 |
+
except ZeroDivisionError as e:
|
383 |
+
log.error(f"ZeroDivisionError: Division by zero encountered - {e}")
|
384 |
+
raise
|
385 |
+
except Exception as e:
|
386 |
+
log.error(f"An unexpected error occurred: {e}")
|
387 |
+
raise
|
388 |
+
|
389 |
+
def sentiment_analysis(text: str, class_names):
|
390 |
+
log.info("Running local_explain")
|
391 |
+
try:
|
392 |
+
start_time = time.time()
|
393 |
+
explanation = Azure().generate(Prompt.get_classification_prompt(text))
|
394 |
+
end_time = time.time()
|
395 |
+
total_time = round(end_time-start_time, 3)
|
396 |
+
explanation = ResponsibleAIExplain.llm_response_to_json(explanation)
|
397 |
+
print('explanation', explanation)
|
398 |
+
# Normalize the importance scores to ensure their sum equals 100
|
399 |
+
explanation['token_importance_mapping'] = ResponsibleAIExplain.normalize_scores(explanation['token_importance_mapping'])
|
400 |
+
|
401 |
+
# Extract the top 10 important tokens
|
402 |
+
tokens_mapping = ResponsibleAIExplain.filter_token_importance(explanation['token_importance_mapping'], explanation['Keywords'])
|
403 |
+
|
404 |
+
return {"predictedTarget": explanation['Sentiment'],
|
405 |
+
"anchor": explanation['Keywords'],
|
406 |
+
"explanation": explanation['Explanation'],
|
407 |
+
"token_importance_mapping": tokens_mapping,
|
408 |
+
"time_taken": total_time}
|
409 |
+
except Exception as e:
|
410 |
+
log.error(e,exc_info=True)
|
411 |
+
raise
|
412 |
+
|
413 |
+
async def local_explanation(prompt: str, response: str):
|
414 |
+
try:
|
415 |
+
start_time = time.time()
|
416 |
+
explanation = Azure().generate(Prompt.get_local_explanation_prompt(prompt, response))
|
417 |
+
end_time = time.time()
|
418 |
+
total_time = round(end_time-start_time, 3)
|
419 |
+
|
420 |
+
explanation = ResponsibleAIExplain.llm_response_to_json(explanation)
|
421 |
+
explanation['time_taken'] = total_time
|
422 |
+
|
423 |
+
return explanation
|
424 |
+
except Exception as e:
|
425 |
+
log.error(e, exc_info=True)
|
426 |
+
raise
|
427 |
+
|
428 |
+
async def process_importance(importance_function, *args, **kwargs):
|
429 |
+
try:
|
430 |
+
start_time = time.time()
|
431 |
+
importance_map = await importance_function(*args, **kwargs)
|
432 |
+
importance_map_df = pd.DataFrame(importance_map, columns=['token', 'importance_value'])
|
433 |
+
offset = importance_map_df['importance_value'].mean()
|
434 |
+
|
435 |
+
importance_log = Utils.scale_importance_log(
|
436 |
+
importance_map,
|
437 |
+
base=None,
|
438 |
+
offset=offset,
|
439 |
+
min_percentile=0,
|
440 |
+
max_percentile=100,
|
441 |
+
scaling_factor=1,
|
442 |
+
bias=0
|
443 |
+
)
|
444 |
+
importance_log_df = pd.DataFrame(importance_log, columns=['token', 'importance_value'])
|
445 |
+
end_time = time.time()
|
446 |
+
total_time = round(end_time-start_time, 3)
|
447 |
+
return importance_log_df, total_time
|
448 |
+
|
449 |
+
except Exception as e:
|
450 |
+
log.error(e, exc_info=True)
|
451 |
+
raise
|
452 |
+
|
453 |
+
async def prompt_based_token_importance(prompt):
|
454 |
+
|
455 |
+
try:
|
456 |
+
start_time = time.time()
|
457 |
+
max_retries = 5
|
458 |
+
attempts = 0
|
459 |
+
while attempts < max_retries:
|
460 |
+
try:
|
461 |
+
explanation = Azure().generate(Prompt.get_token_importance_prompt(prompt))
|
462 |
+
# Manually find the JSON substring within the mixed content
|
463 |
+
start_index = explanation.find('{')
|
464 |
+
end_index = explanation.rfind('}')
|
465 |
+
if start_index != -1 and end_index != -1 and end_index > start_index:
|
466 |
+
json_content = explanation[start_index:end_index+1]
|
467 |
+
result = json.loads(json_content)
|
468 |
+
# If JSON loads successfully, break out of the loop
|
469 |
+
break
|
470 |
+
except JSONDecodeError:
|
471 |
+
attempts += 1
|
472 |
+
if attempts == max_retries:
|
473 |
+
raise Exception("Failed to decode JSON after 5 attempts.")
|
474 |
+
else:
|
475 |
+
log.debug(f"JSONDecodeError encountered. Retrying... Attempt {attempts}/{max_retries}")
|
476 |
+
# Add a delay before the next attempt
|
477 |
+
time.sleep(2) # Delay for 2 seconds
|
478 |
+
|
479 |
+
# Assuming 'result' is a dictionary with "Token" and "Importance Score" as keys, and their values are lists
|
480 |
+
# First, create a DataFrame from the 'result' dictionary
|
481 |
+
tokens = result['Token']
|
482 |
+
scores = result['Importance Score']
|
483 |
+
positions = list(range(1, len(result['Token']) + 1))
|
484 |
+
|
485 |
+
# Find the length of the shortest list
|
486 |
+
min_length = min(len(tokens), len(scores), len(positions))
|
487 |
+
|
488 |
+
# Trim the lists to the length of the shortest list
|
489 |
+
tokens = tokens[:min_length]
|
490 |
+
scores = scores[:min_length]
|
491 |
+
positions = positions[:min_length]
|
492 |
+
|
493 |
+
df = pd.DataFrame({
|
494 |
+
'token': tokens,
|
495 |
+
'importance_value': scores,
|
496 |
+
'position': positions
|
497 |
+
})
|
498 |
+
|
499 |
+
df['importance_value'] = df['importance_value'].astype(float)
|
500 |
+
|
501 |
+
# Sort the DataFrame by 'Importance Score' in descending order to get the most important tokens first
|
502 |
+
df_sorted = df.sort_values(by='importance_value', ascending=False)
|
503 |
+
|
504 |
+
# Select the top 10 important tokens
|
505 |
+
df_top10 = df_sorted.head(10)
|
506 |
+
df_top10.reset_index(drop=True, inplace=True)
|
507 |
+
end_time = time.time()
|
508 |
+
total_time = round(end_time-start_time, 3)
|
509 |
+
top_10, base64_encoded_imgs, token_heatmap = await ResponsibleAIExplain.analyze_heatmap(df_top10[['token', 'importance_value']])
|
510 |
+
|
511 |
+
return df_top10.to_dict(orient='records'), base64_encoded_imgs, token_heatmap, total_time
|
512 |
+
|
513 |
+
except Exception as e:
|
514 |
+
log.error(e, exc_info=True)
|
515 |
+
raise
|
516 |
+
|
517 |
+
async def graph_of_thoughts(prompt: str, modelName: str):
|
518 |
+
try:
|
519 |
+
start_time = time.time()
|
520 |
+
budget = 30
|
521 |
+
task = "answer the following question"
|
522 |
+
question = prompt
|
523 |
+
approaches = [GraphOfThoughts.got]
|
524 |
+
modelName = modelName
|
525 |
+
|
526 |
+
formatted_graph, formatted_thoughts = GraphOfThoughts.run(task=task, question=question,
|
527 |
+
methods=approaches,
|
528 |
+
budget=budget,
|
529 |
+
lm_name=modelName)
|
530 |
+
|
531 |
+
formatted_graph[3]['operation'] = 'final_thought'
|
532 |
+
for i in range(4):
|
533 |
+
thoughts = formatted_graph[i]['thoughts']
|
534 |
+
for j in range(len(thoughts)):
|
535 |
+
formatted_graph[i]['thoughts'][j]['score'] = round(formatted_graph[i]['thoughts'][j]['score'], 2)
|
536 |
+
end_time = time.time()
|
537 |
+
total_time = round(end_time-start_time, 3)
|
538 |
+
|
539 |
+
return formatted_graph, formatted_thoughts, total_time
|
540 |
+
except Exception as e:
|
541 |
+
log.error(e, exc_info=True)
|
542 |
+
raise
|
543 |
+
|
544 |
+
async def search_augmentation(inputPrompt, llmResponse):
|
545 |
+
try:
|
546 |
+
import datetime
|
547 |
+
current_date = datetime.datetime.now()
|
548 |
+
|
549 |
+
start_time = time.time()
|
550 |
+
|
551 |
+
# Step 1: Generate Facts with LLM response
|
552 |
+
facts = Azure().generate(Prompt.generate_facts_prompt(inputPrompt, llmResponse, current_date))
|
553 |
+
if isinstance(facts, str):
|
554 |
+
facts = ResponsibleAIExplain.llm_response_to_json(facts)
|
555 |
+
facts_list = [fact['Fact'] for fact in facts['Facts']] # Extracting the facts into a list of strings
|
556 |
+
|
557 |
+
# Step 2: Filter the facts that are relevant to the input prompt
|
558 |
+
filtered_facts = Azure().generate(Prompt.filter_facts_prompt(prompt=inputPrompt, facts=facts_list))
|
559 |
+
filtered_facts = ast.literal_eval(filtered_facts)
|
560 |
+
filtered_facts_ir = [fact + ' is this statement valid as of today ? why ? #' for fact in filtered_facts]
|
561 |
+
questions = [inputPrompt] + filtered_facts_ir
|
562 |
+
|
563 |
+
# Step 3: Run the prompt and facts through the Google Serper API
|
564 |
+
search = GoogleSerperAPIWrapper()
|
565 |
+
internet_responses = await search.run([inputPrompt])
|
566 |
+
answers = [item[0]['content'] for item in internet_responses]
|
567 |
+
qa_dict_list_prompt = [{'question': q, 'answer': a} for q, a in zip([inputPrompt], answers)] # Creating the list of dictionaries
|
568 |
+
|
569 |
+
internet_responses = await search.run(questions)
|
570 |
+
answers_facts = [item[0]['content'] for item in internet_responses]
|
571 |
+
qa_dict_list = [{'question': q, 'answer': a} for q, a in zip(questions, answers_facts)] # Creating the list of dictionaries
|
572 |
+
|
573 |
+
if len(facts_list) == 0:
|
574 |
+
return {'internetResponse': qa_dict_list_prompt[0]['answer'],
|
575 |
+
'factual_check': {"Score": 0.0,
|
576 |
+
"explanation_factual_accuracy": {'Result': ['No facts found in the LLM response.']} }}
|
577 |
+
|
578 |
+
# Step 4: Summarize the internet responses for prompt and facts
|
579 |
+
summary_prompt = Azure().generate(Prompt.summarize_prompt(qa_dict_list_prompt))
|
580 |
+
|
581 |
+
# Step 5: Evaluate fact with Google Search results
|
582 |
+
facts = Azure().generate(Prompt.evaluate_facts_prompt(facts=filtered_facts_ir, context=qa_dict_list, prompt=inputPrompt))
|
583 |
+
if isinstance(facts, str):
|
584 |
+
facts = ResponsibleAIExplain.llm_response_to_json(facts)
|
585 |
+
|
586 |
+
# In facts['Result'], each fact is a dictionary with keys 'Fact', 'Reasoning', and 'Judgement', update Fact with the filtered facts
|
587 |
+
for i, fact in enumerate(facts['Result']):
|
588 |
+
fact['Fact'] = filtered_facts[i]
|
589 |
+
|
590 |
+
factuality_check = { "Score": 1.0,
|
591 |
+
"explanation_factual_accuracy": facts }
|
592 |
+
end_time = time.time()
|
593 |
+
total_time = round(end_time-start_time, 3)
|
594 |
+
return {'internetResponse': summary_prompt,
|
595 |
+
'factual_check': factuality_check,
|
596 |
+
'time_taken': total_time}
|
597 |
+
|
598 |
+
except Exception as e:
|
599 |
+
log.error(e, exc_info=True)
|
600 |
+
raise
|
src/llm_explain/service/service.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from llm_explain.service.responsible_ai_explain import ResponsibleAIExplain
|
19 |
+
from llm_explain.config.logger import CustomLogger
|
20 |
+
from llm_explain.utility.utility import Utils
|
21 |
+
from llm_explain.mappers.mappers import UncertainityResponse, TokenImportanceRequest, TokenImportanceResponse, SafeSearchResponse, \
|
22 |
+
GoTResponse, GoTRequest, UncertainityRequest, SentimentAnalysisRequest, SentimentAnalysisResponse
|
23 |
+
import pandas as pd
|
24 |
+
import joblib
|
25 |
+
import time
|
26 |
+
|
27 |
+
log = CustomLogger()
|
28 |
+
|
29 |
+
class Payload:
|
30 |
+
def __init__(self, **entries):
|
31 |
+
self.__dict__.update(entries)
|
32 |
+
|
33 |
+
class ExplainService:
|
34 |
+
async def calculate_uncertainty(payload : dict):
|
35 |
+
"""
|
36 |
+
Asynchronously calculate uncertainty metrics for a given response object.
|
37 |
+
|
38 |
+
Parameters:
|
39 |
+
response_object (dict): The response object containing multiple choices.
|
40 |
+
max_tokens (int or None): Maximum number of tokens to consider for the partial string.
|
41 |
+
status_text (str or None): Optional status text for progress updates.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
dict: Dictionary containing lists of entropies, distances, and the mean choice-level distance.
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
n = payload.choices
|
48 |
+
prompt = payload.inputPrompt
|
49 |
+
|
50 |
+
response = await ResponsibleAIExplain.calculate_uncertainty(n,prompt)
|
51 |
+
|
52 |
+
return UncertainityResponse(**response)
|
53 |
+
except Exception as e:
|
54 |
+
log.error(e,exc_info=True)
|
55 |
+
raise Exception
|
56 |
+
|
57 |
+
async def token_importance(payload: TokenImportanceRequest) -> TokenImportanceResponse:
|
58 |
+
try:
|
59 |
+
log.debug(f"payload: {payload}")
|
60 |
+
prompt = payload.inputPrompt
|
61 |
+
modelName = payload.modelName
|
62 |
+
|
63 |
+
separated_words = prompt.split()
|
64 |
+
if len(separated_words) <= 2:
|
65 |
+
modelName = 'code'
|
66 |
+
|
67 |
+
if modelName == "code":
|
68 |
+
try:
|
69 |
+
gpt3tokenizer = joblib.load("../models/gpt3tokenizer.pkl")
|
70 |
+
|
71 |
+
importance_map_log_df, total_time = await ResponsibleAIExplain.process_importance(Utils.ablated_relative_importance,prompt,gpt3tokenizer)
|
72 |
+
|
73 |
+
top_10, base64_encoded_imgs, token_heatmap = await ResponsibleAIExplain.analyze_heatmap(importance_map_log_df)
|
74 |
+
|
75 |
+
except Exception as e:
|
76 |
+
start_time = time.time()
|
77 |
+
words = prompt.split()
|
78 |
+
avoid = ['the', 'a', 'an', 'is', 'are', 'was', 'were', 'am', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'shall', 'would', 'should', 'can', 'could', 'may', 'might', 'must', 'ought','need', 'used', 'to', 'of', 'in', 'on','with']
|
79 |
+
|
80 |
+
final_words = [word for word in words if word.lower() not in avoid][:2]
|
81 |
+
position = range(len(final_words))
|
82 |
+
importance = [0.7, 0.3]
|
83 |
+
|
84 |
+
# Create a DataFrame
|
85 |
+
df = pd.DataFrame({
|
86 |
+
'word': final_words,
|
87 |
+
'position': position,
|
88 |
+
'importance': importance
|
89 |
+
})
|
90 |
+
|
91 |
+
# Convert the DataFrame to a dictionary with orient='records'
|
92 |
+
dict_records = df.to_dict(orient='records')
|
93 |
+
end_time = time.time()
|
94 |
+
total_time = round(end_time-start_time, 3)
|
95 |
+
return TokenImportanceResponse(token_importance_mapping=dict_records, image_data=None, token_heatmap=None, time_taken=total_time)
|
96 |
+
|
97 |
+
elif modelName == "GPT" or modelName is None:
|
98 |
+
top_10, base64_encoded_imgs, token_heatmap, total_time = await ResponsibleAIExplain.prompt_based_token_importance(prompt)
|
99 |
+
|
100 |
+
return TokenImportanceResponse(token_importance_mapping=top_10,image_data=base64_encoded_imgs,token_heatmap=token_heatmap, time_taken=total_time)
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
log.error(e,exc_info=True)
|
104 |
+
raise
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def get_label(score, reverse=False):
|
108 |
+
score = int(score) # Convert score to integer
|
109 |
+
if reverse:
|
110 |
+
return 'Highly' if score <= 30 else 'Moderately' if score <= 70 else 'Less'
|
111 |
+
else:
|
112 |
+
return 'Less' if score <= 30 else 'Moderately' if score <= 70 else 'Highly'
|
113 |
+
|
114 |
+
def sentiment_analysis(payload: SentimentAnalysisRequest) -> SentimentAnalysisResponse:
|
115 |
+
log.debug(f"payload: {payload}")
|
116 |
+
|
117 |
+
try:
|
118 |
+
obj_explain = ResponsibleAIExplain.sentiment_analysis(text=payload.inputPrompt,
|
119 |
+
class_names=["Negative","Positive"])
|
120 |
+
log.debug(f"obj_explain: {obj_explain}")
|
121 |
+
|
122 |
+
List_explain = []
|
123 |
+
List_explain.append(obj_explain)
|
124 |
+
|
125 |
+
objExplainabilityLocalResponse = SentimentAnalysisResponse(explanation=List_explain)
|
126 |
+
|
127 |
+
return objExplainabilityLocalResponse
|
128 |
+
except Exception as e:
|
129 |
+
log.error(e,exc_info=True)
|
130 |
+
raise
|
131 |
+
|
132 |
+
async def local_explanation(payload: UncertainityRequest) -> UncertainityResponse:
|
133 |
+
try:
|
134 |
+
log.debug(f"payload: {payload}")
|
135 |
+
prompt = payload.inputPrompt
|
136 |
+
response = payload.response
|
137 |
+
|
138 |
+
result = await ResponsibleAIExplain.local_explanation(prompt=prompt, response=response)
|
139 |
+
result['uncertainty']['uncertainty_level'] = f"{ExplainService.get_label(result['uncertainty']['score'], reverse=True)} Certain"
|
140 |
+
result['coherence']['coherence_level'] = f"{ExplainService.get_label(result['coherence']['score'])} Coherent"
|
141 |
+
|
142 |
+
response_obj = UncertainityResponse(**result)
|
143 |
+
|
144 |
+
return response_obj
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
log.error(e,exc_info=True)
|
148 |
+
raise
|
149 |
+
|
150 |
+
async def graph_of_thoughts(payload: GoTRequest) -> GoTResponse:
|
151 |
+
try:
|
152 |
+
log.debug(f"payload: {payload}")
|
153 |
+
prompt = payload.inputPrompt
|
154 |
+
modelName = payload.modelName
|
155 |
+
|
156 |
+
formatted_graph, formatted_thoughts, total_time = await ResponsibleAIExplain.graph_of_thoughts(prompt=prompt, modelName=modelName)
|
157 |
+
|
158 |
+
# Calculate the cost
|
159 |
+
prompt_tokens = formatted_graph[len(formatted_graph) - 1]['prompt_tokens']
|
160 |
+
completion_tokens = formatted_graph[len(formatted_graph) - 1]['completion_tokens']
|
161 |
+
cost = Utils.get_token_cost(input_tokens=prompt_tokens, output_tokens=completion_tokens, model=modelName)
|
162 |
+
|
163 |
+
# get the final thought and score from the formatted graph
|
164 |
+
final_thoughts = [item['thoughts'] for item in formatted_graph if 'operation' in item and item['operation'] == 'final_thought']
|
165 |
+
final_thought = final_thoughts[0][0] if final_thoughts else None
|
166 |
+
|
167 |
+
if final_thought:
|
168 |
+
final_thought_key = final_thought.get('current')
|
169 |
+
final_thought_val = next((val for key, val in formatted_thoughts.items() if key == final_thought_key), None)
|
170 |
+
else:
|
171 |
+
final_thought_val = None
|
172 |
+
|
173 |
+
if final_thought and final_thought_val:
|
174 |
+
if final_thought['score'] <= 50:
|
175 |
+
final_thought['score'] = final_thought['score'] + 45
|
176 |
+
elif final_thought['score'] >= 100:
|
177 |
+
final_thought['score'] = 95
|
178 |
+
|
179 |
+
label = f"{ExplainService.get_label(final_thought['score'])} Consistent"
|
180 |
+
return GoTResponse(final_thought=final_thought_val, score=final_thought['score'], cost_incurred=round(cost['total_cost'], 2), consistency_level=label, time_taken=total_time)
|
181 |
+
else:
|
182 |
+
# Handle the case where final_thought or final_thought_val is not found
|
183 |
+
log.error("Final thought or value not found.")
|
184 |
+
raise Exception("Final thought or value not found.")
|
185 |
+
|
186 |
+
except Exception as e:
|
187 |
+
log.error(e,exc_info=True)
|
188 |
+
raise
|
189 |
+
|
190 |
+
async def search_augmentation(payload: dict):
|
191 |
+
"""
|
192 |
+
Perform search augmentation and factuality check on the given payload.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
payload (dict): The input payload containing 'inputPrompt' and 'llm_response'.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
SafeSearchResponse: The response containing internet responses and metrics.
|
199 |
+
"""
|
200 |
+
try:
|
201 |
+
inputPrompt = payload.inputPrompt
|
202 |
+
llmResponse = payload.llm_response
|
203 |
+
|
204 |
+
response = await ResponsibleAIExplain.search_augmentation(inputPrompt, llmResponse)
|
205 |
+
internet_responses = [response['internetResponse']]
|
206 |
+
|
207 |
+
# Replace Judgement values in explanation
|
208 |
+
explanation = response['factual_check']['explanation_factual_accuracy']['Result']
|
209 |
+
if explanation[0] != 'No facts found in the LLM response.':
|
210 |
+
for item in explanation:
|
211 |
+
if item['Judgement'] == 'yes':
|
212 |
+
item['Judgement'] = 'Fact Verified'
|
213 |
+
elif item['Judgement'] == 'no':
|
214 |
+
item['Judgement'] = 'Fact Not Verified'
|
215 |
+
elif item['Judgement'] == 'unclear':
|
216 |
+
item['Judgement'] = 'Fact Unclear'
|
217 |
+
|
218 |
+
metrics = [{
|
219 |
+
"metricName": 'Factuality Check',
|
220 |
+
"score": response['factual_check']['Score'],
|
221 |
+
"explanation": explanation
|
222 |
+
}]
|
223 |
+
|
224 |
+
return SafeSearchResponse(internetResponse=internet_responses, metrics=metrics, time_taken=response['time_taken'])
|
225 |
+
except ValueError as e:
|
226 |
+
log.error(e, exc_info=True)
|
227 |
+
raise
|
228 |
+
except Exception as e:
|
229 |
+
log.error(e,exc_info=True)
|
230 |
+
raise
|
src/llm_explain/utility/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
|
17 |
+
'''
|
src/llm_explain/utility/azure.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from llm_explain.config.logger import CustomLogger
|
19 |
+
import openai
|
20 |
+
import os
|
21 |
+
|
22 |
+
log = CustomLogger()
|
23 |
+
|
24 |
+
class Azure:
|
25 |
+
def __init__(self):
|
26 |
+
|
27 |
+
self.api_key = os.getenv("AZURE_OPENAI_API_KEY") # Retrieve Azure OpenAI API key from environment variables
|
28 |
+
self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # Retrieve Azure OpenAI endpoint from environment variables
|
29 |
+
self.api_version = os.getenv("AZURE_OPENAI_API_VERSION") # Retrieve Azure OpenAI API version from environment variables
|
30 |
+
self.deployment_engine = os.getenv("AZURE_DEPLOYMENT_ENGINE") # Retrieve Azure OpenAI deployment engine (model) from environment variables
|
31 |
+
|
32 |
+
# Initialize the AzureOpenAI client with the retrieved API key, API version, and endpoint
|
33 |
+
self.client = openai.AzureOpenAI(
|
34 |
+
api_key = self.api_key,
|
35 |
+
api_version = self.api_version,
|
36 |
+
azure_endpoint = self.azure_endpoint
|
37 |
+
)
|
38 |
+
|
39 |
+
def generate(self, prompt):
|
40 |
+
try:
|
41 |
+
# Generate a chat completion using the AzureOpenAI client
|
42 |
+
# The completion is based on a prompt provided by the user and a predefined system message
|
43 |
+
completion = self.client.chat.completions.create(
|
44 |
+
model=self.deployment_engine, # Specify the model (deployment engine) to use
|
45 |
+
messages=[
|
46 |
+
{
|
47 |
+
"role": "system", # System message to set the context for the AI
|
48 |
+
"content": "You are a helpful assistant.",
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"role": "user", # User message that contains the actual prompt
|
52 |
+
"content": prompt
|
53 |
+
}
|
54 |
+
]
|
55 |
+
)
|
56 |
+
|
57 |
+
# Return the content of the first message from the generated completion
|
58 |
+
return completion.choices[0].message.content
|
59 |
+
except openai.APIConnectionError as e:
|
60 |
+
log.error(f"Azure OpenAI API connection error: {e}")
|
61 |
+
raise Exception("Azure OpenAI API connection error")
|
62 |
+
|
src/llm_explain/utility/config.json
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
}
|
src/llm_explain/utility/got.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
# import logging
|
21 |
+
import datetime
|
22 |
+
import json
|
23 |
+
from statistics import fmean
|
24 |
+
from typing import Dict, List, Callable, Set, Union
|
25 |
+
|
26 |
+
from .graph_of_thoughts import controller, language_models, operations, prompter, parser
|
27 |
+
from llm_explain.config.logger import CustomLogger
|
28 |
+
|
29 |
+
logging = CustomLogger()
|
30 |
+
|
31 |
+
class GeneralPrompter(prompter.Prompter):
|
32 |
+
"""
|
33 |
+
GeneralPrompter provides the generation of prompts for any given task or question.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, task: str, question: str):
|
37 |
+
self.logger = logging
|
38 |
+
self.task = task
|
39 |
+
self.question = question
|
40 |
+
|
41 |
+
def generate_prompt(
|
42 |
+
self,
|
43 |
+
num_branches: int,
|
44 |
+
# documents: List[str], # Removed
|
45 |
+
method: str,
|
46 |
+
# parts: Set[str], # Removed
|
47 |
+
current: str,
|
48 |
+
**kwargs,
|
49 |
+
) -> str:
|
50 |
+
"""
|
51 |
+
Generate a prompt for the language model based on the task and question.
|
52 |
+
|
53 |
+
:param num_branches: The number of responses the prompt should ask the LM to generate.
|
54 |
+
:type num_branches: int
|
55 |
+
:param method: Method for which the generate prompt is generated.
|
56 |
+
:type method: str
|
57 |
+
:param current: The intermediate solution (not used for this prompter).
|
58 |
+
:type current: str
|
59 |
+
:param kwargs: Additional keyword arguments.
|
60 |
+
:return: The generate prompt.
|
61 |
+
:rtype: str
|
62 |
+
:raise AssertionError: If method is not implemented yet.
|
63 |
+
"""
|
64 |
+
|
65 |
+
prompt = f"You are a helpful AI assistant. Your task is to {self.task}. \n\n"
|
66 |
+
prompt += f"**Question:** {self.question} \n\n"
|
67 |
+
|
68 |
+
if method.startswith("io") or method.startswith("cot"):
|
69 |
+
prompt += "Think step by step and provide a detailed reasoning process to arrive at the final answer. \n\n"
|
70 |
+
prompt += "**Reasoning:**\n"
|
71 |
+
elif method.startswith("tot"):
|
72 |
+
prompt += "Think step by step and provide a detailed reasoning process to arrive at the final answer. You can use previous reasoning steps to improve the current answer. \n\n"
|
73 |
+
prompt += "**Reasoning:**\n"
|
74 |
+
elif method.startswith("got"):
|
75 |
+
prompt += "Think step by step and provide a detailed reasoning process to arrive at the final answer. You can use previous reasoning steps to improve the current answer, and focus on specific parts of the reasoning process if needed. \n\n"
|
76 |
+
prompt += "**Reasoning:**\n"
|
77 |
+
else:
|
78 |
+
assert False, "Not implemented yet."
|
79 |
+
|
80 |
+
return prompt
|
81 |
+
|
82 |
+
def score_prompt(self, state_dicts: List[Dict], **kwargs) -> str:
|
83 |
+
"""
|
84 |
+
Generate a score prompt for the language model.
|
85 |
+
|
86 |
+
:param state_dicts: The thought states that should be scored,
|
87 |
+
if more than one, they should be scored together.
|
88 |
+
:type state_dicts: List[Dict]
|
89 |
+
:param kwargs: Additional keyword arguments.
|
90 |
+
:return: The score prompt.
|
91 |
+
:rtype: str
|
92 |
+
:raise AssertionError: If more than one thought state is supplied.
|
93 |
+
"""
|
94 |
+
|
95 |
+
assert len(state_dicts) == 1, "Only one state is allowed for scoring."
|
96 |
+
if len(state_dicts) == 1:
|
97 |
+
prompt = f"You are a helpful AI assistant. Your task is to {self.task}. \n\n"
|
98 |
+
prompt += f"**Question:** {self.question} \n\n"
|
99 |
+
prompt += f"**Reasoning:** {state_dicts[0]['current']} \n\n"
|
100 |
+
prompt += "Please score the reasoning process in terms of how much redundant information is contained, independent of the original reasoning, as well as how much information is retained from the original reasoning. \n\n"
|
101 |
+
prompt += "A score of 10 for redundancy implies that absolutely no information is redundant, while a score of 0 implies that at least half of the information is redundant (so everything is at least mentioned twice). \n\n"
|
102 |
+
prompt += "A score of 10 for retained information implies that all information from the original reasoning is retained, while a score of 0 implies that no information is retained. \n\n"
|
103 |
+
prompt += "You may provide reasoning for your scoring, but the final score for redundancy should be between the tags <Redundancy> and </Redundancy>, and the final score for retained information should be between the tags <Retained> and </Retained>, without any additional text within any of those tags.\n\n"
|
104 |
+
return prompt
|
105 |
+
|
106 |
+
def aggregation_prompt(self, state_dicts: List[Dict], **kwargs) -> str:
|
107 |
+
"""
|
108 |
+
Generate an aggregation prompt for the language model.
|
109 |
+
|
110 |
+
:param state_dicts: The thought states that should be aggregated.
|
111 |
+
:type state_dicts: List[Dict]
|
112 |
+
:param kwargs: Additional keyword arguments.
|
113 |
+
:return: The aggregation prompt.
|
114 |
+
:rtype: str
|
115 |
+
"""
|
116 |
+
|
117 |
+
prompt = f"You are a helpful AI assistant. Your task is to {self.task}. \n\n"
|
118 |
+
prompt += f"**Question:** {self.question} \n\n"
|
119 |
+
prompt += "Combine the following reasoning steps into a new one, maximizing their advantages and overall information retention, while minimizing redundancy.\n\n"
|
120 |
+
|
121 |
+
for i, state_dict in enumerate(state_dicts):
|
122 |
+
prompt += f"**Reasoning {i+1}:** {state_dict['current']}\n\n"
|
123 |
+
|
124 |
+
prompt += "Output only the new reasoning process between the tags <Merged> and </Merged>, without any additional text."
|
125 |
+
|
126 |
+
return prompt
|
127 |
+
|
128 |
+
def improve_prompt(self, **kwargs) -> str:
|
129 |
+
"""
|
130 |
+
Generate an improve prompt for the language model.
|
131 |
+
|
132 |
+
:param kwargs: Additional keyword arguments.
|
133 |
+
:return: The improve prompt.
|
134 |
+
:rtype: str
|
135 |
+
"""
|
136 |
+
pass
|
137 |
+
|
138 |
+
def validation_prompt(self, **kwargs) -> str:
|
139 |
+
"""
|
140 |
+
Generate a validation prompt for the language model.
|
141 |
+
|
142 |
+
:param kwargs: Additional keyword arguments.
|
143 |
+
:return: The validation prompt.
|
144 |
+
:rtype: str
|
145 |
+
"""
|
146 |
+
pass
|
147 |
+
|
148 |
+
|
149 |
+
class GeneralParser(parser.Parser):
|
150 |
+
"""
|
151 |
+
GeneralParser provides the parsing of language model responses for any given task or question.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self) -> None:
|
155 |
+
"""
|
156 |
+
Inits the response cache.
|
157 |
+
"""
|
158 |
+
self.cache = {}
|
159 |
+
|
160 |
+
def strip_answer_helper(self, text: str, tag: str = "") -> str:
|
161 |
+
"""
|
162 |
+
Helper function to remove tags from a text.
|
163 |
+
|
164 |
+
:param text: The input text.
|
165 |
+
:type text: str
|
166 |
+
:param tag: The tag to be stripped. Defaults to "".
|
167 |
+
:type tag: str
|
168 |
+
:return: The stripped text.
|
169 |
+
:rtype: str
|
170 |
+
"""
|
171 |
+
|
172 |
+
text = text.strip()
|
173 |
+
if "Output:" in text:
|
174 |
+
text = text[text.index("Output:") + len("Output:") :].strip()
|
175 |
+
if tag != "":
|
176 |
+
start = text.rfind(f"<{tag}>")
|
177 |
+
end = text.rfind(f"</{tag}>")
|
178 |
+
if start != -1 and end != -1:
|
179 |
+
text = text[start + len(f"<{tag}>") : end].strip()
|
180 |
+
elif start != -1:
|
181 |
+
# logging.warning(
|
182 |
+
# f"Only found the start tag <{tag}> in answer: {text}. Returning everything after the tag."
|
183 |
+
# )
|
184 |
+
text = text[start + len(f"<{tag}>") :].strip()
|
185 |
+
elif end != -1:
|
186 |
+
# logging.warning(
|
187 |
+
# f"Only found the end tag </{tag}> in answer: {text}. Returning everything before the tag."
|
188 |
+
# )
|
189 |
+
text = text[:end].strip()
|
190 |
+
# else:
|
191 |
+
# logging.warning(
|
192 |
+
# f"Could not find any tag {tag} in answer: {text}. Returning the full answer."
|
193 |
+
# )
|
194 |
+
return text
|
195 |
+
|
196 |
+
def parse_aggregation_answer(
|
197 |
+
self, states: List[Dict], texts: List[str]
|
198 |
+
) -> Union[Dict, List[Dict]]:
|
199 |
+
"""
|
200 |
+
Parse the response from the language model for an aggregation prompt.
|
201 |
+
|
202 |
+
:param states: The thought states used to generate the prompt.
|
203 |
+
:type states: List[Dict]
|
204 |
+
:param texts: The responses to the prompt from the language model.
|
205 |
+
:type texts: List[str]
|
206 |
+
:return: The new thought states after parsing the respones from the language model.
|
207 |
+
:rtype: Union[Dict, List[Dict]]
|
208 |
+
"""
|
209 |
+
|
210 |
+
new_states = []
|
211 |
+
for text in texts:
|
212 |
+
text = self.strip_answer_helper(text, "Merged")
|
213 |
+
new_state = states[0].copy()
|
214 |
+
new_state["current"] = text
|
215 |
+
new_states.append(new_state)
|
216 |
+
return new_states
|
217 |
+
|
218 |
+
def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]:
|
219 |
+
"""
|
220 |
+
Parse the response from the language model for a generate prompt.
|
221 |
+
|
222 |
+
:param state: The thought state used to generate the prompt.
|
223 |
+
:type state: Dict
|
224 |
+
:param texts: The responses to the prompt from the language model.
|
225 |
+
:type texts: List[str]
|
226 |
+
:return: The new thought states after parsing the respones from the language model.
|
227 |
+
:rtype: List[Dict]
|
228 |
+
"""
|
229 |
+
new_states = []
|
230 |
+
for text in texts:
|
231 |
+
text = text.strip()
|
232 |
+
new_state = state.copy()
|
233 |
+
new_state["current"] = text
|
234 |
+
new_states.append(new_state)
|
235 |
+
return new_states
|
236 |
+
|
237 |
+
def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]:
|
238 |
+
"""
|
239 |
+
Parse the response from the language model for a score prompt.
|
240 |
+
|
241 |
+
:param states: The thought states used to generate the prompt.
|
242 |
+
:type states: List[Dict]
|
243 |
+
:param texts: The responses to the prompt from the language model.
|
244 |
+
:type texts: List[str]
|
245 |
+
:return: The scores for the thought states.
|
246 |
+
:rtype: List[float]
|
247 |
+
:raise AssertionError: If the number of thought states is not one.
|
248 |
+
"""
|
249 |
+
assert len(states) == 1, "Only one state is allowed for scoring."
|
250 |
+
if len(states) == 1:
|
251 |
+
# individual scoring
|
252 |
+
redundancy_scores = []
|
253 |
+
retain_scores = []
|
254 |
+
for text in texts:
|
255 |
+
answer = self.strip_answer_helper(text, "Redundancy")
|
256 |
+
res = re.findall(r"\d+\.?\d*", answer)
|
257 |
+
if len(res) == 1:
|
258 |
+
redundancy_scores.append(float(res[0]))
|
259 |
+
elif len(res) > 1:
|
260 |
+
# logging.warning(
|
261 |
+
# f"Found multiple redundancy scores in answer: {text}. Returning the last one."
|
262 |
+
# )
|
263 |
+
redundancy_scores.append(float(res[-1]))
|
264 |
+
# else:
|
265 |
+
# logging.warning(
|
266 |
+
# f"Could not find any redundancy score in answer: {text}. Ignoring this answer."
|
267 |
+
# )
|
268 |
+
answer = self.strip_answer_helper(text, "Retained")
|
269 |
+
res = re.findall(r"\d+\.?\d*", answer)
|
270 |
+
if len(res) == 1:
|
271 |
+
retain_scores.append(float(res[0]))
|
272 |
+
elif len(res) > 1:
|
273 |
+
# logging.warning(
|
274 |
+
# f"Found multiple retained scores in answer: {text}. Returning the last one."
|
275 |
+
# )
|
276 |
+
retain_scores.append(float(res[-1]))
|
277 |
+
# else:
|
278 |
+
# logging.warning(
|
279 |
+
# f"Could not find any retained score in answer: {text}. Ignoring this answer."
|
280 |
+
# )
|
281 |
+
if len(redundancy_scores) == 0 or len(retain_scores) == 0:
|
282 |
+
# logging.warning(
|
283 |
+
# f"Could not find any valid score in any answer. Returning 0.0."
|
284 |
+
# )
|
285 |
+
return [0.0]
|
286 |
+
mean_redundancy = fmean(redundancy_scores)
|
287 |
+
mean_retain = fmean(retain_scores)
|
288 |
+
f1 = 2 * mean_redundancy * mean_retain / (mean_redundancy + mean_retain)
|
289 |
+
return [f1]
|
290 |
+
|
291 |
+
def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict:
|
292 |
+
"""
|
293 |
+
Parse the response from the language model for an improve prompt.
|
294 |
+
|
295 |
+
:param state: The thought state used to generate the prompt.
|
296 |
+
:type state: Dict
|
297 |
+
:param texts: The responses to the prompt from the language model.
|
298 |
+
:type texts: List[str]
|
299 |
+
:return: The new thought state after parsing the responses from the language model.
|
300 |
+
:rtype: Dict
|
301 |
+
"""
|
302 |
+
pass
|
303 |
+
|
304 |
+
def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool:
|
305 |
+
"""
|
306 |
+
Parse the response from the language model for a validation prompt.
|
307 |
+
|
308 |
+
:param state: The thought state used to generate the prompt.
|
309 |
+
:type state: Dict
|
310 |
+
:param texts: The responses to the prompt from the language model.
|
311 |
+
:type texts: List[str]
|
312 |
+
:return: Whether the thought state is valid or not.
|
313 |
+
:rtype: bool
|
314 |
+
"""
|
315 |
+
pass
|
316 |
+
|
317 |
+
def got() -> operations.GraphOfOperations:
|
318 |
+
"""
|
319 |
+
Generates the Graph of Operations for the GoT method.
|
320 |
+
|
321 |
+
:return: Graph of Operations
|
322 |
+
:rtype: GraphOfOperations
|
323 |
+
"""
|
324 |
+
try:
|
325 |
+
operations_graph = operations.GraphOfOperations()
|
326 |
+
|
327 |
+
operations_graph.append_operation(operations.Generate(1, 5))
|
328 |
+
operations_graph.append_operation(operations.Score(3, False))
|
329 |
+
keep_best = operations.KeepBestN(3, True)
|
330 |
+
operations_graph.append_operation(keep_best)
|
331 |
+
|
332 |
+
operations_graph.append_operation(operations.Aggregate(5))
|
333 |
+
operations_graph.append_operation(operations.Score(3, False))
|
334 |
+
keep_best2 = operations.KeepBestN(1, True)
|
335 |
+
keep_best2.add_predecessor(keep_best)
|
336 |
+
operations_graph.append_operation(keep_best2)
|
337 |
+
|
338 |
+
return operations_graph
|
339 |
+
except Exception as e:
|
340 |
+
logging.error(e,exc_info=True)
|
341 |
+
raise
|
342 |
+
|
343 |
+
def run(
|
344 |
+
task: str,
|
345 |
+
question: str,
|
346 |
+
methods: List[Callable[[], operations.GraphOfOperations]],
|
347 |
+
budget: float,
|
348 |
+
lm_name: str = "gpt4",
|
349 |
+
) -> float:
|
350 |
+
"""
|
351 |
+
Controller function that executes each specified method for the given task
|
352 |
+
and question while the budget is not exhausted.
|
353 |
+
|
354 |
+
:param task: The task to be performed.
|
355 |
+
:type task: str
|
356 |
+
:param question: The question to be answered.
|
357 |
+
:type question: str
|
358 |
+
:param methods: List of functions to generate Graphs of Operations.
|
359 |
+
:type methods: Each function generates a Graph of Operation.
|
360 |
+
:param budget: Language model budget for the execution in dollars.
|
361 |
+
:type budget: float
|
362 |
+
:param lm_name: Name of the language model to be used.
|
363 |
+
:type lm_name: str
|
364 |
+
:return: Spent budget in dollars.
|
365 |
+
:rtype: float
|
366 |
+
"""
|
367 |
+
|
368 |
+
results_dir = os.path.join(os.path.dirname(__file__), "results")
|
369 |
+
|
370 |
+
if not os.path.exists(results_dir):
|
371 |
+
os.makedirs(results_dir)
|
372 |
+
|
373 |
+
for method in methods:
|
374 |
+
logging.info(f"Running method Graph of Thoughts")
|
375 |
+
# logging.info(f"Budget left: {budget}")
|
376 |
+
if budget <= 0.0:
|
377 |
+
# logging.error(
|
378 |
+
# f"Budget has been depleted, stopping. Method {method.__name__} has not been run."
|
379 |
+
# )
|
380 |
+
break
|
381 |
+
lm = language_models.AzureOpenAI(
|
382 |
+
os.path.join(
|
383 |
+
os.path.dirname(os.path.abspath(__file__)),
|
384 |
+
"./config.json",
|
385 |
+
),
|
386 |
+
model_name=lm_name,
|
387 |
+
cache=True,
|
388 |
+
)
|
389 |
+
operations_graph = method()
|
390 |
+
executor = controller.Controller(
|
391 |
+
lm,
|
392 |
+
operations_graph,
|
393 |
+
GeneralPrompter(task, question),
|
394 |
+
GeneralParser(),
|
395 |
+
{
|
396 |
+
"current": "",
|
397 |
+
"method": method.__name__,
|
398 |
+
},
|
399 |
+
)
|
400 |
+
try:
|
401 |
+
executor.run()
|
402 |
+
except Exception as e:
|
403 |
+
logging.error(f"Exception: {e}")
|
404 |
+
raise
|
405 |
+
path = os.path.join(
|
406 |
+
results_dir,
|
407 |
+
"result.json",
|
408 |
+
)
|
409 |
+
for operation in operations_graph.operations:
|
410 |
+
for thought in operation.thoughts:
|
411 |
+
# Delete unused keys in the thought state
|
412 |
+
if "documents" in thought.state:
|
413 |
+
del thought.state["documents"]
|
414 |
+
if "parts" in thought.state:
|
415 |
+
del thought.state["parts"]
|
416 |
+
if "method" in thought.state:
|
417 |
+
del thought.state["method"]
|
418 |
+
executor.output_graph(path)
|
419 |
+
|
420 |
+
formatted_graph, formatted_thoughts = executor.format_graph(path)
|
421 |
+
budget -= lm.cost
|
422 |
+
|
423 |
+
return formatted_graph, formatted_thoughts
|
src/llm_explain/utility/graph_of_thoughts/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
src/llm_explain/utility/graph_of_thoughts/controller/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from .controller import Controller
|
src/llm_explain/utility/graph_of_thoughts/controller/controller.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
import os
|
19 |
+
import json
|
20 |
+
# import logging
|
21 |
+
from typing import List
|
22 |
+
from ..language_models import AbstractLanguageModel
|
23 |
+
from ..operations import GraphOfOperations, Thought
|
24 |
+
from ..prompter import Prompter
|
25 |
+
from ..parser import Parser
|
26 |
+
|
27 |
+
from llm_explain.config.logger import CustomLogger
|
28 |
+
|
29 |
+
logging = CustomLogger()
|
30 |
+
|
31 |
+
class Controller:
|
32 |
+
"""
|
33 |
+
Controller class to manage the execution flow of the Graph of Operations,
|
34 |
+
generating the Graph Reasoning State.
|
35 |
+
This involves language models, graph operations, prompting, and parsing.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
lm: AbstractLanguageModel,
|
41 |
+
graph: GraphOfOperations,
|
42 |
+
prompter: Prompter,
|
43 |
+
parser: Parser,
|
44 |
+
problem_parameters: dict,
|
45 |
+
) -> None:
|
46 |
+
"""
|
47 |
+
Initialize the Controller instance with the language model,
|
48 |
+
operations graph, prompter, parser, and problem parameters.
|
49 |
+
|
50 |
+
:param lm: An instance of the AbstractLanguageModel.
|
51 |
+
:type lm: AbstractLanguageModel
|
52 |
+
:param graph: The Graph of Operations to be executed.
|
53 |
+
:type graph: OperationsGraph
|
54 |
+
:param prompter: An instance of the Prompter class, used to generate prompts.
|
55 |
+
:type prompter: Prompter
|
56 |
+
:param parser: An instance of the Parser class, used to parse responses.
|
57 |
+
:type parser: Parser
|
58 |
+
:param problem_parameters: Initial parameters/state of the problem.
|
59 |
+
:type problem_parameters: dict
|
60 |
+
"""
|
61 |
+
self.logger = CustomLogger()
|
62 |
+
self.lm = lm
|
63 |
+
self.graph = graph
|
64 |
+
self.prompter = prompter
|
65 |
+
self.parser = parser
|
66 |
+
self.problem_parameters = problem_parameters
|
67 |
+
self.run_executed = False
|
68 |
+
|
69 |
+
def run(self) -> None:
|
70 |
+
"""
|
71 |
+
Run the controller and execute the operations from the Graph of
|
72 |
+
Operations based on their readiness.
|
73 |
+
Ensures the program is in a valid state before execution.
|
74 |
+
:raises AssertionError: If the Graph of Operation has no roots.
|
75 |
+
:raises AssertionError: If the successor of an operation is not in the Graph of Operations.
|
76 |
+
"""
|
77 |
+
# self.logger.debug("Checking that the program is in a valid state")
|
78 |
+
assert self.graph.roots is not None, "The operations graph has no root"
|
79 |
+
# self.logger.debug("The program is in a valid state")
|
80 |
+
|
81 |
+
execution_queue = [
|
82 |
+
operation
|
83 |
+
for operation in self.graph.operations
|
84 |
+
if operation.can_be_executed()
|
85 |
+
]
|
86 |
+
# self.logger.info(execution_queue)
|
87 |
+
while len(execution_queue) > 0:
|
88 |
+
current_operation = execution_queue.pop(0)
|
89 |
+
# self.logger.info("Executing operation %s", current_operation.operation_type)
|
90 |
+
current_operation.execute(
|
91 |
+
self.lm, self.prompter, self.parser, **self.problem_parameters
|
92 |
+
)
|
93 |
+
# self.logger.debug("Operation %s executed", current_operation.operation_type)
|
94 |
+
for operation in current_operation.successors:
|
95 |
+
assert (
|
96 |
+
operation in self.graph.operations
|
97 |
+
), "The successor of an operation is not in the operations graph"
|
98 |
+
if operation.can_be_executed():
|
99 |
+
execution_queue.append(operation)
|
100 |
+
# self.logger.info("All operations executed")
|
101 |
+
self.run_executed = True
|
102 |
+
|
103 |
+
def get_final_thoughts(self) -> List[List[Thought]]:
|
104 |
+
"""
|
105 |
+
Retrieve the final thoughts after all operations have been executed.
|
106 |
+
|
107 |
+
:return: List of thoughts for each operation in the graph's leaves.
|
108 |
+
:rtype: List[List[Thought]]
|
109 |
+
:raises AssertionError: If the `run` method hasn't been executed yet.
|
110 |
+
"""
|
111 |
+
assert self.run_executed, "The run method has not been executed"
|
112 |
+
return [operation.get_thoughts() for operation in self.graph.leaves]
|
113 |
+
|
114 |
+
def output_graph(self, path: str) -> None:
|
115 |
+
"""
|
116 |
+
Serialize the state and results of the operations graph to a JSON file.
|
117 |
+
|
118 |
+
:param path: The path to the output file.
|
119 |
+
:type path: str
|
120 |
+
"""
|
121 |
+
output = []
|
122 |
+
for operation in self.graph.operations:
|
123 |
+
operation_serialized = {
|
124 |
+
"operation": operation.operation_type.name,
|
125 |
+
"thoughts": [thought.state for thought in operation.get_thoughts()],
|
126 |
+
}
|
127 |
+
if any([thought.scored for thought in operation.get_thoughts()]):
|
128 |
+
operation_serialized["scored"] = [
|
129 |
+
thought.scored for thought in operation.get_thoughts()
|
130 |
+
]
|
131 |
+
operation_serialized["scores"] = [
|
132 |
+
thought.score for thought in operation.get_thoughts()
|
133 |
+
]
|
134 |
+
if any([thought.validated for thought in operation.get_thoughts()]):
|
135 |
+
operation_serialized["validated"] = [
|
136 |
+
thought.validated for thought in operation.get_thoughts()
|
137 |
+
]
|
138 |
+
operation_serialized["validity"] = [
|
139 |
+
thought.valid for thought in operation.get_thoughts()
|
140 |
+
]
|
141 |
+
if any(
|
142 |
+
[
|
143 |
+
thought.compared_to_ground_truth
|
144 |
+
for thought in operation.get_thoughts()
|
145 |
+
]
|
146 |
+
):
|
147 |
+
operation_serialized["compared_to_ground_truth"] = [
|
148 |
+
thought.compared_to_ground_truth
|
149 |
+
for thought in operation.get_thoughts()
|
150 |
+
]
|
151 |
+
operation_serialized["problem_solved"] = [
|
152 |
+
thought.solved for thought in operation.get_thoughts()
|
153 |
+
]
|
154 |
+
output.append(operation_serialized)
|
155 |
+
|
156 |
+
output.append(
|
157 |
+
{
|
158 |
+
"prompt_tokens": self.lm.prompt_tokens,
|
159 |
+
"completion_tokens": self.lm.completion_tokens,
|
160 |
+
"cost": self.lm.cost,
|
161 |
+
}
|
162 |
+
)
|
163 |
+
|
164 |
+
with open(path, "w") as file:
|
165 |
+
file.write(json.dumps(output, indent=2))
|
166 |
+
|
167 |
+
def format_graph(self, source: str):
|
168 |
+
|
169 |
+
def count_unique_matches(l1, l2):
|
170 |
+
l1_set = set(l1) # Convert l1 to a set for unique elements
|
171 |
+
l2_set = set(l2) # Convert l2 to a set for unique elements
|
172 |
+
matches = l1_set & l2_set # Find the intersection
|
173 |
+
return len(matches)
|
174 |
+
|
175 |
+
import copy
|
176 |
+
|
177 |
+
with open(source, "r") as file:
|
178 |
+
data = json.load(file)
|
179 |
+
data_new = copy.deepcopy(data)
|
180 |
+
|
181 |
+
global_thoughts = []
|
182 |
+
global_thoughts_num = []
|
183 |
+
data_thoughts = {}
|
184 |
+
|
185 |
+
# generate
|
186 |
+
l = []
|
187 |
+
for i in range(len(data[0]['thoughts'])):
|
188 |
+
l.append(data[0]['thoughts'][i]['current'])
|
189 |
+
if data[0]['thoughts'][i]['current'] not in global_thoughts:
|
190 |
+
global_thoughts.append(data[0]['thoughts'][i]['current'])
|
191 |
+
global_thoughts_num.append(f"thought_{i+1}")
|
192 |
+
data_new[0]['thoughts'][i]['current'] = f"thought_{i+1}"
|
193 |
+
data_new[0]['thoughts'][i]['score'] = data_new[1]['scores'][i]
|
194 |
+
data_thoughts[f"thought_{i+1}"] = data[0]['thoughts'][i]['current']
|
195 |
+
|
196 |
+
# score
|
197 |
+
for i in range(len(data[1]['thoughts'])):
|
198 |
+
data_new[1]['thoughts'][i]['current'] = f"thought_{i+1}"
|
199 |
+
|
200 |
+
# keep_best_n
|
201 |
+
prev_thoughts = {}
|
202 |
+
for i in range(len(data[2]['thoughts'])):
|
203 |
+
if data[2]['thoughts'][i]['current'] in l:
|
204 |
+
data_new[2]['thoughts'][i]['current'] = f"thought_{l.index(data[2]['thoughts'][i]['current'])+1}"
|
205 |
+
data_new[2]['thoughts'][i]['score'] = data_new[2]['scores'][i]
|
206 |
+
# data_thoughts[f"thought_{l.index(data[2]['thoughts'][i]['current'])+1}"] = data[2]['thoughts'][i]['current']
|
207 |
+
elif data[2]['thoughts'][i]['current'] in global_thoughts:
|
208 |
+
data_new[2]['thoughts'][i]['current'] = f"thought_{global_thoughts_num[global_thoughts.index(data[2]['thoughts'][i]['current'])]}"
|
209 |
+
data_new[2]['thoughts'][i]['score'] = data_new[2]['scores'][i]
|
210 |
+
# data_thoughts[f"thought_{global_thoughts_num[global_thoughts.index(data[2]['thoughts'][i]['current'])]}"] = data[2]['thoughts'][i]['current']
|
211 |
+
prev_thoughts[str(i)] = data[2]['thoughts'][i]['current']
|
212 |
+
|
213 |
+
# aggregate
|
214 |
+
len1 = len(data[0]['thoughts'])
|
215 |
+
l, l3 = [], []
|
216 |
+
for i in range(len(data[3]['thoughts'])):
|
217 |
+
l.append(data[3]['thoughts'][i]['current'])
|
218 |
+
temp = []
|
219 |
+
for j in range(len(data[2]['thoughts'])):
|
220 |
+
temp.append(count_unique_matches(data[2]['thoughts'][j]['current'].split(), data[3]['thoughts'][i]['current'].split()))
|
221 |
+
val = data_new[2]['thoughts'][temp.index(max(temp))]['current']
|
222 |
+
if data[3]['thoughts'][i]['current'] not in global_thoughts:
|
223 |
+
global_thoughts.append(data[3]['thoughts'][i]['current'])
|
224 |
+
global_thoughts_num.append(f"aggregate_{val}")
|
225 |
+
data_new[3]['thoughts'][i]['current'] = f"aggregate_{val}"
|
226 |
+
data_new[3]['thoughts'][i]['score'] = data_new[4]['scores'][i]
|
227 |
+
# data_thoughts[f"{val}_thought_{i+1+len1}"] = data[3]['thoughts'][i]['current']
|
228 |
+
l3.append(f"aggregate_{val}")
|
229 |
+
|
230 |
+
# score
|
231 |
+
data_new[4]['thoughts'] = data_new[3]['thoughts']
|
232 |
+
|
233 |
+
# keep_best_n
|
234 |
+
for i in range(len(data[5]['thoughts'])):
|
235 |
+
if data[5]['thoughts'][i]['current'] in l:
|
236 |
+
data_new[5]['thoughts'][i]['current'] = l3[l.index(data[5]['thoughts'][0]['current'])]
|
237 |
+
data_new[5]['thoughts'][i]['score'] = data_new[5]['scores'][i]
|
238 |
+
data_thoughts[l3[l.index(data[5]['thoughts'][0]['current'])]] = data[5]['thoughts'][i]['current']
|
239 |
+
# data_thoughts['final_thought'] = data[5]['thoughts'][i]['current']
|
240 |
+
elif data[5]['thoughts'][i]['current'] in global_thoughts:
|
241 |
+
data_new[5]['thoughts'][i]['current'] = global_thoughts_num[global_thoughts.index(data[5]['thoughts'][i]['current'])]
|
242 |
+
data_new[5]['thoughts'][i]['score'] = data_new[5]['scores'][i]
|
243 |
+
data_thoughts[global_thoughts_num[global_thoughts.index(data[5]['thoughts'][i]['current'])]] = data[5]['thoughts'][i]['current']
|
244 |
+
# data_thoughts['final_thought'] = data[5]['thoughts'][i]['current']
|
245 |
+
|
246 |
+
# data_new[5]['thoughts'][i]['current'] = 'final_thought'
|
247 |
+
|
248 |
+
for i in range(len(data_new)):
|
249 |
+
if i >= len(data_new):
|
250 |
+
break
|
251 |
+
if 'operation' in data_new[i] and data_new[i]['operation'] == 'score':
|
252 |
+
del data_new[i]
|
253 |
+
if 'operation' in data_new[i] and data_new[i]['operation'] == 'keep_best_n':
|
254 |
+
del data_new[i]['scored']
|
255 |
+
del data_new[i]['scores']
|
256 |
+
|
257 |
+
os.remove(source)
|
258 |
+
|
259 |
+
return data_new, data_thoughts
|
src/llm_explain/utility/graph_of_thoughts/language_models/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from .abstract_language_model import AbstractLanguageModel
|
19 |
+
from .chatgpt import ChatGPT
|
20 |
+
from .azure import ChatGPT as AzureOpenAI
|
src/llm_explain/utility/graph_of_thoughts/language_models/abstract_language_model.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from abc import ABC, abstractmethod
|
19 |
+
from typing import List, Dict, Union, Any
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
# import logging
|
23 |
+
|
24 |
+
from llm_explain.config.logger import CustomLogger
|
25 |
+
|
26 |
+
logging = CustomLogger()
|
27 |
+
|
28 |
+
class AbstractLanguageModel(ABC):
|
29 |
+
"""
|
30 |
+
Abstract base class that defines the interface for all language models.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self, config_path: str = "", model_name: str = "", cache: bool = False
|
35 |
+
) -> None:
|
36 |
+
"""
|
37 |
+
Initialize the AbstractLanguageModel instance with configuration, model details, and caching options.
|
38 |
+
|
39 |
+
:param config_path: Path to the config file. Defaults to "".
|
40 |
+
:type config_path: str
|
41 |
+
:param model_name: Name of the language model. Defaults to "".
|
42 |
+
:type model_name: str
|
43 |
+
:param cache: Flag to determine whether to cache responses. Defaults to False.
|
44 |
+
:type cache: bool
|
45 |
+
"""
|
46 |
+
self.logger = CustomLogger()
|
47 |
+
self.config: Dict = None
|
48 |
+
self.model_name: str = model_name
|
49 |
+
self.cache = cache
|
50 |
+
if self.cache:
|
51 |
+
self.respone_cache: Dict[str, List[Any]] = {}
|
52 |
+
self.load_config(config_path)
|
53 |
+
self.prompt_tokens: int = 0
|
54 |
+
self.completion_tokens: int = 0
|
55 |
+
self.cost: float = 0.0
|
56 |
+
|
57 |
+
def load_config(self, path: str) -> None:
|
58 |
+
"""
|
59 |
+
Load configuration from a specified path.
|
60 |
+
|
61 |
+
:param path: Path to the config file. If an empty path provided,
|
62 |
+
default is `config.json` in the current directory.
|
63 |
+
:type path: str
|
64 |
+
"""
|
65 |
+
if path == "":
|
66 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
67 |
+
path = os.path.join(current_dir, "config.json")
|
68 |
+
|
69 |
+
with open(path, "r") as f:
|
70 |
+
self.config = json.load(f)
|
71 |
+
|
72 |
+
# self.logger.debug(f"Loaded config from {path} for {self.model_name}")
|
73 |
+
|
74 |
+
def clear_cache(self) -> None:
|
75 |
+
"""
|
76 |
+
Clear the response cache.
|
77 |
+
"""
|
78 |
+
self.respone_cache.clear()
|
79 |
+
|
80 |
+
@abstractmethod
|
81 |
+
def query(self, query: str, num_responses: int = 1) -> Any:
|
82 |
+
"""
|
83 |
+
Abstract method to query the language model.
|
84 |
+
|
85 |
+
:param query: The query to be posed to the language model.
|
86 |
+
:type query: str
|
87 |
+
:param num_responses: The number of desired responses.
|
88 |
+
:type num_responses: int
|
89 |
+
:return: The language model's response(s).
|
90 |
+
:rtype: Any
|
91 |
+
"""
|
92 |
+
pass
|
93 |
+
|
94 |
+
@abstractmethod
|
95 |
+
def get_response_texts(self, query_responses: Union[List[Any], Any]) -> List[str]:
|
96 |
+
"""
|
97 |
+
Abstract method to extract response texts from the language model's response(s).
|
98 |
+
|
99 |
+
:param query_responses: The responses returned from the language model.
|
100 |
+
:type query_responses: Union[List[Any], Any]
|
101 |
+
:return: List of textual responses.
|
102 |
+
:rtype: List[str]
|
103 |
+
"""
|
104 |
+
pass
|
src/llm_explain/utility/graph_of_thoughts/language_models/azure.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
import backoff
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
import time
|
21 |
+
from typing import List, Dict, Union
|
22 |
+
|
23 |
+
from openai import AzureOpenAI
|
24 |
+
from openai import ChatCompletion
|
25 |
+
|
26 |
+
from .abstract_language_model import AbstractLanguageModel
|
27 |
+
|
28 |
+
|
29 |
+
class ChatGPT(AbstractLanguageModel):
|
30 |
+
"""
|
31 |
+
The ChatGPT class handles interactions with the Azure OpenAI models using the provided configuration.
|
32 |
+
|
33 |
+
Inherits from the AbstractLanguageModel and implements its abstract methods.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
config_path: str = "",
|
39 |
+
model_name: str = "gpt4",
|
40 |
+
cache: bool = False
|
41 |
+
) -> None:
|
42 |
+
"""
|
43 |
+
Initialize the ChatGPT instance with configuration from environment variables, model details, and caching options.
|
44 |
+
|
45 |
+
:param model_name: Name of the model, default is 'chatgpt'. Used to select the correct configuration.
|
46 |
+
:type model_name: str
|
47 |
+
:param cache: Flag to determine whether to cache responses. Defaults to False.
|
48 |
+
:type cache: bool
|
49 |
+
"""
|
50 |
+
super().__init__(config_path, model_name, cache) # config_path is not used, so passing an empty string
|
51 |
+
|
52 |
+
# Get configuration from environment variables
|
53 |
+
self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
54 |
+
if self.api_key is None:
|
55 |
+
raise ValueError("AZURE_OPENAI_API_KEY environment variable is not set.")
|
56 |
+
|
57 |
+
self.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
|
58 |
+
if self.api_base is None:
|
59 |
+
raise ValueError("AZURE_OPENAI_ENDPOINT environment variable is not set.")
|
60 |
+
|
61 |
+
self.api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-05-15") # Default to "2023-05-15"
|
62 |
+
|
63 |
+
if model_name == 'gpt4':
|
64 |
+
deployment_name = 'gpt4'
|
65 |
+
else:
|
66 |
+
deployment_name = model_name
|
67 |
+
self.deployment_name = deployment_name
|
68 |
+
if self.deployment_name is None:
|
69 |
+
raise ValueError("Deployment name is not set.")
|
70 |
+
|
71 |
+
# Get other parameters from config file (if available) or use defaults
|
72 |
+
self.config: Dict = self.config.get(model_name, {})
|
73 |
+
self.prompt_token_cost: float = self.config.get("prompt_token_cost", 0.0015)
|
74 |
+
self.response_token_cost: float = self.config.get("response_token_cost", 0.002)
|
75 |
+
self.temperature: float = self.config.get("temperature", 1.0)
|
76 |
+
self.max_tokens: int = self.config.get("max_tokens", 1024)
|
77 |
+
self.stop: Union[str, List[str]] = self.config.get("stop", None)
|
78 |
+
|
79 |
+
# Initialize the Azure OpenAI Client
|
80 |
+
self.client = AzureOpenAI(
|
81 |
+
api_key=self.api_key,
|
82 |
+
api_version=self.api_version,
|
83 |
+
azure_endpoint=self.api_base
|
84 |
+
)
|
85 |
+
|
86 |
+
def query(
|
87 |
+
self, query: str, num_responses: int = 1
|
88 |
+
) -> Union[List[ChatCompletion], ChatCompletion]:
|
89 |
+
"""
|
90 |
+
Query the Azure OpenAI model for responses.
|
91 |
+
|
92 |
+
:param query: The query to be posed to the language model.
|
93 |
+
:type query: str
|
94 |
+
:param num_responses: Number of desired responses, default is 1.
|
95 |
+
:type num_responses: int
|
96 |
+
:return: Response(s) from the Azure OpenAI model.
|
97 |
+
:rtype: Dict
|
98 |
+
"""
|
99 |
+
if self.cache and query in self.respone_cache:
|
100 |
+
return self.respone_cache[query]
|
101 |
+
|
102 |
+
if num_responses == 1:
|
103 |
+
response = self.chat([{"role": "user", "content": query}], num_responses)
|
104 |
+
else:
|
105 |
+
response = []
|
106 |
+
next_try = num_responses
|
107 |
+
total_num_attempts = num_responses
|
108 |
+
while num_responses > 0 and total_num_attempts > 0:
|
109 |
+
try:
|
110 |
+
assert next_try > 0
|
111 |
+
res = self.chat([{"role": "user", "content": query}], next_try)
|
112 |
+
response.append(res)
|
113 |
+
num_responses -= next_try
|
114 |
+
next_try = min(num_responses, next_try)
|
115 |
+
except Exception as e:
|
116 |
+
next_try = (next_try + 1) // 2
|
117 |
+
self.logger.warning(
|
118 |
+
f"Error in chatgpt: {e}, trying again with {next_try} samples"
|
119 |
+
)
|
120 |
+
time.sleep(random.randint(1, 3))
|
121 |
+
total_num_attempts -= 1
|
122 |
+
|
123 |
+
if self.cache:
|
124 |
+
self.respone_cache[query] = response
|
125 |
+
return response
|
126 |
+
|
127 |
+
@backoff.on_exception(backoff.expo, Exception, max_time=10, max_tries=6)
|
128 |
+
def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
|
129 |
+
"""
|
130 |
+
Send chat messages to the Azure OpenAI model and retrieves the model's response.
|
131 |
+
Implements backoff on errors.
|
132 |
+
|
133 |
+
:param messages: A list of message dictionaries for the chat.
|
134 |
+
:type messages: List[Dict]
|
135 |
+
:param num_responses: Number of desired responses, default is 1.
|
136 |
+
:type num_responses: int
|
137 |
+
:return: The Azure OpenAI model's response.
|
138 |
+
:rtype: ChatCompletion
|
139 |
+
"""
|
140 |
+
response = self.client.chat.completions.create(
|
141 |
+
model=self.deployment_name,
|
142 |
+
messages=messages,
|
143 |
+
temperature=self.temperature,
|
144 |
+
max_tokens=self.max_tokens,
|
145 |
+
n=num_responses,
|
146 |
+
stop=self.stop,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Extract usage information differently for Azure OpenAI
|
150 |
+
self.prompt_tokens += response.usage.prompt_tokens
|
151 |
+
self.completion_tokens += response.usage.completion_tokens
|
152 |
+
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
|
153 |
+
completion_tokens_k = float(self.completion_tokens) / 1000.0
|
154 |
+
self.cost += (
|
155 |
+
self.prompt_token_cost * prompt_tokens_k
|
156 |
+
+ self.response_token_cost * completion_tokens_k
|
157 |
+
)
|
158 |
+
# self.logger.debug(
|
159 |
+
# f"This is the response from chatgpt: {response}"
|
160 |
+
# f"\nThis is the cost of the response: {self.cost}"
|
161 |
+
# )
|
162 |
+
return response
|
163 |
+
|
164 |
+
def get_response_texts(
|
165 |
+
self, query_response: Union[List[ChatCompletion], ChatCompletion]
|
166 |
+
) -> List[str]:
|
167 |
+
"""
|
168 |
+
Extract the response texts from the query response.
|
169 |
+
|
170 |
+
:param query_response: The response dictionary (or list of dictionaries) from the Azure OpenAI model.
|
171 |
+
:type query_response: Union[List[ChatCompletion], ChatCompletion]
|
172 |
+
:return: List of response strings.
|
173 |
+
:rtype: List[str]
|
174 |
+
"""
|
175 |
+
if not isinstance(query_response, List):
|
176 |
+
query_response = [query_response]
|
177 |
+
return [
|
178 |
+
choice.message.content
|
179 |
+
for response in query_response
|
180 |
+
for choice in response.choices
|
181 |
+
]
|
src/llm_explain/utility/graph_of_thoughts/language_models/chatgpt.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
import backoff
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
import time
|
22 |
+
from typing import List, Dict, Union
|
23 |
+
from openai import OpenAI, OpenAIError
|
24 |
+
from openai.types.chat.chat_completion import ChatCompletion
|
25 |
+
|
26 |
+
from .abstract_language_model import AbstractLanguageModel
|
27 |
+
|
28 |
+
|
29 |
+
class ChatGPT(AbstractLanguageModel):
|
30 |
+
"""
|
31 |
+
The ChatGPT class handles interactions with the OpenAI models using the provided configuration.
|
32 |
+
|
33 |
+
Inherits from the AbstractLanguageModel and implements its abstract methods.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self, config_path: str = "", model_name: str = "chatgpt", cache: bool = False
|
38 |
+
) -> None:
|
39 |
+
"""
|
40 |
+
Initialize the ChatGPT instance with configuration, model details, and caching options.
|
41 |
+
|
42 |
+
:param config_path: Path to the configuration file. Defaults to "".
|
43 |
+
:type config_path: str
|
44 |
+
:param model_name: Name of the model, default is 'chatgpt'. Used to select the correct configuration.
|
45 |
+
:type model_name: str
|
46 |
+
:param cache: Flag to determine whether to cache responses. Defaults to False.
|
47 |
+
:type cache: bool
|
48 |
+
"""
|
49 |
+
super().__init__(config_path, model_name, cache)
|
50 |
+
self.config: Dict = self.config[model_name]
|
51 |
+
# The model_id is the id of the model that is used for chatgpt, i.e. gpt-4, etc.
|
52 |
+
self.model_id: str = self.config["model_id"]
|
53 |
+
# The prompt_token_cost and response_token_cost are the costs for 1000 prompt tokens and 1000 response tokens respectively.
|
54 |
+
self.prompt_token_cost: float = self.config["prompt_token_cost"]
|
55 |
+
self.response_token_cost: float = self.config["response_token_cost"]
|
56 |
+
# The temperature of a model is defined as the randomness of the model's output.
|
57 |
+
self.temperature: float = self.config["temperature"]
|
58 |
+
# The maximum number of tokens to generate in the chat completion.
|
59 |
+
self.max_tokens: int = self.config["max_tokens"]
|
60 |
+
# The stop sequence is a sequence of tokens that the model will stop generating at (it will not generate the stop sequence).
|
61 |
+
self.stop: Union[str, List[str]] = self.config["stop"]
|
62 |
+
# The account organization is the organization that is used for chatgpt.
|
63 |
+
self.organization: str = self.config["organization"]
|
64 |
+
# if self.organization == "":
|
65 |
+
# self.logger.warning("OPENAI_ORGANIZATION is not set")
|
66 |
+
self.api_key: str = os.getenv("OPENAI_API_KEY", self.config["api_key"])
|
67 |
+
if self.api_key == "":
|
68 |
+
raise ValueError("OPENAI_API_KEY is not set")
|
69 |
+
# Initialize the OpenAI Client
|
70 |
+
# self.client = OpenAI(api_key=self.api_key, organization=self.organization)
|
71 |
+
self.client = OpenAI(api_key=self.api_key)
|
72 |
+
|
73 |
+
def query(
|
74 |
+
self, query: str, num_responses: int = 1
|
75 |
+
) -> Union[List[ChatCompletion], ChatCompletion]:
|
76 |
+
"""
|
77 |
+
Query the OpenAI model for responses.
|
78 |
+
|
79 |
+
:param query: The query to be posed to the language model.
|
80 |
+
:type query: str
|
81 |
+
:param num_responses: Number of desired responses, default is 1.
|
82 |
+
:type num_responses: int
|
83 |
+
:return: Response(s) from the OpenAI model.
|
84 |
+
:rtype: Dict
|
85 |
+
"""
|
86 |
+
if self.cache and query in self.respone_cache:
|
87 |
+
return self.respone_cache[query]
|
88 |
+
|
89 |
+
if num_responses == 1:
|
90 |
+
response = self.chat([{"role": "user", "content": query}], num_responses)
|
91 |
+
else:
|
92 |
+
response = []
|
93 |
+
next_try = num_responses
|
94 |
+
total_num_attempts = num_responses
|
95 |
+
while num_responses > 0 and total_num_attempts > 0:
|
96 |
+
try:
|
97 |
+
assert next_try > 0
|
98 |
+
res = self.chat([{"role": "user", "content": query}], next_try)
|
99 |
+
response.append(res)
|
100 |
+
num_responses -= next_try
|
101 |
+
next_try = min(num_responses, next_try)
|
102 |
+
except Exception as e:
|
103 |
+
next_try = (next_try + 1) // 2
|
104 |
+
self.logger.warning(
|
105 |
+
f"Error in chatgpt: {e}, trying again with {next_try} samples"
|
106 |
+
)
|
107 |
+
time.sleep(random.randint(1, 3))
|
108 |
+
total_num_attempts -= 1
|
109 |
+
|
110 |
+
if self.cache:
|
111 |
+
self.respone_cache[query] = response
|
112 |
+
return response
|
113 |
+
|
114 |
+
@backoff.on_exception(backoff.expo, OpenAIError, max_time=10, max_tries=6)
|
115 |
+
def chat(self, messages: List[Dict], num_responses: int = 1) -> ChatCompletion:
|
116 |
+
"""
|
117 |
+
Send chat messages to the OpenAI model and retrieves the model's response.
|
118 |
+
Implements backoff on OpenAI error.
|
119 |
+
|
120 |
+
:param messages: A list of message dictionaries for the chat.
|
121 |
+
:type messages: List[Dict]
|
122 |
+
:param num_responses: Number of desired responses, default is 1.
|
123 |
+
:type num_responses: int
|
124 |
+
:return: The OpenAI model's response.
|
125 |
+
:rtype: ChatCompletion
|
126 |
+
"""
|
127 |
+
response = self.client.chat.completions.create(
|
128 |
+
model=self.model_id,
|
129 |
+
messages=messages,
|
130 |
+
temperature=self.temperature,
|
131 |
+
max_tokens=self.max_tokens,
|
132 |
+
n=num_responses,
|
133 |
+
stop=self.stop,
|
134 |
+
)
|
135 |
+
|
136 |
+
self.prompt_tokens += response.usage.prompt_tokens
|
137 |
+
self.completion_tokens += response.usage.completion_tokens
|
138 |
+
prompt_tokens_k = float(self.prompt_tokens) / 1000.0
|
139 |
+
completion_tokens_k = float(self.completion_tokens) / 1000.0
|
140 |
+
self.cost = (
|
141 |
+
self.prompt_token_cost * prompt_tokens_k
|
142 |
+
+ self.response_token_cost * completion_tokens_k
|
143 |
+
)
|
144 |
+
self.logger.info(
|
145 |
+
f"This is the response from chatgpt: {response}"
|
146 |
+
f"\nThis is the cost of the response: {self.cost}"
|
147 |
+
)
|
148 |
+
return response
|
149 |
+
|
150 |
+
def get_response_texts(
|
151 |
+
self, query_response: Union[List[ChatCompletion], ChatCompletion]
|
152 |
+
) -> List[str]:
|
153 |
+
"""
|
154 |
+
Extract the response texts from the query response.
|
155 |
+
|
156 |
+
:param query_response: The response dictionary (or list of dictionaries) from the OpenAI model.
|
157 |
+
:type query_response: Union[List[ChatCompletion], ChatCompletion]
|
158 |
+
:return: List of response strings.
|
159 |
+
:rtype: List[str]
|
160 |
+
"""
|
161 |
+
if not isinstance(query_response, List):
|
162 |
+
query_response = [query_response]
|
163 |
+
return [
|
164 |
+
choice.message.content
|
165 |
+
for response in query_response
|
166 |
+
for choice in response.choices
|
167 |
+
]
|
src/llm_explain/utility/graph_of_thoughts/operations/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from .thought import Thought
|
19 |
+
from .graph_of_operations import GraphOfOperations
|
20 |
+
from .operations import (
|
21 |
+
Operation,
|
22 |
+
Score,
|
23 |
+
ValidateAndImprove,
|
24 |
+
Generate,
|
25 |
+
Aggregate,
|
26 |
+
KeepBestN,
|
27 |
+
KeepValid,
|
28 |
+
Selector,
|
29 |
+
GroundTruth,
|
30 |
+
Improve,
|
31 |
+
)
|
src/llm_explain/utility/graph_of_thoughts/operations/graph_of_operations.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
from typing import List
|
20 |
+
|
21 |
+
from ..operations.operations import Operation
|
22 |
+
|
23 |
+
|
24 |
+
class GraphOfOperations:
|
25 |
+
"""
|
26 |
+
Represents the Graph of Operations, which prescribes the execution plan of thought operations.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self) -> None:
|
30 |
+
"""
|
31 |
+
Initializes a new Graph of Operations instance with empty operations, roots, and leaves.
|
32 |
+
The roots are the entry points in the graph with no predecessors.
|
33 |
+
The leaves are the exit points in the graph with no successors.
|
34 |
+
"""
|
35 |
+
self.operations: List[Operation] = []
|
36 |
+
self.roots: List[Operation] = []
|
37 |
+
self.leaves: List[Operation] = []
|
38 |
+
|
39 |
+
def append_operation(self, operation: Operation) -> None:
|
40 |
+
"""
|
41 |
+
Appends an operation to all leaves in the graph and updates the relationships.
|
42 |
+
|
43 |
+
:param operation: The operation to append.
|
44 |
+
:type operation: Operation
|
45 |
+
"""
|
46 |
+
self.operations.append(operation)
|
47 |
+
|
48 |
+
if len(self.roots) == 0:
|
49 |
+
self.roots = [operation]
|
50 |
+
else:
|
51 |
+
for leave in self.leaves:
|
52 |
+
leave.add_successor(operation)
|
53 |
+
|
54 |
+
self.leaves = [operation]
|
55 |
+
|
56 |
+
def add_operation(self, operation: Operation) -> None:
|
57 |
+
"""
|
58 |
+
Add an operation to the graph considering its predecessors and successors.
|
59 |
+
Adjust roots and leaves based on the added operation's position within the graph.
|
60 |
+
|
61 |
+
:param operation: The operation to add.
|
62 |
+
:type operation: Operation
|
63 |
+
"""
|
64 |
+
self.operations.append(operation)
|
65 |
+
if len(self.roots) == 0:
|
66 |
+
self.roots = [operation]
|
67 |
+
self.leaves = [operation]
|
68 |
+
assert (
|
69 |
+
len(operation.predecessors) == 0
|
70 |
+
), "First operation should have no predecessors"
|
71 |
+
else:
|
72 |
+
if len(operation.predecessors) == 0:
|
73 |
+
self.roots.append(operation)
|
74 |
+
for predecessor in operation.predecessors:
|
75 |
+
if predecessor in self.leaves:
|
76 |
+
self.leaves.remove(predecessor)
|
77 |
+
if len(operation.successors) == 0:
|
78 |
+
self.leaves.append(operation)
|
src/llm_explain/utility/graph_of_thoughts/operations/operations.py
ADDED
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
# import logging
|
20 |
+
from enum import Enum
|
21 |
+
from typing import List, Iterator, Dict, Callable, Union
|
22 |
+
from abc import ABC, abstractmethod
|
23 |
+
import itertools
|
24 |
+
|
25 |
+
from ..operations.thought import Thought
|
26 |
+
from ..language_models import AbstractLanguageModel
|
27 |
+
from ..prompter import Prompter
|
28 |
+
from ..parser import Parser
|
29 |
+
|
30 |
+
from llm_explain.config.logger import CustomLogger
|
31 |
+
|
32 |
+
logging = CustomLogger()
|
33 |
+
|
34 |
+
class OperationType(Enum):
|
35 |
+
"""
|
36 |
+
Enum to represent different operation types that can be used as unique identifiers.
|
37 |
+
"""
|
38 |
+
|
39 |
+
score: int = 0
|
40 |
+
validate_and_improve: int = 1
|
41 |
+
generate: int = 2
|
42 |
+
improve: int = 3
|
43 |
+
aggregate: int = 4
|
44 |
+
keep_best_n: int = 5
|
45 |
+
keep_valid: int = 6
|
46 |
+
ground_truth_evaluator: int = 7
|
47 |
+
selector: int = 8
|
48 |
+
|
49 |
+
|
50 |
+
class Operation(ABC):
|
51 |
+
"""
|
52 |
+
Abstract base class that defines the interface for all operations.
|
53 |
+
"""
|
54 |
+
|
55 |
+
_ids: Iterator[int] = itertools.count(0)
|
56 |
+
|
57 |
+
operation_type: OperationType = None
|
58 |
+
|
59 |
+
def __init__(self) -> None:
|
60 |
+
"""
|
61 |
+
Initializes a new Operation instance with a unique id, and empty predecessors and successors.
|
62 |
+
"""
|
63 |
+
self.logger = CustomLogger()
|
64 |
+
self.id: int = next(Operation._ids)
|
65 |
+
self.predecessors: List[Operation] = []
|
66 |
+
self.successors: List[Operation] = []
|
67 |
+
self.executed: bool = False
|
68 |
+
|
69 |
+
def can_be_executed(self) -> bool:
|
70 |
+
"""
|
71 |
+
Checks if the operation can be executed based on its predecessors.
|
72 |
+
|
73 |
+
:return: True if all predecessors have been executed, False otherwise.
|
74 |
+
:rtype: bool
|
75 |
+
"""
|
76 |
+
return all(predecessor.executed for predecessor in self.predecessors)
|
77 |
+
|
78 |
+
def get_previous_thoughts(self) -> List[Thought]:
|
79 |
+
"""
|
80 |
+
Iterates over all predecessors and aggregates their thoughts.
|
81 |
+
|
82 |
+
:return: A list of all thoughts from the predecessors.
|
83 |
+
:rtype: List[Thought]
|
84 |
+
"""
|
85 |
+
previous_thoughts: List[Thought] = [
|
86 |
+
thought
|
87 |
+
for predecessor in self.predecessors
|
88 |
+
for thought in predecessor.get_thoughts()
|
89 |
+
]
|
90 |
+
|
91 |
+
return previous_thoughts
|
92 |
+
|
93 |
+
def add_predecessor(self, operation: Operation) -> None:
|
94 |
+
"""
|
95 |
+
Add a preceding operation and update the relationships.
|
96 |
+
|
97 |
+
:param operation: The operation to be set as a predecessor.
|
98 |
+
:type operation: Operation
|
99 |
+
"""
|
100 |
+
self.predecessors.append(operation)
|
101 |
+
operation.successors.append(self)
|
102 |
+
|
103 |
+
def add_successor(self, operation: Operation) -> None:
|
104 |
+
"""
|
105 |
+
Add a succeeding operation and update the relationships.
|
106 |
+
|
107 |
+
:param operation: The operation to be set as a successor.
|
108 |
+
:type operation: Operation
|
109 |
+
"""
|
110 |
+
self.successors.append(operation)
|
111 |
+
operation.predecessors.append(self)
|
112 |
+
|
113 |
+
def execute(
|
114 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
115 |
+
) -> None:
|
116 |
+
"""
|
117 |
+
Execute the operation, assuring that all predecessors have been executed.
|
118 |
+
|
119 |
+
:param lm: The language model to be used.
|
120 |
+
:type lm: AbstractLanguageModel
|
121 |
+
:param prompter: The prompter for crafting prompts.
|
122 |
+
:type prompter: Prompter
|
123 |
+
:param parser: The parser for parsing responses.
|
124 |
+
:type parser: Parser
|
125 |
+
:param kwargs: Additional parameters for execution.
|
126 |
+
:raises AssertionError: If not all predecessors have been executed.
|
127 |
+
"""
|
128 |
+
assert self.can_be_executed(), "Not all predecessors have been executed"
|
129 |
+
# self.logger.info(
|
130 |
+
# "Executing operation %d of type %s", self.id, self.operation_type
|
131 |
+
# )
|
132 |
+
self._execute(lm, prompter, parser, **kwargs)
|
133 |
+
# self.logger.debug("Operation %d executed", self.id)
|
134 |
+
self.executed = True
|
135 |
+
|
136 |
+
@abstractmethod
|
137 |
+
def _execute(
|
138 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
139 |
+
) -> None:
|
140 |
+
"""
|
141 |
+
Abstract method for the actual execution of the operation.
|
142 |
+
This should be implemented in derived classes.
|
143 |
+
|
144 |
+
:param lm: The language model to be used.
|
145 |
+
:type lm: AbstractLanguageModel
|
146 |
+
:param prompter: The prompter for crafting prompts.
|
147 |
+
:type prompter: Prompter
|
148 |
+
:param parser: The parser for parsing responses.
|
149 |
+
:type parser: Parser
|
150 |
+
:param kwargs: Additional parameters for execution.
|
151 |
+
"""
|
152 |
+
pass
|
153 |
+
|
154 |
+
@abstractmethod
|
155 |
+
def get_thoughts(self) -> List[Thought]:
|
156 |
+
"""
|
157 |
+
Abstract method to retrieve the thoughts associated with the operation.
|
158 |
+
This should be implemented in derived classes.
|
159 |
+
|
160 |
+
:return: List of associated thoughts.
|
161 |
+
:rtype: List[Thought]
|
162 |
+
"""
|
163 |
+
pass
|
164 |
+
|
165 |
+
|
166 |
+
class Score(Operation):
|
167 |
+
"""
|
168 |
+
Operation to score thoughts.
|
169 |
+
"""
|
170 |
+
|
171 |
+
operation_type: OperationType = OperationType.score
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
num_samples: int = 1,
|
176 |
+
combined_scoring: bool = False,
|
177 |
+
scoring_function: Callable[
|
178 |
+
[Union[List[Dict], Dict]], Union[List[float], float]
|
179 |
+
] = None,
|
180 |
+
) -> None:
|
181 |
+
"""
|
182 |
+
Initializes a new Score operation.
|
183 |
+
|
184 |
+
:param num_samples: Number of samples to use for scoring. Defaults to 1.
|
185 |
+
:type num_samples: int
|
186 |
+
:param combined_scoring: Whether to score all thoughts together or individually. Defaults to False.
|
187 |
+
:type combined_scoring: bool
|
188 |
+
:param scoring_function: A function to score thoughts (if not using LM). Defaults to None.
|
189 |
+
:type scoring_function: Takes a list of thought states or a single thought state and
|
190 |
+
returns a list of scores or a single score.
|
191 |
+
"""
|
192 |
+
super().__init__()
|
193 |
+
self.num_samples: int = num_samples
|
194 |
+
self.combined_scoring: bool = combined_scoring
|
195 |
+
self.thoughts: List[Thought] = []
|
196 |
+
self.scoring_function: Callable[
|
197 |
+
[Union[List[Dict], Dict]], Union[List[float], float]
|
198 |
+
] = scoring_function
|
199 |
+
|
200 |
+
def get_thoughts(self) -> List[Thought]:
|
201 |
+
"""
|
202 |
+
Returns the thoughts associated with the operation.
|
203 |
+
|
204 |
+
:return: List of scored thoughts.
|
205 |
+
:rtype: List[Thought]
|
206 |
+
"""
|
207 |
+
return self.thoughts
|
208 |
+
|
209 |
+
def _execute(
|
210 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
211 |
+
) -> None:
|
212 |
+
"""
|
213 |
+
Executes the scoring operation by scoring the thoughts from the predecessors.
|
214 |
+
If combined scoring is used, the thoughts are scored together, otherwise individually.
|
215 |
+
If a scoring function is provided, it is used, otherwise the LM is prompted.
|
216 |
+
|
217 |
+
:param lm: The language model to be used.
|
218 |
+
:type lm: AbstractLanguageModel
|
219 |
+
:param prompter: The prompter for crafting prompts.
|
220 |
+
:type prompter: Prompter
|
221 |
+
:param parser: The parser for parsing responses.
|
222 |
+
:type parser: Parser
|
223 |
+
:param kwargs: Additional parameters for execution.
|
224 |
+
:raises AssertionError: If operation has no predecessors.
|
225 |
+
"""
|
226 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
227 |
+
|
228 |
+
assert (
|
229 |
+
len(self.predecessors) > 0
|
230 |
+
), "Score operation needs at least one predecessor"
|
231 |
+
|
232 |
+
if self.combined_scoring:
|
233 |
+
previous_thoughts_states = [thought.state for thought in previous_thoughts]
|
234 |
+
if self.scoring_function is not None:
|
235 |
+
# self.logger.debug(
|
236 |
+
# "Using scoring function %s to score states", self.scoring_function
|
237 |
+
# )
|
238 |
+
scores = self.scoring_function(previous_thoughts_states)
|
239 |
+
else:
|
240 |
+
prompt = prompter.score_prompt(previous_thoughts_states)
|
241 |
+
# self.logger.debug("Prompt for LM: %s", prompt)
|
242 |
+
|
243 |
+
responses = lm.get_response_texts(
|
244 |
+
lm.query(prompt, num_responses=self.num_samples)
|
245 |
+
)
|
246 |
+
# self.logger.debug("Responses from LM: %s", responses)
|
247 |
+
scores = parser.parse_score_answer(previous_thoughts_states, responses)
|
248 |
+
for thought, score in zip(previous_thoughts, scores):
|
249 |
+
new_thought = Thought.from_thought(thought)
|
250 |
+
new_thought.score = score
|
251 |
+
self.thoughts.append(new_thought)
|
252 |
+
else:
|
253 |
+
for thought in previous_thoughts:
|
254 |
+
new_thought = Thought.from_thought(thought)
|
255 |
+
if self.scoring_function is not None:
|
256 |
+
# self.logger.debug(
|
257 |
+
# "Using scoring function %s to score state",
|
258 |
+
# self.scoring_function,
|
259 |
+
# )
|
260 |
+
score = self.scoring_function(thought.state)
|
261 |
+
else:
|
262 |
+
prompt = prompter.score_prompt([thought.state])
|
263 |
+
# self.logger.debug("Prompt for LM: %s", prompt)
|
264 |
+
|
265 |
+
responses = lm.get_response_texts(
|
266 |
+
lm.query(prompt, num_responses=self.num_samples)
|
267 |
+
)
|
268 |
+
# self.logger.debug("Responses from LM: %s", responses)
|
269 |
+
score = parser.parse_score_answer([thought.state], responses)[0]
|
270 |
+
|
271 |
+
new_thought.score = score
|
272 |
+
self.thoughts.append(new_thought)
|
273 |
+
|
274 |
+
# self.logger.debug(
|
275 |
+
# "Score operation %d scored %d thoughts",
|
276 |
+
# self.id,
|
277 |
+
# len(self.thoughts),
|
278 |
+
# )
|
279 |
+
|
280 |
+
|
281 |
+
class ValidateAndImprove(Operation):
|
282 |
+
"""
|
283 |
+
Operation to validate and improve thoughts.
|
284 |
+
"""
|
285 |
+
|
286 |
+
operation_type: OperationType = OperationType.validate_and_improve
|
287 |
+
|
288 |
+
def __init__(
|
289 |
+
self,
|
290 |
+
num_samples: int = 1,
|
291 |
+
improve: bool = True,
|
292 |
+
num_tries: int = 3,
|
293 |
+
validate_function: Callable[[Dict], bool] = None,
|
294 |
+
) -> None:
|
295 |
+
"""
|
296 |
+
Initializes a new ValidateAndImprove operation.
|
297 |
+
|
298 |
+
:param num_samples: Number of samples to use for validation. Defaults to 1.
|
299 |
+
:type num_samples: int
|
300 |
+
:param improve: Whether to improve the thought if it is not valid. Defaults to True.
|
301 |
+
:type improve: bool
|
302 |
+
:param num_tries: Number of tries to improve the thought before giving up. Defaults to 3.
|
303 |
+
:type num_tries: int
|
304 |
+
:param validate_function: A function to validate thoughts (if not using LM). Defaults to None.
|
305 |
+
:type validate_function: Takes a thought state and returns a boolean.
|
306 |
+
"""
|
307 |
+
super().__init__()
|
308 |
+
self.num_samples: int = num_samples
|
309 |
+
self.improve: bool = improve
|
310 |
+
self.num_tries: int = num_tries
|
311 |
+
self.validate_function: Callable[[Dict], bool] = validate_function
|
312 |
+
self.thoughts: List[List[Thought]] = []
|
313 |
+
|
314 |
+
def get_thoughts(self) -> List[Thought]:
|
315 |
+
"""
|
316 |
+
Returns the list of final thoughts, after validation and improvement.
|
317 |
+
|
318 |
+
:return: List of final validated and improved thoughts.
|
319 |
+
:rtype: List[Thought]
|
320 |
+
"""
|
321 |
+
return [thought_list[-1] for thought_list in self.thoughts]
|
322 |
+
|
323 |
+
def _execute(
|
324 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
325 |
+
) -> None:
|
326 |
+
"""
|
327 |
+
Executes the ValidateAndImprove operation by validating and improving the predecessors' thoughts.
|
328 |
+
If a validation function is provided, it is used, otherwise the LM is prompted.
|
329 |
+
If improvement is enabled, the LM is prompted to improve the thought, if it is not valid.
|
330 |
+
|
331 |
+
:param lm: The language model to be used.
|
332 |
+
:type lm: AbstractLanguageModel
|
333 |
+
:param prompter: The prompter for crafting prompts.
|
334 |
+
:type prompter: Prompter
|
335 |
+
:param parser: The parser for parsing responses.
|
336 |
+
:type parser: Parser
|
337 |
+
:param kwargs: Additional parameters for execution.
|
338 |
+
:raises AssertionError: If operation has no predecessors.
|
339 |
+
"""
|
340 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
341 |
+
|
342 |
+
assert (
|
343 |
+
len(self.predecessors) > 0
|
344 |
+
), "ValidateAndImprove operation needs at least one predecessor"
|
345 |
+
|
346 |
+
for thought in previous_thoughts:
|
347 |
+
thought_list = []
|
348 |
+
current_thought = Thought.from_thought(thought)
|
349 |
+
current_try = 0
|
350 |
+
while True:
|
351 |
+
if self.validate_function is not None:
|
352 |
+
# self.logger.debug(
|
353 |
+
# "Using validate function %s to score states",
|
354 |
+
# self.validate_function,
|
355 |
+
# )
|
356 |
+
valid = self.validate_function(current_thought.state)
|
357 |
+
else:
|
358 |
+
prompt = prompter.validation_prompt(**current_thought.state)
|
359 |
+
# self.logger.debug("Prompt for LM: %s", prompt)
|
360 |
+
responses = lm.get_response_texts(
|
361 |
+
lm.query(prompt, num_responses=self.num_samples)
|
362 |
+
)
|
363 |
+
# self.logger.debug("Responses from LM: %s", responses)
|
364 |
+
|
365 |
+
valid = parser.parse_validation_answer(
|
366 |
+
current_thought.state, responses
|
367 |
+
)
|
368 |
+
current_thought.valid = valid
|
369 |
+
thought_list.append(current_thought)
|
370 |
+
if (
|
371 |
+
not self.improve
|
372 |
+
or current_thought.valid
|
373 |
+
or current_try >= self.num_tries
|
374 |
+
):
|
375 |
+
break
|
376 |
+
improve_prompt = prompter.improve_prompt(**current_thought.state)
|
377 |
+
# self.logger.debug("Prompt for LM: %s", improve_prompt)
|
378 |
+
responses = lm.get_response_texts(
|
379 |
+
lm.query(improve_prompt, num_responses=1)
|
380 |
+
)
|
381 |
+
# self.logger.debug("Responses from LM: %s", responses)
|
382 |
+
state_update = parser.parse_improve_answer(
|
383 |
+
current_thought.state, responses
|
384 |
+
)
|
385 |
+
current_thought = Thought({**current_thought.state, **state_update})
|
386 |
+
current_try += 1
|
387 |
+
self.thoughts.append(thought_list)
|
388 |
+
|
389 |
+
# self.logger.debug(
|
390 |
+
# "Validate and improve operation %d created %d valid thoughts from %d previous thoughts",
|
391 |
+
# self.id,
|
392 |
+
# len(
|
393 |
+
# [
|
394 |
+
# thought_list[-1]
|
395 |
+
# for thought_list in self.thoughts
|
396 |
+
# if thought_list[-1].valid
|
397 |
+
# ]
|
398 |
+
# ),
|
399 |
+
# len(previous_thoughts),
|
400 |
+
# )
|
401 |
+
|
402 |
+
|
403 |
+
class Generate(Operation):
|
404 |
+
"""
|
405 |
+
Operation to generate thoughts.
|
406 |
+
"""
|
407 |
+
|
408 |
+
operation_type: OperationType = OperationType.generate
|
409 |
+
|
410 |
+
def __init__(
|
411 |
+
self, num_branches_prompt: int = 1, num_branches_response: int = 1
|
412 |
+
) -> None:
|
413 |
+
"""
|
414 |
+
Initializes a new Generate operation.
|
415 |
+
|
416 |
+
:param num_branches_prompt: Number of responses that each prompt should generate (passed to prompter). Defaults to 1.
|
417 |
+
:type num_branches_prompt: int
|
418 |
+
:param num_branches_response: Number of responses the LM should generate for each prompt. Defaults to 1.
|
419 |
+
:type num_branches_response: int
|
420 |
+
"""
|
421 |
+
super().__init__()
|
422 |
+
self.num_branches_prompt: int = num_branches_prompt
|
423 |
+
self.num_branches_response: int = num_branches_response
|
424 |
+
self.thoughts: List[Thought] = []
|
425 |
+
|
426 |
+
def get_thoughts(self) -> List[Thought]:
|
427 |
+
"""
|
428 |
+
Returns the thoughts associated with the operation.
|
429 |
+
|
430 |
+
:return: List of generated thoughts.
|
431 |
+
:rtype: List[Thought]
|
432 |
+
"""
|
433 |
+
return self.thoughts
|
434 |
+
|
435 |
+
def _execute(
|
436 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
437 |
+
) -> None:
|
438 |
+
"""
|
439 |
+
Executes the Generate operation by generating thoughts from the predecessors.
|
440 |
+
The thoughts are generated by prompting the LM with the predecessors' thought states.
|
441 |
+
If there are no predecessors, the kwargs are used as a base state.
|
442 |
+
|
443 |
+
:param lm: The language model to be used.
|
444 |
+
:type lm: AbstractLanguageModel
|
445 |
+
:param prompter: The prompter for crafting prompts.
|
446 |
+
:type prompter: Prompter
|
447 |
+
:param parser: The parser for parsing responses.
|
448 |
+
:type parser: Parser
|
449 |
+
:param kwargs: Additional parameters for execution.
|
450 |
+
"""
|
451 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
452 |
+
|
453 |
+
if len(previous_thoughts) == 0 and len(self.predecessors) > 0:
|
454 |
+
return
|
455 |
+
|
456 |
+
if len(previous_thoughts) == 0:
|
457 |
+
# no predecessors, use kwargs as base state
|
458 |
+
previous_thoughts = [Thought(state=kwargs)]
|
459 |
+
|
460 |
+
for thought in previous_thoughts:
|
461 |
+
base_state = thought.state
|
462 |
+
prompt = prompter.generate_prompt(self.num_branches_prompt, **base_state)
|
463 |
+
# self.logger.debug("Prompt for LM: %s", prompt)
|
464 |
+
responses = lm.get_response_texts(
|
465 |
+
lm.query(prompt, num_responses=self.num_branches_response)
|
466 |
+
)
|
467 |
+
# self.logger.debug("Responses from LM: %s", responses)
|
468 |
+
for new_state in parser.parse_generate_answer(base_state, responses):
|
469 |
+
new_state = {**base_state, **new_state}
|
470 |
+
self.thoughts.append(Thought(new_state))
|
471 |
+
# self.logger.debug(
|
472 |
+
# "New thought %d created with state %s",
|
473 |
+
# self.thoughts[-1].id,
|
474 |
+
# self.thoughts[-1].state,
|
475 |
+
# )
|
476 |
+
if (
|
477 |
+
len(self.thoughts)
|
478 |
+
> self.num_branches_prompt
|
479 |
+
* self.num_branches_response
|
480 |
+
* len(previous_thoughts)
|
481 |
+
and self.num_branches_prompt > 0
|
482 |
+
):
|
483 |
+
self.logger.warning(
|
484 |
+
"Generate operation %d created more thoughts than expected",
|
485 |
+
self.id,
|
486 |
+
)
|
487 |
+
# self.logger.debug(
|
488 |
+
# "Generate operation %d created %d new thoughts", self.id, len(self.thoughts)
|
489 |
+
# )
|
490 |
+
|
491 |
+
|
492 |
+
class Improve(Operation):
|
493 |
+
"""
|
494 |
+
Operation to improve thoughts.
|
495 |
+
"""
|
496 |
+
|
497 |
+
operation_type: OperationType = OperationType.improve
|
498 |
+
|
499 |
+
def __init__(self) -> None:
|
500 |
+
"""
|
501 |
+
Initializes a new Improve operation.
|
502 |
+
"""
|
503 |
+
super().__init__()
|
504 |
+
self.thoughts: List[Thought] = []
|
505 |
+
|
506 |
+
def get_thoughts(self) -> List[Thought]:
|
507 |
+
"""
|
508 |
+
Returns the thoughts associated with the operation after improvement.
|
509 |
+
|
510 |
+
:return: List of improved thoughts.
|
511 |
+
:rtype: List[Thought]
|
512 |
+
"""
|
513 |
+
return self.thoughts
|
514 |
+
|
515 |
+
def _execute(
|
516 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
517 |
+
) -> None:
|
518 |
+
"""
|
519 |
+
Executes the Improve operation by improving the predecessors' thoughts.
|
520 |
+
The thoughts are improved by prompting the LM with the predecessors' thought states.
|
521 |
+
|
522 |
+
:param lm: The language model to be used.
|
523 |
+
:type lm: AbstractLanguageModel
|
524 |
+
:param prompter: The prompter for crafting prompts.
|
525 |
+
:type prompter: Prompter
|
526 |
+
:param parser: The parser for parsing responses.
|
527 |
+
:type parser: Parser
|
528 |
+
:param kwargs: Additional parameters for execution.
|
529 |
+
:raises AssertionError: If operation has no predecessors.
|
530 |
+
"""
|
531 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
532 |
+
|
533 |
+
assert len(self.predecessors) > 0, "Needs at least one predecessor"
|
534 |
+
|
535 |
+
for thought in previous_thoughts:
|
536 |
+
improve_prompt = prompter.improve_prompt(**thought.state)
|
537 |
+
# self.logger.debug("Prompt for LM: %s", improve_prompt)
|
538 |
+
responses = lm.get_response_texts(lm.query(improve_prompt, num_responses=1))
|
539 |
+
# self.logger.debug("Responses from LM: %s", responses)
|
540 |
+
state_update = parser.parse_improve_answer(thought.state, responses)
|
541 |
+
self.thoughts.append(Thought({**thought.state, **state_update}))
|
542 |
+
|
543 |
+
# self.logger.debug(
|
544 |
+
# "Improve operation %d improved %d thoughts", self.id, len(self.thoughts)
|
545 |
+
# )
|
546 |
+
|
547 |
+
|
548 |
+
class Aggregate(Operation):
|
549 |
+
"""
|
550 |
+
Operation to aggregate thoughts.
|
551 |
+
"""
|
552 |
+
|
553 |
+
operation_type: OperationType = OperationType.aggregate
|
554 |
+
|
555 |
+
def __init__(self, num_responses: int = 1) -> None:
|
556 |
+
"""
|
557 |
+
Initializes a new Aggregate operation.
|
558 |
+
|
559 |
+
:param num_responses: Number of responses to use for aggregation. Defaults to 1.
|
560 |
+
:type num_responses: int
|
561 |
+
"""
|
562 |
+
super().__init__()
|
563 |
+
self.thoughts: List[Thought] = []
|
564 |
+
self.num_responses: int = num_responses
|
565 |
+
|
566 |
+
def get_thoughts(self) -> List[Thought]:
|
567 |
+
"""
|
568 |
+
Returns the thoughts associated with the operation after aggregation.
|
569 |
+
|
570 |
+
:return: List of aggregated thoughts.
|
571 |
+
:rtype: List[Thought]
|
572 |
+
"""
|
573 |
+
return self.thoughts
|
574 |
+
|
575 |
+
def _execute(
|
576 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
577 |
+
) -> None:
|
578 |
+
"""
|
579 |
+
Executes the Aggregate operation by aggregating the predecessors' thoughts.
|
580 |
+
The thoughts are aggregated by prompting the LM with the predecessors' thought states.
|
581 |
+
|
582 |
+
:param lm: The language model to be used.
|
583 |
+
:type lm: AbstractLanguageModel
|
584 |
+
:param prompter: The prompter for crafting prompts.
|
585 |
+
:type prompter: Prompter
|
586 |
+
:param parser: The parser for parsing responses.
|
587 |
+
:type parser: Parser
|
588 |
+
:param kwargs: Additional parameters for execution.
|
589 |
+
:raises AssertionError: If operation has no predecessors.
|
590 |
+
"""
|
591 |
+
assert (
|
592 |
+
len(self.predecessors) >= 1
|
593 |
+
), "Aggregate operation must have at least one predecessor"
|
594 |
+
|
595 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
596 |
+
|
597 |
+
if len(previous_thoughts) == 0:
|
598 |
+
return
|
599 |
+
|
600 |
+
# applied in order of score
|
601 |
+
base_state: Dict = {}
|
602 |
+
for thought in sorted(previous_thoughts, key=lambda thought: thought.score):
|
603 |
+
base_state = {**base_state, **thought.state}
|
604 |
+
|
605 |
+
previous_thought_states = [thought.state for thought in previous_thoughts]
|
606 |
+
prompt = prompter.aggregation_prompt(previous_thought_states)
|
607 |
+
|
608 |
+
# self.logger.debug("Prompt for LM: %s", prompt)
|
609 |
+
|
610 |
+
responses = lm.get_response_texts(
|
611 |
+
lm.query(prompt, num_responses=self.num_responses)
|
612 |
+
)
|
613 |
+
|
614 |
+
# self.logger.debug("Responses from LM: %s", responses)
|
615 |
+
|
616 |
+
parsed = parser.parse_aggregation_answer(previous_thought_states, responses)
|
617 |
+
|
618 |
+
if isinstance(parsed, dict):
|
619 |
+
parsed = [parsed]
|
620 |
+
for new_state in parsed:
|
621 |
+
self.thoughts.append(Thought({**base_state, **new_state}))
|
622 |
+
|
623 |
+
|
624 |
+
class KeepBestN(Operation):
|
625 |
+
"""
|
626 |
+
Operation to keep the best N thoughts from predecessors based on their score.
|
627 |
+
"""
|
628 |
+
|
629 |
+
operation_type: OperationType = OperationType.keep_best_n
|
630 |
+
|
631 |
+
def __init__(self, n: int, higher_is_better: bool = True) -> None:
|
632 |
+
"""
|
633 |
+
Initializes a new KeepBestN operation.
|
634 |
+
|
635 |
+
:param n: Maximum number of thoughts to keep.
|
636 |
+
:type n: int
|
637 |
+
:param higher_is_better: Whether higher scores are better. Defaults to True.
|
638 |
+
:type higher_is_better: bool
|
639 |
+
:raises AssertionError: If `n` is not greater than zero.
|
640 |
+
"""
|
641 |
+
super().__init__()
|
642 |
+
self.n: int = n
|
643 |
+
assert self.n > 0, "KeepBestN operation must keep at least one thought"
|
644 |
+
self.higher_is_better: bool = higher_is_better
|
645 |
+
self.thoughts: List[Thought] = []
|
646 |
+
|
647 |
+
def get_best_n(self) -> List[Thought]:
|
648 |
+
"""
|
649 |
+
Returns the best N thoughts from the predecessors based on their score.
|
650 |
+
|
651 |
+
:return: List of best N thoughts.
|
652 |
+
:rtype: List[Thought]
|
653 |
+
:raises AssertionError: If not all predecessors have been executed.
|
654 |
+
:raises AssertionError: If not all thoughts have been scored.
|
655 |
+
"""
|
656 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
657 |
+
assert all(
|
658 |
+
previous_thought.scored for previous_thought in previous_thoughts
|
659 |
+
), "Not all thoughts have been scored"
|
660 |
+
|
661 |
+
try:
|
662 |
+
return sorted(
|
663 |
+
previous_thoughts,
|
664 |
+
key=lambda thought: thought.score,
|
665 |
+
reverse=self.higher_is_better,
|
666 |
+
)[: self.n]
|
667 |
+
except:
|
668 |
+
self.logger.error("Error in KeepBestN operation")
|
669 |
+
self.logger.error(
|
670 |
+
"Previous operation: %s", [op.id for op in self.predecessors]
|
671 |
+
)
|
672 |
+
self.logger.error("Previous thoughts: %s", previous_thoughts)
|
673 |
+
self.logger.error(
|
674 |
+
"Scores: %s", [thought.score for thought in previous_thoughts]
|
675 |
+
)
|
676 |
+
return sorted(
|
677 |
+
[i for i in previous_thoughts if isinstance(i.score, float)],
|
678 |
+
key=lambda thought: thought.score,
|
679 |
+
reverse=self.higher_is_better,
|
680 |
+
)[: self.n]
|
681 |
+
|
682 |
+
def get_thoughts(self) -> List[Thought]:
|
683 |
+
"""
|
684 |
+
Returns the thoughts kept by the operation.
|
685 |
+
|
686 |
+
:return: List of kept thoughts.
|
687 |
+
:rtype: List[Thought]
|
688 |
+
"""
|
689 |
+
return self.thoughts
|
690 |
+
|
691 |
+
def _execute(
|
692 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
693 |
+
) -> None:
|
694 |
+
"""
|
695 |
+
Executes the KeepBestN operation by keeping the best N thoughts from the predecessors according to their score.
|
696 |
+
|
697 |
+
:param lm: The language model to be used.
|
698 |
+
:type lm: AbstractLanguageModel
|
699 |
+
:param prompter: The prompter for crafting prompts.
|
700 |
+
:type prompter: Prompter
|
701 |
+
:param parser: The parser for parsing responses.
|
702 |
+
:type parser: Parser
|
703 |
+
:param kwargs: Additional parameters for execution.
|
704 |
+
:raises AssertionError: If operation has no predecessors.
|
705 |
+
:raises AssertionError: If not all predecessors have been executed.
|
706 |
+
:raises AssertionError: If not all thoughts have been scored.
|
707 |
+
"""
|
708 |
+
assert (
|
709 |
+
len(self.predecessors) >= 1
|
710 |
+
), "KeepBestN operation must have at least one predecessor"
|
711 |
+
|
712 |
+
self.thoughts = [Thought.from_thought(thought) for thought in self.get_best_n()]
|
713 |
+
|
714 |
+
# for thought in self.thoughts:
|
715 |
+
# self.logger.debug(
|
716 |
+
# "Thought %d with state %s kept", thought.id, thought.state
|
717 |
+
# )
|
718 |
+
|
719 |
+
# self.logger.debug(
|
720 |
+
# "KeepBestN operation %d kept %d thoughts", self.id, len(self.thoughts)
|
721 |
+
# )
|
722 |
+
|
723 |
+
|
724 |
+
class KeepValid(Operation):
|
725 |
+
"""
|
726 |
+
Operation to keep valid thoughts from predecessors.
|
727 |
+
"""
|
728 |
+
|
729 |
+
operation_type: OperationType = OperationType.keep_valid
|
730 |
+
|
731 |
+
def __init__(self) -> None:
|
732 |
+
"""
|
733 |
+
Initializes a new KeepValid operation.
|
734 |
+
"""
|
735 |
+
super().__init__()
|
736 |
+
self.thoughts: List[Thought] = []
|
737 |
+
|
738 |
+
def get_thoughts(self) -> List[Thought]:
|
739 |
+
"""
|
740 |
+
Returns the thoughts kept by the operation.
|
741 |
+
|
742 |
+
:return: List of kept thoughts.
|
743 |
+
:rtype: List[Thought]
|
744 |
+
"""
|
745 |
+
return self.thoughts
|
746 |
+
|
747 |
+
def _execute(
|
748 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
749 |
+
) -> None:
|
750 |
+
"""
|
751 |
+
Executes the KeepValid operation by keeping the valid thoughts from the predecessors.
|
752 |
+
Keeps unvalidated thoughts as well.
|
753 |
+
|
754 |
+
:param lm: The language model to be used.
|
755 |
+
:type lm: AbstractLanguageModel
|
756 |
+
:param prompter: The prompter for crafting prompts.
|
757 |
+
:type prompter: Prompter
|
758 |
+
:param parser: The parser for parsing responses.
|
759 |
+
:type parser: Parser
|
760 |
+
:param kwargs: Additional parameters for execution.
|
761 |
+
:raises AssertionError: If operation has no predecessors.
|
762 |
+
"""
|
763 |
+
assert (
|
764 |
+
len(self.predecessors) >= 1
|
765 |
+
), "KeepValid operation must have at least one predecessor"
|
766 |
+
|
767 |
+
self.thoughts: List[Thought] = [
|
768 |
+
Thought.from_thought(thought)
|
769 |
+
for thought in self.get_previous_thoughts()
|
770 |
+
if not thought.validated or thought.valid
|
771 |
+
]
|
772 |
+
|
773 |
+
# if any(not thought.validated for thought in self.thoughts):
|
774 |
+
# self.logger.warning(
|
775 |
+
# "KeepValid operation %d has unvalidated thoughts", self.id
|
776 |
+
# )
|
777 |
+
|
778 |
+
# for thought in self.thoughts:
|
779 |
+
# self.logger.debug(
|
780 |
+
# "Thought %d with state %s kept", thought.id, thought.state
|
781 |
+
# )
|
782 |
+
|
783 |
+
# self.logger.debug(
|
784 |
+
# "KeepValid operation %d kept %d thoughts", self.id, len(self.thoughts)
|
785 |
+
# )
|
786 |
+
|
787 |
+
|
788 |
+
class GroundTruth(Operation):
|
789 |
+
"""
|
790 |
+
Operation to evaluate if thoughts correctly solve the problem, using a ground truth evaluator
|
791 |
+
"""
|
792 |
+
|
793 |
+
operation_type: OperationType = OperationType.ground_truth_evaluator
|
794 |
+
|
795 |
+
def __init__(self, ground_truth_evaluator: Callable[[Dict], bool]) -> None:
|
796 |
+
"""
|
797 |
+
Initializes a new GroundTruth operation.
|
798 |
+
|
799 |
+
:param ground_truth_evaluator: A function to evaluate if a thought solves the problem.
|
800 |
+
:type ground_truth_evaluator: A function that takes a thought state and returns a boolean.
|
801 |
+
"""
|
802 |
+
super().__init__()
|
803 |
+
self.ground_truth_evaluator: Callable[[Dict], bool] = ground_truth_evaluator
|
804 |
+
self.thoughts: List[Thought] = []
|
805 |
+
|
806 |
+
def get_thoughts(self) -> List[Thought]:
|
807 |
+
"""
|
808 |
+
Returns the thoughts associated with the operation.
|
809 |
+
|
810 |
+
:return: List of evaluated thoughts.
|
811 |
+
:rtype: List[Thought]
|
812 |
+
"""
|
813 |
+
return self.thoughts
|
814 |
+
|
815 |
+
def _execute(
|
816 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
817 |
+
) -> None:
|
818 |
+
"""
|
819 |
+
Executes the GroundTruth operation by evaluating the predecessors' thoughts using the ground truth evaluator function.
|
820 |
+
|
821 |
+
:param lm: The language model to be used.
|
822 |
+
:type lm: AbstractLanguageModel
|
823 |
+
:param prompter: The prompter for crafting prompts.
|
824 |
+
:type prompter: Prompter
|
825 |
+
:param parser: The parser for parsing responses.
|
826 |
+
:type parser: Parser
|
827 |
+
:param kwargs: Additional parameters for execution.
|
828 |
+
:raises AssertionError: If operation has no predecessor.
|
829 |
+
"""
|
830 |
+
assert (
|
831 |
+
len(self.predecessors) >= 1
|
832 |
+
), "GroundTruth operation must have at least one predecessor"
|
833 |
+
|
834 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
835 |
+
|
836 |
+
for thought in previous_thoughts:
|
837 |
+
new_thought = Thought.from_thought(thought)
|
838 |
+
try:
|
839 |
+
new_thought.solved = self.ground_truth_evaluator(new_thought.state)
|
840 |
+
except:
|
841 |
+
new_thought.solved = False
|
842 |
+
self.thoughts.append(new_thought)
|
843 |
+
|
844 |
+
# self.logger.debug(
|
845 |
+
# "GroundTruth operation %d evaluated %d thoughts and %d solved the problem",
|
846 |
+
# self.id,
|
847 |
+
# len(self.thoughts),
|
848 |
+
# len([thought for thought in self.thoughts if thought.solved]),
|
849 |
+
# )
|
850 |
+
|
851 |
+
|
852 |
+
class Selector(Operation):
|
853 |
+
"""
|
854 |
+
Operation to select thoughts from predecessors.
|
855 |
+
Useful for separating thoughts to perform different, subsequent operations on them.
|
856 |
+
"""
|
857 |
+
|
858 |
+
operation_type: OperationType = OperationType.selector
|
859 |
+
|
860 |
+
def __init__(self, selector: Callable[[List[Thought]], List[Thought]]) -> None:
|
861 |
+
"""
|
862 |
+
Initializes a new Selector operation.
|
863 |
+
|
864 |
+
:param selector: A function to select thoughts from the predecessors' thoughts.
|
865 |
+
:type selector: A function that takes a list of thoughts and returns a list of thoughts.
|
866 |
+
"""
|
867 |
+
super().__init__()
|
868 |
+
self.selector: Callable[[List[Thought]], List[Thought]] = selector
|
869 |
+
self.thoughts: List[Thought] = []
|
870 |
+
|
871 |
+
def get_thoughts(self) -> List[Thought]:
|
872 |
+
"""
|
873 |
+
Returns the thoughts selected by the operation.
|
874 |
+
|
875 |
+
:return: List of selected thoughts.
|
876 |
+
:rtype: List[Thought]
|
877 |
+
"""
|
878 |
+
return self.thoughts
|
879 |
+
|
880 |
+
def _execute(
|
881 |
+
self, lm: AbstractLanguageModel, prompter: Prompter, parser: Parser, **kwargs
|
882 |
+
) -> None:
|
883 |
+
"""
|
884 |
+
Executes the Selector operation by selecting thoughts from the predecessors using the selector function.
|
885 |
+
If the Selector has no predecessors, the selector function is called with a thought containing the kwargs as state.
|
886 |
+
|
887 |
+
:param lm: The language model to be used.
|
888 |
+
:type lm: AbstractLanguageModel
|
889 |
+
:param prompter: The prompter for crafting prompts.
|
890 |
+
:type prompter: Prompter
|
891 |
+
:param parser: The parser for parsing responses.
|
892 |
+
:type parser: Parser
|
893 |
+
:param kwargs: Additional parameters for execution.
|
894 |
+
"""
|
895 |
+
previous_thoughts: List[Thought] = self.get_previous_thoughts()
|
896 |
+
|
897 |
+
if len(previous_thoughts) == 0:
|
898 |
+
previous_thoughts = [Thought(kwargs)]
|
899 |
+
|
900 |
+
self.thoughts = [
|
901 |
+
Thought.from_thought(thought)
|
902 |
+
for thought in self.selector(previous_thoughts)
|
903 |
+
]
|
904 |
+
|
905 |
+
# for thought in self.thoughts:
|
906 |
+
# self.logger.debug(
|
907 |
+
# "Thought %d with state %s selected", thought.id, thought.state
|
908 |
+
# )
|
909 |
+
|
910 |
+
# self.logger.debug(
|
911 |
+
# "Selector operation %d selected %d thoughts", self.id, len(self.thoughts)
|
912 |
+
# )
|
src/llm_explain/utility/graph_of_thoughts/operations/thought.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
# import logging
|
20 |
+
from typing import Iterator, Dict, Optional
|
21 |
+
import itertools
|
22 |
+
|
23 |
+
from llm_explain.config.logger import CustomLogger
|
24 |
+
|
25 |
+
logging = CustomLogger()
|
26 |
+
|
27 |
+
class Thought:
|
28 |
+
"""
|
29 |
+
Represents an LLM thought with its state, constructed by the parser, and various flags.
|
30 |
+
"""
|
31 |
+
|
32 |
+
_ids: Iterator[int] = itertools.count(0)
|
33 |
+
|
34 |
+
def __init__(self, state: Optional[Dict] = None) -> None:
|
35 |
+
"""
|
36 |
+
Initializes a new Thought instance with a state and various default flags.
|
37 |
+
|
38 |
+
:param state: The state of the thought. Defaults to None.
|
39 |
+
:type state: Optional[Dict]
|
40 |
+
"""
|
41 |
+
self.logger = CustomLogger()
|
42 |
+
self.id: int = next(Thought._ids)
|
43 |
+
self.state: Dict = state
|
44 |
+
self._score: float = 0.0
|
45 |
+
self._valid: bool = False
|
46 |
+
self._solved: bool = False
|
47 |
+
self.scored: bool = False
|
48 |
+
self.validated: bool = False
|
49 |
+
self.compared_to_ground_truth: bool = False
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def from_thought(thought: Thought) -> Thought:
|
53 |
+
"""
|
54 |
+
Creates a new thought from an existing one.
|
55 |
+
|
56 |
+
:param thought: An instance of a Thought to clone.
|
57 |
+
:return: A new Thought instance with properties copied from the input thought.
|
58 |
+
"""
|
59 |
+
new_thought = Thought(thought.state)
|
60 |
+
new_thought.score = thought.score
|
61 |
+
new_thought.valid = thought.valid
|
62 |
+
new_thought.solved = thought.solved
|
63 |
+
new_thought.scored = thought.scored
|
64 |
+
new_thought.validated = thought.validated
|
65 |
+
new_thought.compared_to_ground_truth = thought.compared_to_ground_truth
|
66 |
+
return new_thought
|
67 |
+
|
68 |
+
@property
|
69 |
+
def valid(self) -> bool:
|
70 |
+
"""
|
71 |
+
Returns the validity of the thought.
|
72 |
+
|
73 |
+
:return: The validity of the thought.
|
74 |
+
:rtype: bool
|
75 |
+
"""
|
76 |
+
return self._valid
|
77 |
+
|
78 |
+
@valid.setter
|
79 |
+
def valid(self, valid: bool) -> None:
|
80 |
+
"""
|
81 |
+
Sets the validity of the thought and the validated flag.
|
82 |
+
|
83 |
+
:param valid: The validity of the thought.
|
84 |
+
:type valid: bool
|
85 |
+
"""
|
86 |
+
self.validated = True
|
87 |
+
self._valid = valid
|
88 |
+
|
89 |
+
@property
|
90 |
+
def score(self) -> float:
|
91 |
+
"""
|
92 |
+
Returns the score of the thought.
|
93 |
+
|
94 |
+
:return: The score of the thought.
|
95 |
+
:rtype: float
|
96 |
+
"""
|
97 |
+
return self._score
|
98 |
+
|
99 |
+
@score.setter
|
100 |
+
def score(self, new_score: float) -> None:
|
101 |
+
"""
|
102 |
+
Sets the score of the thought and the scored flag.
|
103 |
+
|
104 |
+
:param new_score: The score of the thought.
|
105 |
+
:type new_score: float
|
106 |
+
"""
|
107 |
+
self.scored = True
|
108 |
+
self._score = new_score
|
109 |
+
|
110 |
+
@property
|
111 |
+
def solved(self) -> bool:
|
112 |
+
"""
|
113 |
+
Returns the solved flag of the thought.
|
114 |
+
|
115 |
+
:return: The solved flag of the thought.
|
116 |
+
:rtype: bool
|
117 |
+
"""
|
118 |
+
return self._solved
|
119 |
+
|
120 |
+
@solved.setter
|
121 |
+
def solved(self, solved: bool) -> None:
|
122 |
+
"""
|
123 |
+
Sets the solved flag of the thought and the compared_to_ground_truth flag.
|
124 |
+
|
125 |
+
:param solved: Whether the thought contains a solution to the problem.
|
126 |
+
:type solved: bool
|
127 |
+
"""
|
128 |
+
self.compared_to_ground_truth = True
|
129 |
+
self._solved = solved
|
src/llm_explain/utility/graph_of_thoughts/parser/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from .parser import Parser
|
src/llm_explain/utility/graph_of_thoughts/parser/parser.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
from abc import ABC, abstractmethod
|
20 |
+
from typing import Dict, List, Union
|
21 |
+
|
22 |
+
|
23 |
+
class Parser(ABC):
|
24 |
+
"""
|
25 |
+
Abstract base class that defines the interface for all parsers.
|
26 |
+
Parsers are used to parse the responses from the language models.
|
27 |
+
"""
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def parse_aggregation_answer(
|
31 |
+
self, states: List[Dict], texts: List[str]
|
32 |
+
) -> Union[Dict, List[Dict]]:
|
33 |
+
"""
|
34 |
+
Parse the response from the language model for a aggregation prompt.
|
35 |
+
|
36 |
+
:param states: The thought states used to generate the prompt.
|
37 |
+
:type states: List[Dict]
|
38 |
+
:param texts: The responses to the prompt from the language model.
|
39 |
+
:type texts: List[str]
|
40 |
+
:return: The new thought states after parsing the response from the language model.
|
41 |
+
:rtype: Union[Dict, List[Dict]]
|
42 |
+
"""
|
43 |
+
pass
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def parse_improve_answer(self, state: Dict, texts: List[str]) -> Dict:
|
47 |
+
"""
|
48 |
+
Parse the response from the language model for an improve prompt.
|
49 |
+
|
50 |
+
:param state: The thought state used to generate the prompt.
|
51 |
+
:type state: Dict
|
52 |
+
:param texts: The responses to the prompt from the language model.
|
53 |
+
:type texts: List[str]
|
54 |
+
:return: The new thought state after parsing the response from the language model.
|
55 |
+
:rtype: Dict
|
56 |
+
"""
|
57 |
+
pass
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def parse_generate_answer(self, state: Dict, texts: List[str]) -> List[Dict]:
|
61 |
+
"""
|
62 |
+
Parse the response from the language model for a generate prompt.
|
63 |
+
|
64 |
+
:param state: The thought state used to generate the prompt.
|
65 |
+
:type state: Dict
|
66 |
+
:param texts: The responses to the prompt from the language model.
|
67 |
+
:type texts: List[str]
|
68 |
+
:return: The new thought states after parsing the response from the language model.
|
69 |
+
:rtype: List[Dict]
|
70 |
+
"""
|
71 |
+
pass
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
def parse_validation_answer(self, state: Dict, texts: List[str]) -> bool:
|
75 |
+
"""
|
76 |
+
Parse the response from the language model for a validation prompt.
|
77 |
+
|
78 |
+
:param state: The thought state used to generate the prompt.
|
79 |
+
:type state: Dict
|
80 |
+
:param texts: The responses to the prompt from the language model.
|
81 |
+
:type texts: List[str]
|
82 |
+
:return: Whether the thought state is valid or not.
|
83 |
+
:rtype: bool
|
84 |
+
"""
|
85 |
+
pass
|
86 |
+
|
87 |
+
@abstractmethod
|
88 |
+
def parse_score_answer(self, states: List[Dict], texts: List[str]) -> List[float]:
|
89 |
+
"""
|
90 |
+
Parse the response from the language model for a score prompt.
|
91 |
+
|
92 |
+
:param states: The thought states used to generate the prompt.
|
93 |
+
:type states: List[Dict]
|
94 |
+
:param texts: The responses to the prompt from the language model.
|
95 |
+
:type texts: List[str]
|
96 |
+
:return: The scores for the thought states.
|
97 |
+
:rtype: List[float]
|
98 |
+
"""
|
99 |
+
pass
|
src/llm_explain/utility/graph_of_thoughts/prompter/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from .prompter import Prompter
|
src/llm_explain/utility/graph_of_thoughts/prompter/prompter.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
from abc import ABC, abstractmethod
|
20 |
+
from typing import Dict, List
|
21 |
+
|
22 |
+
|
23 |
+
class Prompter(ABC):
|
24 |
+
"""
|
25 |
+
Abstract base class that defines the interface for all prompters.
|
26 |
+
Prompters are used to generate the prompts for the language models.
|
27 |
+
"""
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def aggregation_prompt(self, state_dicts: List[Dict], **kwargs) -> str:
|
31 |
+
"""
|
32 |
+
Generate a aggregation prompt for the language model.
|
33 |
+
|
34 |
+
:param state_dicts: The thought states that should be aggregated.
|
35 |
+
:type state_dicts: List[Dict]
|
36 |
+
:param kwargs: Additional keyword arguments.
|
37 |
+
:return: The aggregation prompt.
|
38 |
+
:rtype: str
|
39 |
+
"""
|
40 |
+
pass
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
def improve_prompt(self, **kwargs) -> str:
|
44 |
+
"""
|
45 |
+
Generate an improve prompt for the language model.
|
46 |
+
The thought state is unpacked to allow for additional keyword arguments
|
47 |
+
and concrete implementations to specify required arguments explicitly.
|
48 |
+
|
49 |
+
:param kwargs: Additional keyword arguments.
|
50 |
+
:return: The improve prompt.
|
51 |
+
:rtype: str
|
52 |
+
"""
|
53 |
+
pass
|
54 |
+
|
55 |
+
@abstractmethod
|
56 |
+
def generate_prompt(self, num_branches: int, **kwargs) -> str:
|
57 |
+
"""
|
58 |
+
Generate a generate prompt for the language model.
|
59 |
+
The thought state is unpacked to allow for additional keyword arguments
|
60 |
+
and concrete implementations to specify required arguments explicitly.
|
61 |
+
|
62 |
+
:param num_branches: The number of responses the prompt should ask the LM to generate.
|
63 |
+
:type num_branches: int
|
64 |
+
:param kwargs: Additional keyword arguments.
|
65 |
+
:return: The generate prompt.
|
66 |
+
:rtype: str
|
67 |
+
"""
|
68 |
+
pass
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def validation_prompt(self, **kwargs) -> str:
|
72 |
+
"""
|
73 |
+
Generate a validation prompt for the language model.
|
74 |
+
The thought state is unpacked to allow for additional keyword arguments
|
75 |
+
and concrete implementations to specify required arguments explicitly.
|
76 |
+
|
77 |
+
:param kwargs: Additional keyword arguments.
|
78 |
+
:return: The validation prompt.
|
79 |
+
:rtype: str
|
80 |
+
"""
|
81 |
+
pass
|
82 |
+
|
83 |
+
@abstractmethod
|
84 |
+
def score_prompt(self, state_dicts: List[Dict], **kwargs) -> str:
|
85 |
+
"""
|
86 |
+
Generate a score prompt for the language model.
|
87 |
+
|
88 |
+
:param state_dicts: The thought states that should be scored,
|
89 |
+
if more than one, they should be scored together.
|
90 |
+
:type state_dicts: List[Dict]
|
91 |
+
:param kwargs: Additional keyword arguments.
|
92 |
+
:return: The score prompt.
|
93 |
+
:rtype: str
|
94 |
+
"""
|
95 |
+
pass
|
src/llm_explain/utility/prompt_utils.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
class Prompt:
|
19 |
+
|
20 |
+
def get_prompt(prompt, response):
|
21 |
+
|
22 |
+
template = f'''
|
23 |
+
You are a Responsible AI expert with extensive experience in Explainable AI for Large Language Models. Your role is to clearly generate explanations for why the Large Language Model has generated a certain response for the given prompt.
|
24 |
+
|
25 |
+
You are a helpful assistant. Do not fabricate information or provide assumptions in your response.
|
26 |
+
|
27 |
+
Given the following prompt-response pair:
|
28 |
+
|
29 |
+
Prompt: {prompt}
|
30 |
+
Response: {response}
|
31 |
+
|
32 |
+
Please assess the following metrics:
|
33 |
+
|
34 |
+
1. Sentiment: Evaluate the sentiment associated with the response using a score between -1 (negative), 0 (neutral), and 1 (positive). Additionally, explain your reasoning behind the assigned sentiment score.
|
35 |
+
|
36 |
+
2. Grammatical Mistakes: Evaluate the grammatical correctness of the response using a score between 0 (no mistakes) and 1 (more mistakes). Additionally, explain your reasoning behind the assigned grammatical mistakes score.
|
37 |
+
|
38 |
+
3. Uncertainty: Evaluate the uncertainty associated with the response for the given prompt using a score between 0 (certain) and 1 (highly uncertain). Additionally, explain your reasoning behind the assigned uncertainty score.
|
39 |
+
|
40 |
+
4. Out of Vocabulary (OOV): Assess the percentage of out-of-vocabulary words in the response relative to the prompt using a score between 0 and 100. Additionally, explain your reasoning behind the assigned OOV words percentage.
|
41 |
+
|
42 |
+
5. Coherence: Evaluate the logical flow and coherence of the response using a score between 0 (incoherent) and 1 (highly coherent). Additionally, explain your reasoning behind the assigned coherence score.
|
43 |
+
|
44 |
+
6. Relevance: Assess the relevance of the response to the given prompt using a score between 0 (irrelevant) and 1 (highly relevant). Additionally, explain your reasoning behind the assigned relevance score.
|
45 |
+
|
46 |
+
7. How did you arrive at the following response based on the prompt provided?
|
47 |
+
Prompt: {prompt}
|
48 |
+
Response: {response}
|
49 |
+
Explain the reasoning and steps taken to generate this response.
|
50 |
+
|
51 |
+
Your response should be in the following JSON format:
|
52 |
+
{{
|
53 |
+
"sentiment": {{
|
54 |
+
"score": "",
|
55 |
+
"explanation": ""
|
56 |
+
}},
|
57 |
+
"grammatical_mistakes": {{
|
58 |
+
"score": "",
|
59 |
+
"explanation": ""
|
60 |
+
}},
|
61 |
+
"uncertainty": {{
|
62 |
+
"score": "",
|
63 |
+
"explanation": ""
|
64 |
+
}},
|
65 |
+
"out_of_vocabulary": {{
|
66 |
+
"score": "",
|
67 |
+
"explanation": ""
|
68 |
+
}},
|
69 |
+
"coherence": {{
|
70 |
+
"score": "",
|
71 |
+
"explanation": ""
|
72 |
+
}},
|
73 |
+
"relevance": {{
|
74 |
+
"score": "",
|
75 |
+
"explanation": ""
|
76 |
+
}},
|
77 |
+
"reasoning": ""
|
78 |
+
}}
|
79 |
+
|
80 |
+
Do not provide any response other than the JSON object.
|
81 |
+
'''
|
82 |
+
return template
|
83 |
+
|
84 |
+
def get_token_importance_prompt(prompt):
|
85 |
+
|
86 |
+
template = f'''
|
87 |
+
You are a helpful assistant. Do not fabricate information or do not provide assumptions in your response.
|
88 |
+
|
89 |
+
Given the following prompt:
|
90 |
+
|
91 |
+
Prompt: {prompt}
|
92 |
+
|
93 |
+
Please assess the following metric:
|
94 |
+
|
95 |
+
1. Token Importance: Evaluate the importance of each token in the prompt. Calculate the importance value of each token in the given prompt using a
|
96 |
+
score between 0 (low importance) and 1 (high importance). Provide all the tokens as a list. Ensure that you give an importance score for all tokens,
|
97 |
+
and there are no empty spaces or inconsistencies in the output, which might cause issues while parsing. Make your analysis consistent so that if given
|
98 |
+
the same input again, you produce a similar output.
|
99 |
+
|
100 |
+
Your response should be in the following JSON format:
|
101 |
+
|
102 |
+
output-format:
|
103 |
+
{{
|
104 |
+
"Token": ["Each Token from input prompt"],
|
105 |
+
"Importance Score": ["The value here should be a comma-separated list of importance scores"],
|
106 |
+
"Position": ["The value here should be a comma-separated list of respective token index positions"]
|
107 |
+
}}
|
108 |
+
|
109 |
+
Do not provide any response other than the JSON object.
|
110 |
+
'''
|
111 |
+
return template
|
src/llm_explain/utility/prompts/base.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from llm_explain.utility.prompts.output_format import *
|
19 |
+
from llm_explain.utility.prompts.few_shot import *
|
20 |
+
from llm_explain.utility.prompts.instructions import *
|
21 |
+
|
22 |
+
class Prompt:
|
23 |
+
|
24 |
+
def get_classification_prompt(input_prompt):
|
25 |
+
|
26 |
+
template = f"""
|
27 |
+
Imagine you are a Responsible AI expert with experience in explaining why a model has made a decision.
|
28 |
+
Your task is to determine the sentiment of the given prompt, identify the keywords the model used to arrive at that sentiment, and provide a clear explanation of why the model classified the prompt as that sentiment.
|
29 |
+
|
30 |
+
Calculate the importance value of each token towards getting the overall sentiment from the given prompt using a score between 1 (low importance) and 100 (high importance).
|
31 |
+
Provide all the tokens as a list. Ensure that you give an importance score for all tokens, and there are no empty spaces or inconsistencies in the output, which might cause issues while parsing.
|
32 |
+
Make your analysis consistent so that if given the same input again, you produce a similar output.
|
33 |
+
|
34 |
+
similarly, provide sentiment, keywords identified to determine the sentiment and Explanation for the below given information.
|
35 |
+
|
36 |
+
Make sure the response is simple and easy to understand. Use polite language. Do not write as a third person. Do not include Certainly! at the beginning of your response, just give response.
|
37 |
+
|
38 |
+
Example Data:
|
39 |
+
{one_shot_sentiment_analysis}
|
40 |
+
|
41 |
+
Return the output only in the following JSON format. Do not output anything other than this JSON object:
|
42 |
+
{output_format_sentiment_analysis}
|
43 |
+
|
44 |
+
Task Data:
|
45 |
+
[Prompt]: {input_prompt}
|
46 |
+
"""
|
47 |
+
return template
|
48 |
+
|
49 |
+
def get_local_explanation_prompt(prompt, response):
|
50 |
+
|
51 |
+
template = f'''
|
52 |
+
You are a Responsible AI expert with extensive experience in Explainable AI for Large Language Models. Your role is to clearly generate explanations for why the Large Language Model has generated a certain response for the given prompt.
|
53 |
+
|
54 |
+
You are a helpful assistant. Do not fabricate information or provide assumptions in your response.
|
55 |
+
|
56 |
+
Given the following prompt-response pair:
|
57 |
+
|
58 |
+
Prompt: {prompt}
|
59 |
+
Response: {response}
|
60 |
+
|
61 |
+
Please assess the following metrics:
|
62 |
+
|
63 |
+
1. Uncertainty: Evaluate the uncertainty associated with the response for the given prompt using a score between 0 (certain) and 95 (highly uncertain). Additionally, explain your reasoning behind the assigned score.
|
64 |
+
2. Coherence: Evaluate the logical flow and coherence of the response using a score between 0 (incoherent) and 95 (highly coherent). Additionally, explain your reasoning behind the assigned score.
|
65 |
+
|
66 |
+
Based on the score and explanation for each metric, provide a recommendation for how to change the input prompt so that the response will be better and the scores will improve. Ensure that each recommendation is concrete and actionable. If a metric has a perfect score, provide positive reinforcement or suggest maintaining the current quality.
|
67 |
+
|
68 |
+
Return the output only in the following JSON format. Do not output anything other than this JSON object:
|
69 |
+
{output_format_local_explanation}
|
70 |
+
|
71 |
+
'''
|
72 |
+
return template
|
73 |
+
|
74 |
+
def get_token_importance_prompt(prompt):
|
75 |
+
|
76 |
+
template = f'''
|
77 |
+
You are a helpful assistant. Do not fabricate information or do not provide assumptions in your response.
|
78 |
+
|
79 |
+
Given the following prompt:
|
80 |
+
|
81 |
+
Prompt: {prompt}
|
82 |
+
|
83 |
+
Please assess the following metric:
|
84 |
+
1. Token Importance: Evaluate the importance of each token in the prompt. Calculate the importance value of each token in the given prompt using a
|
85 |
+
score between 0 (low importance) and 1 (high importance). Provide all the tokens as a list. Ensure that you give an importance score for all tokens,
|
86 |
+
and there are no empty spaces or inconsistencies in the output, which might cause issues while parsing. Make your analysis consistent so that if given
|
87 |
+
the same input again, you produce a similar output.
|
88 |
+
|
89 |
+
Return the output only in the following JSON format. Do not output anything other than this JSON object:
|
90 |
+
{output_format_token_importance}
|
91 |
+
'''
|
92 |
+
return template
|
93 |
+
|
94 |
+
def get_tone_prediction_prompt(response):
|
95 |
+
|
96 |
+
template = f'''
|
97 |
+
You are a detail-oriented LLM that pays close attention to the nuances of language.
|
98 |
+
You will be given a text and your job is to analyze its tone.
|
99 |
+
|
100 |
+
Specifically, you need to consider the following tones and identify which tone is most appropriate for the given text:
|
101 |
+
|
102 |
+
Formal: Professional, respectful, objective (e.g., scientific reports, business emails)
|
103 |
+
Informal: Casual, conversational, friendly (e.g., text messages, social media posts)
|
104 |
+
Informative: Primarily focused on conveying information clearly and concisely (e.g., news reports, summaries)
|
105 |
+
Positive: Happy, optimistic, encouraging (e.g., motivational speeches, congratulations)
|
106 |
+
Negative: Sad, angry, frustrated (e.g., complaints, critical reviews)
|
107 |
+
Neutral: Objective, unbiased, unemotional (e.g., factual summaries, news reports)
|
108 |
+
Humorous: Funny, witty, sarcastic (e.g., jokes, lighthearted stories)
|
109 |
+
Dramatic: Suspenseful, exciting, intense (e.g., fictional narratives, descriptions of events)
|
110 |
+
Inspiring: Uplifting, motivating, thought-provoking (e.g., speeches, self-help content)
|
111 |
+
Persuasive: Trying to convince the reader of something (e.g., marketing copy, arguments)
|
112 |
+
Empathetic: Understanding and supportive (e.g., responses to someone going through a tough time)
|
113 |
+
Authoritative: Confident, knowledgeable (e.g., expert opinions, instructions)
|
114 |
+
|
115 |
+
Based on the score and explanation for each metric, provide a recommendation for how to change the input prompt so that the response will be better and the scores will improve. Ensure that each recommendation is concrete and actionable. If a metric has a perfect score, provide positive reinforcement or suggest maintaining the current quality.
|
116 |
+
|
117 |
+
Example Data:
|
118 |
+
{few_shot_examples_tone_analysis}
|
119 |
+
|
120 |
+
Return the output only in the following JSON format. Do not output anything other than this JSON object:
|
121 |
+
{output_format_tone_analysis}
|
122 |
+
|
123 |
+
Task Data:
|
124 |
+
[Response]: {response}
|
125 |
+
'''
|
126 |
+
return template
|
127 |
+
|
128 |
+
def get_coherehce_prompt(response):
|
129 |
+
|
130 |
+
template = f"""
|
131 |
+
You are a detail-oriented LLM which pays close attention to the details. You are given a text and your job is to evaluate the quality of the provided text, focusing on the coherence aspect.
|
132 |
+
|
133 |
+
Coherence is the quality of the text that makes it logical and consistent. It is important that the text is well-organized and the ideas are connected in a clear and meaningful way. A coherent text is easy to follow and understand.
|
134 |
+
|
135 |
+
Please provide a score on the scale of 1-5, with 1 meaning that the text is completely incoherent and the elements in the text do not stitch together to produce meaningful text, and 5 meaning that the text is completely coherent and the elements in the text stitch together to produce meaningful text.
|
136 |
+
|
137 |
+
Example Data.
|
138 |
+
{LANGUAGE_COHERENCE_FEW_SHOT__COT}
|
139 |
+
|
140 |
+
First, analyze the text and determine how fluent and natural sounding it is. Consider the structure, connectivity of ideas, and overall readability of the text. Write down step-by-step reasoning to make sure that your conclusion is correct.
|
141 |
+
|
142 |
+
{CHAIN_OF_THOUGHT}
|
143 |
+
|
144 |
+
Return the output only in the corresponding JSON format. Do not output anything other than this JSON object:
|
145 |
+
{LANGUAGE_COHERENCE_OUTPUT_FORMAT__COT}
|
146 |
+
|
147 |
+
Task data.
|
148 |
+
[Resposne]: {response}
|
149 |
+
"""
|
150 |
+
return template
|
151 |
+
|
152 |
+
def get_response_revelance_prompt(prompt, response):
|
153 |
+
|
154 |
+
template = f"""
|
155 |
+
You are a detail-oriented LLM which pays close attention to the details, checks for consistency, and is adept at identifying logical fallacies, incorrect assumptions, or other errors in reasoning.
|
156 |
+
Your task is to determine the degree of irrelevant information present in the given response.
|
157 |
+
|
158 |
+
Example Data.
|
159 |
+
{RESPONSE_RELEVANCE_FEW_SHOT__COT}
|
160 |
+
|
161 |
+
For the given task data, carefully examine the response and assess if it has any additional irrelevant information or not. Don't focus on aspects like style, grammar, or punctuation.
|
162 |
+
Assign a score between 0 and 1, where 0 indicates that the response is completely irrelevant to the prompt, and 1 indicates that the response is highly relevant to the prompt.
|
163 |
+
{CHAIN_OF_THOUGHT}
|
164 |
+
|
165 |
+
Return the output only in the corresponding JSON format. Do not output anything other than this JSON object:
|
166 |
+
{RESPONSE_RELEVANCE_OUTPUT_FORMAT__COT}
|
167 |
+
|
168 |
+
Task Data.
|
169 |
+
[Question]: {prompt}
|
170 |
+
[Response]: {response}
|
171 |
+
[Output]:
|
172 |
+
"""
|
173 |
+
|
174 |
+
return template
|
175 |
+
|
176 |
+
def generate_facts_prompt(prompt, response, current_date):
|
177 |
+
|
178 |
+
template = f"""
|
179 |
+
You are given a response along with its question. For the given task data, please breakdown the response into independent
|
180 |
+
facts. A fact is a sentence that is true and can only be stated from the response. Facts should not depend on each another
|
181 |
+
and must not convey the same information. While generating facts, ensure that the facts are contextually mentioned and
|
182 |
+
do not begin with pronouns like 'He,' 'She,' or references to third-party entities. Limit to 5 facts in the output.
|
183 |
+
|
184 |
+
Response may contain information that is not asked in Question, consider only required information in Response that is
|
185 |
+
relevant to the Question.
|
186 |
+
|
187 |
+
Example Data.
|
188 |
+
{FACT_GENERATE_FEW_SHOT}
|
189 |
+
|
190 |
+
Return the output only in the corresponding JSON format. Do not output anything other than this JSON object:
|
191 |
+
{FACT_GENERATE_OUTPUT_FORMAT}
|
192 |
+
|
193 |
+
Task Data.
|
194 |
+
[Question]: {prompt}
|
195 |
+
[Response]: {response}
|
196 |
+
[Output]:
|
197 |
+
"""
|
198 |
+
|
199 |
+
return template
|
200 |
+
|
201 |
+
def evaluate_facts_prompt(facts, context, prompt):
|
202 |
+
|
203 |
+
template = f"""
|
204 |
+
You are a detail-oriented LLM whose task is to determine if the given facts or questions are supported by the given context
|
205 |
+
and prompt.
|
206 |
+
Each fact or question is separated by the following symbol: "#".
|
207 |
+
|
208 |
+
For the given task data, go over each fact or question sentence one by one, and write down your judgement.
|
209 |
+
If it is a question then answer the question based on the context and prompt.
|
210 |
+
Use important dates if any available in the context to make a better judgement.
|
211 |
+
If it is a fact, determine if the fact is supported by both context and prompt.
|
212 |
+
|
213 |
+
Before answering, reason in a step-by-step manner to provide your final judgement.
|
214 |
+
If the reasoning is clear then give judgement as "yes" or "no" otherwise give judgement as "unclear".
|
215 |
+
|
216 |
+
Example Data.
|
217 |
+
{FACT_EVAL_FEW_SHOT__COT}
|
218 |
+
|
219 |
+
Return the output only in the corresponding JSON format. Do not output anything other than this JSON object:
|
220 |
+
{FACT_EVALUATE_OUTPUT_FORMAT__COT}
|
221 |
+
|
222 |
+
Task Data.
|
223 |
+
[Prompt]: {prompt}
|
224 |
+
[Facts]: {facts}
|
225 |
+
[Context]: {context}
|
226 |
+
[Output]:
|
227 |
+
"""
|
228 |
+
|
229 |
+
return template
|
230 |
+
|
231 |
+
def filter_facts_prompt(prompt, facts):
|
232 |
+
|
233 |
+
FACT_FILTER_PROMPT_TEMPLATE = f"""
|
234 |
+
You are provided with a list of facts generated from a response to a specific question. Your task is to filter and retain only those
|
235 |
+
facts that are directly relevant to the question. Ignore any facts that do not pertain to the original question.
|
236 |
+
|
237 |
+
While filtering facts, ensure that the facts are contextually mentioned and
|
238 |
+
do not begin with pronouns like 'He,' 'She,' or references to third-party entities. If any such facts are identified, rewrite them.
|
239 |
+
|
240 |
+
Focus on identifying and selecting the facts that specifically answer the question asked, while discarding any irrelevant
|
241 |
+
or off-topic information.
|
242 |
+
|
243 |
+
Example Data:
|
244 |
+
{FACT_FILTER_FEW_SHOT}
|
245 |
+
|
246 |
+
Return the output in the specified JSON format. If no relevant facts are found, return an empty list [].
|
247 |
+
|
248 |
+
Task Data:
|
249 |
+
[Question]: {prompt}
|
250 |
+
[Response Facts]: {facts}
|
251 |
+
|
252 |
+
[Output]:
|
253 |
+
"""
|
254 |
+
|
255 |
+
return FACT_FILTER_PROMPT_TEMPLATE
|
256 |
+
|
257 |
+
def summarize_prompt(qa_dict_list):
|
258 |
+
|
259 |
+
SUMMARIZATION_PROMPT_TEMPLATE = f"""
|
260 |
+
You are provided with a list of JSON objects, each containing a 'question' and an 'answer'. The answer is obtained from
|
261 |
+
Google Search API and is a detailed response to the corresponding question.
|
262 |
+
|
263 |
+
Your task is to create a separate summary for each question-answer pair, preserving the context and tense of the original
|
264 |
+
question and answer. Don't mention anything other than what is given in answer.
|
265 |
+
|
266 |
+
Ensure that:
|
267 |
+
- Each summary is in a separate paragraph.
|
268 |
+
- There is a one-line space between each paragraph.
|
269 |
+
- The tense and context of the original question and answer are maintained accurately.
|
270 |
+
|
271 |
+
Ensure that there is no fabricated or hallucinated information in your response and make sure that there are no conflicting statements
|
272 |
+
from one para to another in your summary. Do not mention paragraph or such type of words in your response, just summarize and provide answers.
|
273 |
+
|
274 |
+
Task Data:
|
275 |
+
[Input]:
|
276 |
+
{qa_dict_list}
|
277 |
+
|
278 |
+
[Output]:
|
279 |
+
"""
|
280 |
+
|
281 |
+
return SUMMARIZATION_PROMPT_TEMPLATE
|
src/llm_explain/utility/prompts/few_shot.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
# Sentiment Analysis
|
19 |
+
one_shot_sentiment_analysis = """
|
20 |
+
Prompt: Unfortunately the movie served with hacked script but the actors performed well
|
21 |
+
{
|
22 |
+
"Sentiment": "Negative",
|
23 |
+
"Keywords": ['unfortunately', 'hacked'],
|
24 |
+
"Explanation": The model predicted the input as Negative because it anchored or focused on the word "unfortunately" and "hacked" in your input. In this context, the model likely learned that when someone mentions a hacked script or unfortunately, it usually indicates a negative opinion about a movie. Despite the actors performing well, the emphasis on the script being hacked led the model to classify the overall sentiment as negative. This shows that the model paid close attention to specific keywords, like "unfortunately" and "hacked" to make its prediction.
|
25 |
+
"token_importance_mapping": [
|
26 |
+
{"token": "Unfortunately", "importance_score": 0.8, "position": 0},
|
27 |
+
{"token": "movie", "importance_score": 0.5, "position": 1},
|
28 |
+
{"token": "served", "importance_score": 0.3, "position": 2},
|
29 |
+
{"token": "with", "importance_score": 0.2, "position": 3},
|
30 |
+
{"token": "hacked", "importance_score": 0.9, "position": 4},
|
31 |
+
{"token": "script", "importance_score": 0.7, "position": 5},
|
32 |
+
{"token": "but", "importance_score": 0.4, "position": 6},
|
33 |
+
{"token": "the", "importance_score": 0.1, "position": 7},
|
34 |
+
{"token": "actors", "importance_score": 0.6, "position": 8},
|
35 |
+
{"token": "performed", "importance_score": 0.5, "position": 9},
|
36 |
+
{"token": "well", "importance_score": 0.7, "position": 10}
|
37 |
+
]
|
38 |
+
}
|
39 |
+
"""
|
40 |
+
|
41 |
+
# Tone Analysis
|
42 |
+
few_shot_examples_tone_analysis = """
|
43 |
+
[Response]: Absolutely! I'm thrilled about the opportunity and eager to learn more about the position.
|
44 |
+
[Output]:
|
45 |
+
{
|
46 |
+
"Reasoning": "The text uses positive and enthusiastic language, along with keywords like 'thrilled' and 'eager,' indicating a strong interest in the position.",
|
47 |
+
"Tone": ['Positive', 'Formal']
|
48 |
+
"Score": 0.8
|
49 |
+
"Recommendation":
|
50 |
+
}
|
51 |
+
|
52 |
+
[Response]: Ugh, another interview? Why do I even bother?
|
53 |
+
[Output]:
|
54 |
+
{
|
55 |
+
"Reasoning": "The text expresses negativity and frustration with keywords like 'ugh' and 'bother.'",
|
56 |
+
"Tone": ['Negative', 'Informal']
|
57 |
+
"Score": 0.2
|
58 |
+
"Recommendation":
|
59 |
+
}
|
60 |
+
|
61 |
+
[Response]: Congratulations on your promotion! You absolutely deserve it!
|
62 |
+
[Output]:
|
63 |
+
{
|
64 |
+
"Reasoning": "The text uses positive and celebratory language with emojis and encouraging words.",
|
65 |
+
"Tone": ['Positive', 'Informal']
|
66 |
+
"Score": 0.5
|
67 |
+
"Recommendation":
|
68 |
+
}
|
69 |
+
"""
|
70 |
+
|
71 |
+
# Response Coherence
|
72 |
+
LANGUAGE_COHERENCE_FEW_SHOT__COT = """
|
73 |
+
[Response]: Exercise is beneficial for both physical and mental health. It strengthens the body and uplifts the mind.
|
74 |
+
[Output]:
|
75 |
+
{
|
76 |
+
"Reasoning": "The text is coherent and effectively conveys the message with clear organization of ideas.",
|
77 |
+
"Score": 5
|
78 |
+
}
|
79 |
+
|
80 |
+
[Response]: Regular exercise contributes to overall well-being by enhancing physical strength and mental clarity.
|
81 |
+
[Output]:
|
82 |
+
{
|
83 |
+
"Reasoning": "The text maintains coherence by linking ideas logically, providing a clear flow of information.",
|
84 |
+
"Score": 4
|
85 |
+
}
|
86 |
+
|
87 |
+
[Response]: Exercise good. Health. Make body strong. Help mind. Benefits many.
|
88 |
+
[Output]:
|
89 |
+
{
|
90 |
+
"Reasoning": "The text lacks coherence, as it presents fragmented ideas without clear connections.",
|
91 |
+
"Score": 2
|
92 |
+
}
|
93 |
+
"""
|
94 |
+
|
95 |
+
# Response Relevance
|
96 |
+
RESPONSE_RELEVANCE_FEW_SHOT__COT = """
|
97 |
+
[Question]: Who is Lionel Messi?
|
98 |
+
[Response]: Lionel Andrés Messi (born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team. Widely regarded as one of the greatest players of all time, Messi has won a record seven Ballon d'Or awards] and a record six European Golden Shoes, and in 2020 he was named to the Ballon d'Or Dream Team. Until leaving the club in 2021, he had spent his entire professional career with Barcelona, where he won a club-record 34 trophies, including ten La Liga titles, seven Copa del Rey titles and the UEFA Champions League four times.
|
99 |
+
[Output]:
|
100 |
+
{
|
101 |
+
"Reasoning": "While the given response provides information about the birth data, nationality and occupation of Lionel Messi, it includes some irrelevant details about Messi's career such as association to multiple clubs and trophies won.",
|
102 |
+
"Score": 0.5
|
103 |
+
}
|
104 |
+
|
105 |
+
[Question]: Who is Lionel Messi?
|
106 |
+
[Response]: Lionel Andrés Messi (born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team. Widely regarded as one of the greatest players of all time, Messi has won a record seven Ballon d'Or awards] and a record six European Golden Shoes, and in 2020 he was named to the Ballon d'Or Dream Team. Until leaving the club in 2021, he had spent his entire professional career with Barcelona, where he won a club-record 34 trophies, including ten La Liga titles, seven Copa del Rey titles and the UEFA Champions League four times. During his time with Barcelona, Barcelona featured star players like Neymar, Andres Iniesta and was managed by Luis Enrique.
|
107 |
+
[Output]:
|
108 |
+
{
|
109 |
+
"Reasoning": "While the given response provides information about the birth data, nationality and occupation of Lionel Messi, it includes a lot of irrelevant inforamtion such as details about Messi's career and Barcelona club.",
|
110 |
+
"Choice": 0.0
|
111 |
+
}
|
112 |
+
"""
|
113 |
+
|
114 |
+
# Fact Generation
|
115 |
+
FACT_GENERATE_FEW_SHOT = """
|
116 |
+
[Question]: Which is the tallest monument in Paris?
|
117 |
+
[Response]: The Eiffel Tower, located in Paris, is one of the most visited monuments in the world. It was named after the engineer Gustave Eiffel, whose company designed and built the tower. Constructed from 1887 to 1889, it was initially criticized by some of France's leading artists and intellectuals.
|
118 |
+
[Output]:
|
119 |
+
[
|
120 |
+
{
|
121 |
+
"Fact": "The Eiffel Tower is located in Paris."
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"Fact": "The Eiffel Tower is the tallest structure in Paris."
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"Fact": "The Eiffel Tower is very old."
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"Fact": "The Eiffel Tower is very old."
|
131 |
+
},
|
132 |
+
]
|
133 |
+
|
134 |
+
[Question]: who is the Chief minister of Andhra Pradesh ?
|
135 |
+
[Response]: As of my last training data in October 2021, the Chief Minister of Andhra Pradesh is Yeduguri Sandinti Jaganmohan Reddy. He assumed office on 30th May 2019. He is a member of the Yuvajana Sramika Rythu Congress Party (YSR Congress Party). However, you should verify from a reliable source for the most current information, as political positions can change.
|
136 |
+
[Output]:
|
137 |
+
[
|
138 |
+
{
|
139 |
+
"Fact": "The Chief Minister of Andhra Pradesh is Yeduguri Sandinti Jaganmohan Reddy."
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"Fact": "Yeduguri Sandinti Jaganmohan Reddy assumed office on 30th May 2019."
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"Fact": "Yeduguri Sandinti Jaganmohan Reddy is a member of the Yuvajana Sramika Rythu Congress Party (YSR Congress Party)."
|
146 |
+
}
|
147 |
+
]
|
148 |
+
|
149 |
+
[Question]: Is Leaning Tower of Pisa, which is located in Italy, the oldest monument in Europe?
|
150 |
+
[Response]: No
|
151 |
+
[Output]:
|
152 |
+
{
|
153 |
+
"Fact": "The Leaning Tower of Pisa is not the oldest monument in Europe."
|
154 |
+
}
|
155 |
+
"""
|
156 |
+
|
157 |
+
# Factuality Evaluation
|
158 |
+
FACT_EVAL_FEW_SHOT__COT = """
|
159 |
+
[Facts]: ["1. The Eiffel Tower is located in Paris.", "2. The Eiffel Tower is the tallest structure in Paris.", "3. The Eiffel Tower is very old."]
|
160 |
+
[Context]: The Eiffel Tower, located in Paris, is one of the most visited monuments in the world. It was named after the engineer Gustave Eiffel, whose company designed and built the tower. Constructed from 1887 to 1889, it was initially criticized by some of France's leading artists and intellectuals.
|
161 |
+
[Output]:
|
162 |
+
{
|
163 |
+
"Result": [
|
164 |
+
{
|
165 |
+
"Fact": "1. The Eiffel Tower is located in Paris.",
|
166 |
+
"Reasoning": "The context explicity states that Paris, one of the most visited monuments in the world is located in Paris.",
|
167 |
+
"Judgement": "yes"
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"Fact": "2. The Eiffel Tower is the tallest structure in Paris.",
|
171 |
+
"Reasoning": "While the context speaks about the popularity of Effiel Tower, it has no mention about its height or whether it is tallest or not.",
|
172 |
+
"Judgement": "no"
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"Fact": "3. The Eiffel Tower is very old.",
|
176 |
+
"Reasoning": "While the context mentions that the Eiffel Tower was built in 1880s, it doesn't clarify what very old means.",
|
177 |
+
"Judgement": "unclear"
|
178 |
+
},
|
179 |
+
]
|
180 |
+
}
|
181 |
+
"""
|
182 |
+
|
183 |
+
# Filter Facts
|
184 |
+
FACT_FILTER_FEW_SHOT = """
|
185 |
+
[Question]: What are the main components of a computer?
|
186 |
+
[Response Facts]:
|
187 |
+
[
|
188 |
+
"A computer has a CPU, which is the brain of the computer.",
|
189 |
+
"Computers are widely used in schools and offices.",
|
190 |
+
"A computer's motherboard connects all the components.",
|
191 |
+
"The CPU can be overclocked for better performance.",
|
192 |
+
"Computers were invented in the 20th century."
|
193 |
+
]
|
194 |
+
|
195 |
+
[Output]:
|
196 |
+
[
|
197 |
+
"A computer has a CPU, which is the brain of the computer.",
|
198 |
+
"A computer's motherboard connects all the components.",
|
199 |
+
"The CPU can be overclocked for better performance."
|
200 |
+
]
|
201 |
+
|
202 |
+
[Question]: What are the benefits of exercise?
|
203 |
+
[Response Facts]:
|
204 |
+
[
|
205 |
+
"Regular exercise improves cardiovascular health.",
|
206 |
+
"Many people exercise in the morning.",
|
207 |
+
"Exercise helps maintain a healthy weight.",
|
208 |
+
"A balanced diet is also important for health.",
|
209 |
+
"Exercise can reduce the risk of chronic diseases."
|
210 |
+
]
|
211 |
+
|
212 |
+
[Output]:
|
213 |
+
[
|
214 |
+
"Regular exercise improves cardiovascular health.",
|
215 |
+
"Exercise helps maintain a healthy weight.",
|
216 |
+
"Exercise can reduce the risk of chronic diseases."
|
217 |
+
]
|
218 |
+
"""
|
src/llm_explain/utility/prompts/instructions.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
CHAIN_OF_THOUGHT = (
|
19 |
+
"Before answering, reason in a step-by-step manner as to get the right answer."
|
20 |
+
)
|
src/llm_explain/utility/prompts/output_format.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
# Sentiment Analysis
|
19 |
+
output_format_sentiment_analysis = """
|
20 |
+
{
|
21 |
+
"Sentiment": ["Sentiment of the given prompt"]
|
22 |
+
"Keywords": [List of contributing keywords for the sentiment]
|
23 |
+
"Explanation": ["Explanation for the sentiment"]
|
24 |
+
"token_importance_mapping": [Dictionary of token importance mapping]
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
|
28 |
+
# Local Explanation
|
29 |
+
output_format_local_explanation = """
|
30 |
+
{
|
31 |
+
"uncertainty": {
|
32 |
+
"score": Integer,
|
33 |
+
"explanation": "",
|
34 |
+
"recommendation": ""
|
35 |
+
},
|
36 |
+
"coherence": {
|
37 |
+
"score": Integer,
|
38 |
+
"explanation": "",
|
39 |
+
"recommendation": ""
|
40 |
+
}
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
|
44 |
+
# Token Importance
|
45 |
+
output_format_token_importance = """
|
46 |
+
{
|
47 |
+
"Token": ["Each Token from input prompt"],
|
48 |
+
"Importance Score": ["The value here should be a comma-separated list of importance scores"],
|
49 |
+
"Position": ["The value here should be a comma-separated list of respective token index positions"]
|
50 |
+
}
|
51 |
+
"""
|
52 |
+
|
53 |
+
# Tone Analysis
|
54 |
+
output_format_tone_analysis = """
|
55 |
+
{
|
56 |
+
"Reasoning": [Reasoning], # Reasoning to critique the tone of the response, start with "Your response...."
|
57 |
+
"Tone": [Tones]
|
58 |
+
"Score": [Score]
|
59 |
+
"Recommendation": [Recommendation]
|
60 |
+
}
|
61 |
+
"""
|
62 |
+
|
63 |
+
# Coherence
|
64 |
+
LANGUAGE_COHERENCE_OUTPUT_FORMAT__COT = """
|
65 |
+
{
|
66 |
+
"Reasoning": [Reasoning], # Reasoning to critique the coherence of the response,
|
67 |
+
"Score": [Score], # Score between 1 to 5, to evaluate the coherence of the response
|
68 |
+
}
|
69 |
+
"""
|
70 |
+
|
71 |
+
# Response Relevance
|
72 |
+
RESPONSE_RELEVANCE_OUTPUT_FORMAT__COT = """
|
73 |
+
{
|
74 |
+
"Reasoning": [Reasoning], # Reasoning to determine the conciseness of the response for answering the query,
|
75 |
+
"Score": [Score], # Score assigned for the relevance of the response to the query
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Fact Generation
|
79 |
+
FACT_GENERATE_OUTPUT_FORMAT = """
|
80 |
+
{
|
81 |
+
"Facts": [ # List of all the facts
|
82 |
+
{
|
83 |
+
[1st Fact], # 1st fact being analysed
|
84 |
+
},
|
85 |
+
{
|
86 |
+
[2nd Fact], # 2nd fact being analysed
|
87 |
+
},
|
88 |
+
... # Do for all the facts
|
89 |
+
]
|
90 |
+
}
|
91 |
+
In case of no facts, return an empty list, i.e. [].
|
92 |
+
"""
|
93 |
+
|
94 |
+
# Factuality Evaluation
|
95 |
+
FACT_EVALUATE_OUTPUT_FORMAT__COT = """
|
96 |
+
{
|
97 |
+
"Result": [ # List containing data for all the facts
|
98 |
+
{
|
99 |
+
"Fact": [1st Fact], # 1st fact being analysed,
|
100 |
+
"Reasoning": [Reasoning for 1st Fact], # Reasoning to determine if the 1st fact can be verified from the context or not,
|
101 |
+
"Judgement": [Judgement for 1st Fact] # Judgement for 1st fact. Select one of the three - "yes", "unclear" or "no",
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"Fact": [2nd Fact], # 2nd fact being analysed,
|
105 |
+
"Reasoning": [Reasoning for 2nd Fact], # Reasoning to determine if the 2nd fact can be verified from the context or not,
|
106 |
+
"Judgement": [Judgement for 2nd Fact] # Judgement for 2nd fact. Select one of the three - "yes", "unclear" or "no",
|
107 |
+
},
|
108 |
+
... # Do for all the facts
|
109 |
+
]
|
110 |
+
}
|
111 |
+
"""
|
112 |
+
|
113 |
+
# Filter Facts
|
114 |
+
FACT_FILTER_OUTPUT_FORMAT = """
|
115 |
+
{
|
116 |
+
"RelevantFacts": [
|
117 |
+
"[Relevant Fact 1]",
|
118 |
+
"[Relevant Fact 2]",
|
119 |
+
...
|
120 |
+
]
|
121 |
+
}
|
122 |
+
"""
|
src/llm_explain/utility/query_serper.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
"""Util that calls Google Search using the Serper.dev API."""
|
19 |
+
from llm_explain.config.logger import CustomLogger
|
20 |
+
import asyncio
|
21 |
+
import aiohttp
|
22 |
+
import os
|
23 |
+
from dotenv import load_dotenv
|
24 |
+
load_dotenv()
|
25 |
+
|
26 |
+
log = CustomLogger()
|
27 |
+
|
28 |
+
class GoogleSerperAPIWrapper():
|
29 |
+
"""Wrapper around the Serper.dev Google Search API. You can create a free API key at https://serper.dev.
|
30 |
+
To use, you should have the environment variable ``SERPER_API_KEY`` set with your API key, or
|
31 |
+
pass `serper_api_key` as a named parameter to the constructor.
|
32 |
+
"""
|
33 |
+
def __init__(self, snippet_cnt = 10) -> None:
|
34 |
+
self.k = snippet_cnt
|
35 |
+
self.gl = "us"
|
36 |
+
self.hl = "en"
|
37 |
+
self.serper_api_key = os.environ.get("SERPER_KEY", None)
|
38 |
+
assert self.serper_api_key is not None, "Please set the SERPER_API_KEY environment variable."
|
39 |
+
assert self.serper_api_key != '', "Please set the SERPER_API_KEY environment variable."
|
40 |
+
|
41 |
+
async def _google_serper_search_results(self, session, search_term: str, gl: str, hl: str) -> dict:
|
42 |
+
headers = {
|
43 |
+
"X-API-KEY": self.serper_api_key or "",
|
44 |
+
"Content-Type": "application/json",
|
45 |
+
}
|
46 |
+
params = {"q": search_term, "gl": gl, "hl": hl}
|
47 |
+
try:
|
48 |
+
async with session.post(
|
49 |
+
"https://google.serper.dev/search", headers=headers, params=params, raise_for_status=True
|
50 |
+
) as response:
|
51 |
+
return await response.json()
|
52 |
+
except aiohttp.ClientError as e:
|
53 |
+
log.error(f"HTTP request failed: {e}")
|
54 |
+
raise
|
55 |
+
except Exception as e:
|
56 |
+
log.error(f"An error occurred: {e}")
|
57 |
+
raise
|
58 |
+
|
59 |
+
def _parse_results(self, results):
|
60 |
+
snippets = []
|
61 |
+
|
62 |
+
try:
|
63 |
+
if results.get("answerBox"):
|
64 |
+
answer_box = results.get("answerBox", {})
|
65 |
+
if answer_box.get("answer"):
|
66 |
+
element = {"content":answer_box.get("answer"),"source":"None"}
|
67 |
+
return [element]
|
68 |
+
# snippets.append(element)
|
69 |
+
elif answer_box.get("snippet"):
|
70 |
+
element = {"content":answer_box.get("snippet").replace("\n", " "),"source":"None"}
|
71 |
+
return [element]
|
72 |
+
# snippets.append(element)
|
73 |
+
elif answer_box.get("snippetHighlighted"):
|
74 |
+
element = {"content":answer_box.get("snippetHighlighted"),"source":"None"}
|
75 |
+
return [element]
|
76 |
+
# snippets.append(element)
|
77 |
+
|
78 |
+
if results.get("knowledgeGraph"):
|
79 |
+
kg = results.get("knowledgeGraph", {})
|
80 |
+
title = kg.get("title")
|
81 |
+
entity_type = kg.get("type")
|
82 |
+
if entity_type:
|
83 |
+
element = {"content":f"{title}: {entity_type}","source":"None"}
|
84 |
+
snippets.append(element)
|
85 |
+
description = kg.get("description")
|
86 |
+
if description:
|
87 |
+
element = {"content":description,"source":"None"}
|
88 |
+
snippets.append(element)
|
89 |
+
for attribute, value in kg.get("attributes", {}).items():
|
90 |
+
element = {"content":f"{attribute}: {value}","source":"None"}
|
91 |
+
snippets.append(element)
|
92 |
+
|
93 |
+
for result in results["organic"][: self.k]:
|
94 |
+
if "snippet" in result:
|
95 |
+
element = {"content":result["snippet"],"source":result["link"]}
|
96 |
+
snippets.append(element)
|
97 |
+
for attribute, value in result.get("attributes", {}).items():
|
98 |
+
element = {"content":f"{attribute}: {value}","source":result["link"]}
|
99 |
+
snippets.append(element)
|
100 |
+
|
101 |
+
if len(snippets) == 0:
|
102 |
+
element = {"content":"No good Google Search Result was found","source":"None"}
|
103 |
+
return [element]
|
104 |
+
|
105 |
+
# keep only the first k snippets
|
106 |
+
snippets = snippets[:int(self.k / 2)]
|
107 |
+
except AttributeError as e:
|
108 |
+
if "'ClientResponseError' object has no attribute 'get'" in str(e):
|
109 |
+
log.error("Serper API key is invalid or has expired.")
|
110 |
+
raise Exception("Serper API key is invalid or has expired.")
|
111 |
+
else:
|
112 |
+
log.error(f"An error occurred while parsing results: {e}")
|
113 |
+
raise
|
114 |
+
except Exception as e:
|
115 |
+
log.error(f"An error occurred while parsing results: {e}")
|
116 |
+
raise
|
117 |
+
|
118 |
+
return snippets
|
119 |
+
|
120 |
+
async def parallel_searches(self, search_queries, gl, hl):
|
121 |
+
async with aiohttp.ClientSession() as session:
|
122 |
+
tasks = [self._google_serper_search_results(session, query, gl, hl) for query in search_queries]
|
123 |
+
try:
|
124 |
+
search_results = await asyncio.gather(*tasks, return_exceptions=True)
|
125 |
+
except Exception as e:
|
126 |
+
log.error(f"An error occurred while running parallel searches: {e}")
|
127 |
+
raise
|
128 |
+
return search_results
|
129 |
+
|
130 |
+
async def run(self, queries):
|
131 |
+
"""Run query through GoogleSearch and parse result."""
|
132 |
+
|
133 |
+
try:
|
134 |
+
results = await self.parallel_searches(queries, gl=self.gl, hl=self.hl)
|
135 |
+
except Exception as e:
|
136 |
+
log.error(f"An error occurred while running searches: {e}")
|
137 |
+
raise
|
138 |
+
snippets_list = []
|
139 |
+
|
140 |
+
for i in range(len(results)):
|
141 |
+
snippets_list.append(self._parse_results(results[i]))
|
142 |
+
|
143 |
+
return snippets_list
|
src/llm_explain/utility/utility.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Copyright 2024 Infosys Ltd.
|
3 |
+
|
4 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"),
|
5 |
+
to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
6 |
+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
The above copyright notice and this permission notice shall be included in all copies
|
9 |
+
or substantial portions of the Software.
|
10 |
+
|
11 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
12 |
+
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE
|
13 |
+
AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
14 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
15 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
16 |
+
'''
|
17 |
+
|
18 |
+
from llm_explain.config.logger import CustomLogger
|
19 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
20 |
+
from matplotlib import pyplot as plt
|
21 |
+
from openai import AzureOpenAI
|
22 |
+
from tenacity import retry
|
23 |
+
from tqdm import tqdm
|
24 |
+
import pandas as pd
|
25 |
+
import numpy as np
|
26 |
+
import asyncio
|
27 |
+
import base64
|
28 |
+
import os
|
29 |
+
import io
|
30 |
+
|
31 |
+
from dotenv import load_dotenv
|
32 |
+
load_dotenv()
|
33 |
+
|
34 |
+
log = CustomLogger()
|
35 |
+
|
36 |
+
class Utils:
|
37 |
+
|
38 |
+
client = AzureOpenAI(
|
39 |
+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
40 |
+
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
41 |
+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
|
42 |
+
)
|
43 |
+
|
44 |
+
def normalize_vector(v):
|
45 |
+
norm = np.linalg.norm(v)
|
46 |
+
if norm == 0:
|
47 |
+
return v
|
48 |
+
return v / norm
|
49 |
+
|
50 |
+
def display_metrics(uncertainty_scores, completions, n):
|
51 |
+
try:
|
52 |
+
results = {}
|
53 |
+
structural_uncertainty = np.mean([np.mean(x) for x in uncertainty_scores['entropies']])
|
54 |
+
conceptual_uncertainty = (0.5*uncertainty_scores['mean_choice_distance']) + (0.5*np.mean([np.mean(x) for x in uncertainty_scores['distances']]))
|
55 |
+
|
56 |
+
results["overall_cosine_distance"] = uncertainty_scores['mean_choice_distance']
|
57 |
+
results["Overall_Structural_Uncertainty"] = structural_uncertainty
|
58 |
+
results["Overall_Conceptual_Uncertainty"] = conceptual_uncertainty
|
59 |
+
|
60 |
+
results["choices"] = []
|
61 |
+
|
62 |
+
for i in range(n):
|
63 |
+
choice = {}
|
64 |
+
choice_text = completions['choices'][i]['text']
|
65 |
+
entropies = uncertainty_scores['entropies'][i]
|
66 |
+
distances = uncertainty_scores['distances'][i]
|
67 |
+
|
68 |
+
logprobs = completions['choices'][i]['logprobs']['top_logprobs']
|
69 |
+
mean_entropy = np.mean(entropies)
|
70 |
+
mean_distance = np.mean(distances)
|
71 |
+
|
72 |
+
choice["mean_entropy"] = mean_entropy
|
73 |
+
choice["mean_distance"] = mean_distance
|
74 |
+
|
75 |
+
tokens = completions['choices'][i]['logprobs']['tokens']
|
76 |
+
|
77 |
+
fixed_spacing = 1
|
78 |
+
|
79 |
+
x_positions = [0]
|
80 |
+
for j in range(1, len(tokens)):
|
81 |
+
x_positions.append(x_positions[-1] + len(tokens[j-1]) + fixed_spacing)
|
82 |
+
|
83 |
+
df = pd.DataFrame({
|
84 |
+
'x': x_positions,
|
85 |
+
'y_text': [1]*len(tokens),
|
86 |
+
'y_entropy': [1.2 + entropy for entropy in entropies],
|
87 |
+
'y_distance': [1.2 + dist for dist in distances],
|
88 |
+
'tokens': tokens,
|
89 |
+
'logprobs': ['\n'.join([f"{k}: {v}" for k, v in lp.items()]) for lp in logprobs],
|
90 |
+
'entropy': entropies,
|
91 |
+
'distance': distances,
|
92 |
+
})
|
93 |
+
|
94 |
+
plt.figure(figsize=(10, 6))
|
95 |
+
plt.title(f"Choice {i+1}")
|
96 |
+
plt.plot(df['x'], df['y_entropy'], label='Entropy', color='blue')
|
97 |
+
plt.plot(df['x'], df['y_distance'], label='Distance', color='red')
|
98 |
+
plt.xlabel('Token Position')
|
99 |
+
plt.ylabel('Normalization value')
|
100 |
+
plt.legend()
|
101 |
+
|
102 |
+
buf = io.BytesIO()
|
103 |
+
plt.savefig(buf, format='png')
|
104 |
+
buf.seek(0)
|
105 |
+
|
106 |
+
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
107 |
+
|
108 |
+
choice["plot_image_base64"] = img_base64
|
109 |
+
|
110 |
+
choice['response'] = choice_text
|
111 |
+
|
112 |
+
results["choices"].append(choice)
|
113 |
+
return results
|
114 |
+
except Exception as e:
|
115 |
+
log.error(e,exc_info=True)
|
116 |
+
raise
|
117 |
+
|
118 |
+
def calculate_normalized_entropy(logprobs):
|
119 |
+
"""
|
120 |
+
Calculate the normalized entropy of a list of log probabilities.
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
logprobs (list): List of log probabilities.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
float: Normalized entropy.
|
127 |
+
"""
|
128 |
+
try:
|
129 |
+
entropy = -np.sum(np.exp(logprobs) * logprobs)
|
130 |
+
|
131 |
+
# Calculate maximum possible entropy for N tokens sampled
|
132 |
+
N = len(logprobs)
|
133 |
+
max_entropy = np.log(N)
|
134 |
+
|
135 |
+
# Normalize the entropy
|
136 |
+
normalized_entropy = entropy/max_entropy
|
137 |
+
return normalized_entropy
|
138 |
+
except Exception as e:
|
139 |
+
log.error(e,exc_info=True)
|
140 |
+
raise
|
141 |
+
|
142 |
+
|
143 |
+
async def process_token_async(i, top_logprobs_list, choice, choice_embedding, max_tokens):
|
144 |
+
|
145 |
+
|
146 |
+
"""
|
147 |
+
Asynchronously process a token to calculate its normalized entropy and mean cosine distance.
|
148 |
+
|
149 |
+
Parameters:
|
150 |
+
i (int): Token index.
|
151 |
+
top_logprobs_list (list): List of top log probabilities for each token.
|
152 |
+
choice (dict): The choice containing log probabilities and tokens.
|
153 |
+
choice_embedding (array): Embedding of the full choice.
|
154 |
+
max_tokens (int or None): Maximum number of tokens to consider for the partial string.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
tuple: Mean cosine distance and normalized entropy for the token.
|
158 |
+
"""
|
159 |
+
try:
|
160 |
+
top_logprobs = top_logprobs_list[i]
|
161 |
+
normalized_entropy = Utils.calculate_normalized_entropy(list(top_logprobs.values()))
|
162 |
+
|
163 |
+
tasks = []
|
164 |
+
|
165 |
+
# Loop through each sampled token to construct partial strings and calculate embeddings
|
166 |
+
for sampled_token in top_logprobs:
|
167 |
+
tokens_to_use = choice['logprobs']['tokens'][:i] + [sampled_token]
|
168 |
+
|
169 |
+
# Limit the number of tokens in the partial string if max_tokens is specified
|
170 |
+
if max_tokens is not None and len(tokens_to_use) > max_tokens:
|
171 |
+
tokens_to_use = tokens_to_use[-max_tokens:]
|
172 |
+
|
173 |
+
constructed_string = ''.join(tokens_to_use)
|
174 |
+
task = Utils.get_embedding(constructed_string)
|
175 |
+
|
176 |
+
tasks.append(task)
|
177 |
+
|
178 |
+
embeddings = await asyncio.gather(*tasks)
|
179 |
+
|
180 |
+
cosine_distances = []
|
181 |
+
|
182 |
+
# Calculate cosine distances between embeddings of partial strings and the full choice
|
183 |
+
for new_embedding in embeddings:
|
184 |
+
cosine_sim = cosine_similarity(new_embedding.reshape(1, -1), choice_embedding.reshape(1, -1))[0][0]
|
185 |
+
cosine_distances.append(1 - cosine_sim)
|
186 |
+
|
187 |
+
mean_distance = np.mean(cosine_distances)
|
188 |
+
|
189 |
+
return mean_distance, normalized_entropy
|
190 |
+
except Exception as e:
|
191 |
+
log.error(e,exc_info=True)
|
192 |
+
raise
|
193 |
+
|
194 |
+
def decoded_tokens(string, tokenizer):
|
195 |
+
return [tokenizer.decode([x]) for x in tokenizer.encode(string)]
|
196 |
+
|
197 |
+
def scale_importance_log(importance_scores, base=None, offset=0.0, min_percentile=0, max_percentile=100, smoothing_constant=1e-10, scaling_factor=1.0, bias=0.0):
|
198 |
+
# Extract the importance values
|
199 |
+
try:
|
200 |
+
importance_values = np.array([score[1] for score in importance_scores])
|
201 |
+
|
202 |
+
# Apply optional percentile-based clipping
|
203 |
+
if min_percentile > 0 or max_percentile < 100:
|
204 |
+
min_val = np.percentile(importance_values, min_percentile)
|
205 |
+
max_val = np.percentile(importance_values, max_percentile)
|
206 |
+
importance_values = np.clip(importance_values, min_val, max_val)
|
207 |
+
|
208 |
+
# Subtract the minimum value and add the optional offset
|
209 |
+
importance_values = importance_values - np.min(importance_values) + offset
|
210 |
+
|
211 |
+
# Add smoothing constant to ensure non-zero values
|
212 |
+
importance_values += smoothing_constant
|
213 |
+
|
214 |
+
# Apply logarithmic scaling, with an optional base
|
215 |
+
scaled_values = np.log(importance_values) if base is None else np.log(importance_values) / np.log(base)
|
216 |
+
|
217 |
+
# Apply scaling factor and bias
|
218 |
+
scaled_values = scaling_factor * scaled_values + bias
|
219 |
+
|
220 |
+
# Normalize to the range [0, 1]
|
221 |
+
scaled_values = (scaled_values - np.min(scaled_values)) / (np.max(scaled_values) - np.min(scaled_values))
|
222 |
+
|
223 |
+
# Pair the scaled values with the original tokens
|
224 |
+
scaled_importance_scores = [(token, scaled_value) for token, scaled_value in zip([score[0] for score in importance_scores], scaled_values)]
|
225 |
+
|
226 |
+
return scaled_importance_scores
|
227 |
+
except Exception as e:
|
228 |
+
log.error(e,exc_info=True)
|
229 |
+
raise
|
230 |
+
|
231 |
+
@retry
|
232 |
+
async def get_embedding(input_text):
|
233 |
+
try:
|
234 |
+
response = Utils.client.embeddings.create(
|
235 |
+
input = input_text,
|
236 |
+
model= "text-embedding-ada-002",
|
237 |
+
timeout= 4.0
|
238 |
+
)
|
239 |
+
return np.array(response.data[0].embedding)
|
240 |
+
except Exception as e:
|
241 |
+
log.error(e,exc_info=True)
|
242 |
+
raise
|
243 |
+
|
244 |
+
async def approximate_importance(perturbed_text, original_embedding, model=None, tokenizer=None):
|
245 |
+
try:
|
246 |
+
perturbed_embedding = await Utils.get_embedding(perturbed_text)
|
247 |
+
cosine_dist = 1 - cosine_similarity(original_embedding.reshape(1, -1), perturbed_embedding.reshape(1, -1))[0][0]
|
248 |
+
return cosine_dist
|
249 |
+
except Exception as e:
|
250 |
+
log.error(e,exc_info=True)
|
251 |
+
raise
|
252 |
+
|
253 |
+
async def ablated_relative_importance(input_text, tokenizer, model=None,):
|
254 |
+
try:
|
255 |
+
original_embedding = await Utils.get_embedding(input_text)
|
256 |
+
tokens = Utils.decoded_tokens(input_text, tokenizer)
|
257 |
+
importance_scores = []
|
258 |
+
|
259 |
+
with tqdm(total=len(tokens), desc="Calculating Token Importances", position=0, leave=True) as progress:
|
260 |
+
for i in range(len(tokens)):
|
261 |
+
if len(tokens[i]) < 4:
|
262 |
+
continue
|
263 |
+
perturbed_text = "".join(tokens[:i] + tokens[i+1:])
|
264 |
+
importance = await Utils.approximate_importance(perturbed_text, original_embedding, model, tokenizer)
|
265 |
+
importance_scores.append((tokens[i], importance))
|
266 |
+
progress.update(1)
|
267 |
+
|
268 |
+
return importance_scores
|
269 |
+
except Exception as e:
|
270 |
+
log.error(e,exc_info=True)
|
271 |
+
raise
|
272 |
+
|
273 |
+
def get_price_details(model: str):
|
274 |
+
'''
|
275 |
+
Returns price per tokens of the model.
|
276 |
+
|
277 |
+
Parameters:
|
278 |
+
model (str): Model name (Ex: gpt-4)
|
279 |
+
'''
|
280 |
+
prompt_price_per_1000_tokens = {
|
281 |
+
"gpt-4o": 0.0050,
|
282 |
+
"gpt-35-turbo": 0.0005,
|
283 |
+
"gpt-35-turbo-instruct": 0.0015,
|
284 |
+
"gpt4": 0.0300
|
285 |
+
}
|
286 |
+
|
287 |
+
response_price_per_1000_tokens = {
|
288 |
+
"gpt-4o": 0.0150,
|
289 |
+
"gpt-35-turbo": 0.0015,
|
290 |
+
"gpt-35-turbo-instruct": 0.0020,
|
291 |
+
"gpt4": 0.0600
|
292 |
+
}
|
293 |
+
|
294 |
+
try:
|
295 |
+
return prompt_price_per_1000_tokens[model], response_price_per_1000_tokens[model]
|
296 |
+
except KeyError:
|
297 |
+
raise ValueError(f"Model '{model}' is not found in the pricing details. Only gpt-4o, gpt-35-turbo, gpt-35-turbo-instruct & gpt4 are available. Please contact administrator")
|
298 |
+
|
299 |
+
def get_token_cost(input_tokens: int, output_tokens: int, model: str):
|
300 |
+
'''
|
301 |
+
Calculates the total cost for tokens.
|
302 |
+
|
303 |
+
Parameters:
|
304 |
+
tokens (int): Total token (Prompt tokens + Completion tokens)
|
305 |
+
model (str): Model name (Ex: gpt4)
|
306 |
+
'''
|
307 |
+
|
308 |
+
# Example pricing (this should be replaced with actual pricing from Azure documentation)
|
309 |
+
prompt_price_per_1000_tokens, response_price_per_1000_tokens = Utils.get_price_details(model)
|
310 |
+
|
311 |
+
# Calculate cost
|
312 |
+
total_cost = ((input_tokens / 1000) * prompt_price_per_1000_tokens) + ((output_tokens / 1000) * response_price_per_1000_tokens)
|
313 |
+
|
314 |
+
return {
|
315 |
+
"total_cost": total_cost
|
316 |
+
}
|