momenaca commited on
Commit
579d749
·
1 Parent(s): a925f26

add azure search vectorstore link for presse

Browse files
app.py CHANGED
@@ -4,24 +4,27 @@ import logging
4
  import gradio as gr
5
  from langchain.prompts.chat import ChatPromptTemplate
6
  from huggingface_hub import hf_hub_download, whoami
7
- from app.source.backend.llm_utils import get_llm
8
- from app.source.backend.document_store import pickle_to_document_store
9
- from app.source.backend.get_prompts import get_qa_prompts
10
- from app.source.frontend.utils import (
11
  make_html_source,
12
  make_html_presse_source,
13
  init_env,
14
  )
15
- from app.source.backend.prompt_utils import to_chat_instruction, SpecialTokens
 
 
 
16
 
17
  init_env()
18
 
19
- with open("./app/config.yaml") as f:
20
  config = yaml.full_load(f)
21
 
22
  prompts = {}
23
  for source in config["prompt_naming"]:
24
- with open(f"./app/prompt_{source}.yaml") as f:
25
  prompts[source] = yaml.full_load(f)
26
 
27
  ## Building LLM
@@ -40,8 +43,11 @@ qdrants = {
40
  )
41
  )
42
  for tab in config["prompt_naming"]
 
43
  }
44
 
 
 
45
  ## Load Prompts
46
  print("Loading Prompts")
47
  chat_qa_prompts, chat_reformulation_prompts, chat_summarize_memory_prompts = {}, {}, {}
@@ -51,6 +57,7 @@ for source, prompt in prompts.items():
51
  chat_reformulation_prompts[source] = chat_reformulation_prompt
52
  # chat_summarize_memory_prompts[source] = chat_summarize_memory_prompt
53
 
 
54
  with open("./assets/style.css", "r") as f:
55
  css = f.read()
56
 
@@ -277,18 +284,29 @@ def get_html_sources(buttons, cards):
277
  """
278
 
279
 
280
- def get_sources(outils, question, tab, qdrants=qdrants, config=config):
 
 
281
  k = config["num_document_retrieved"]
282
  min_similarity = config["min_similarity"]
283
  if tab in outils:
284
- sources = qdrants[
285
- config["source_mapping"][tab]
286
- ].similarity_search_with_relevance_scores(
287
- config["query_preprompt"]
288
- + question.replace("<p>", "").replace("</p>\n", ""),
289
- k=k,
290
- # filter=get_qdrant_filters(filters),
 
 
 
 
 
 
 
 
291
  )
 
292
  sources = [(doc, score) for doc, score in sources if score >= min_similarity]
293
 
294
  buttons_ids = list(range(len(sources)))
@@ -323,9 +341,11 @@ def get_sources(outils, question, tab, qdrants=qdrants, config=config):
323
  return "", ""
324
 
325
 
326
- def retrieve_sources(outils, *questions, qdrants=qdrants, config=config):
 
 
327
  results = [
328
- get_sources(outils, question, tab, qdrants, config)
329
  for question, tab in zip(questions, config["tabs"])
330
  ]
331
  formated_sources = [source[0] for source in results]
 
4
  import gradio as gr
5
  from langchain.prompts.chat import ChatPromptTemplate
6
  from huggingface_hub import hf_hub_download, whoami
7
+ from spinoza_project.source.backend.llm_utils import get_llm, get_vectorstore
8
+ from spinoza_project.source.backend.document_store import pickle_to_document_store
9
+ from spinoza_project.source.backend.get_prompts import get_qa_prompts
10
+ from spinoza_project.source.frontend.utils import (
11
  make_html_source,
12
  make_html_presse_source,
13
  init_env,
14
  )
15
+ from spinoza_project.source.backend.prompt_utils import (
16
+ to_chat_instruction,
17
+ SpecialTokens,
18
+ )
19
 
20
  init_env()
21
 
22
+ with open("./spinoza_project/config.yaml") as f:
23
  config = yaml.full_load(f)
24
 
25
  prompts = {}
26
  for source in config["prompt_naming"]:
27
+ with open(f"./spinoza_project/prompt_{source}.yaml") as f:
28
  prompts[source] = yaml.full_load(f)
29
 
30
  ## Building LLM
 
43
  )
44
  )
45
  for tab in config["prompt_naming"]
46
+ if tab != "Presse"
47
  }
48
 
49
+ bdd_presse = get_vectorstore("presse")
50
+
51
  ## Load Prompts
52
  print("Loading Prompts")
53
  chat_qa_prompts, chat_reformulation_prompts, chat_summarize_memory_prompts = {}, {}, {}
 
57
  chat_reformulation_prompts[source] = chat_reformulation_prompt
58
  # chat_summarize_memory_prompts[source] = chat_summarize_memory_prompt
59
 
60
+
61
  with open("./assets/style.css", "r") as f:
62
  css = f.read()
63
 
 
284
  """
285
 
286
 
287
+ def get_sources(
288
+ outils, question, tab, qdrants=qdrants, bdd_presse=bdd_presse, config=config
289
+ ):
290
  k = config["num_document_retrieved"]
291
  min_similarity = config["min_similarity"]
292
  if tab in outils:
293
+ sources = (
294
+ (
295
+ bdd_presse.similarity_search_with_relevance_scores(
296
+ question.replace("<p>", "").replace("</p>\n", ""),
297
+ k=k,
298
+ )
299
+ )
300
+ if tab == "Presse"
301
+ else qdrants[
302
+ config["source_mapping"][tab]
303
+ ].similarity_search_with_relevance_scores(
304
+ config["query_preprompt"]
305
+ + question.replace("<p>", "").replace("</p>\n", ""),
306
+ k=k,
307
+ )
308
  )
309
+
310
  sources = [(doc, score) for doc, score in sources if score >= min_similarity]
311
 
312
  buttons_ids = list(range(len(sources)))
 
341
  return "", ""
342
 
343
 
344
+ def retrieve_sources(
345
+ outils, *questions, qdrants=qdrants, bdd_presse=bdd_presse, config=config
346
+ ):
347
  results = [
348
+ get_sources(outils, question, tab, qdrants, bdd_presse, config)
349
  for question, tab in zip(questions, config["tabs"])
350
  ]
351
  formated_sources = [source[0] for source in results]
poetry.lock CHANGED
@@ -223,6 +223,70 @@ files = [
223
  [package.dependencies]
224
  cryptography = "*"
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  [[package]]
227
  name = "certifi"
228
  version = "2024.6.2"
@@ -1419,6 +1483,20 @@ files = [
1419
  {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
1420
  ]
1421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1422
  [[package]]
1423
  name = "itsdangerous"
1424
  version = "2.2.0"
@@ -1998,6 +2076,25 @@ requests = ">=2.0.0,<3"
1998
  [package.extras]
1999
  broker = ["pymsalruntime (>=0.13.2,<0.17)"]
2000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2001
  [[package]]
2002
  name = "multidict"
2003
  version = "6.0.5"
@@ -4859,4 +4956,4 @@ multidict = ">=4.0"
4859
  [metadata]
4860
  lock-version = "2.0"
4861
  python-versions = "^3.10"
4862
- content-hash = "4b2da1198ef4ee6995118810fcfe20d43f21d1eac07b71d7b71a2bab98b633e0"
 
223
  [package.dependencies]
224
  cryptography = "*"
225
 
226
+ [[package]]
227
+ name = "azure-common"
228
+ version = "1.1.28"
229
+ description = "Microsoft Azure Client Library for Python (Common)"
230
+ optional = false
231
+ python-versions = "*"
232
+ files = [
233
+ {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"},
234
+ {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"},
235
+ ]
236
+
237
+ [[package]]
238
+ name = "azure-core"
239
+ version = "1.30.2"
240
+ description = "Microsoft Azure Core Library for Python"
241
+ optional = false
242
+ python-versions = ">=3.8"
243
+ files = [
244
+ {file = "azure-core-1.30.2.tar.gz", hash = "sha256:a14dc210efcd608821aa472d9fb8e8d035d29b68993819147bc290a8ac224472"},
245
+ {file = "azure_core-1.30.2-py3-none-any.whl", hash = "sha256:cf019c1ca832e96274ae85abd3d9f752397194d9fea3b41487290562ac8abe4a"},
246
+ ]
247
+
248
+ [package.dependencies]
249
+ requests = ">=2.21.0"
250
+ six = ">=1.11.0"
251
+ typing-extensions = ">=4.6.0"
252
+
253
+ [package.extras]
254
+ aio = ["aiohttp (>=3.0)"]
255
+
256
+ [[package]]
257
+ name = "azure-identity"
258
+ version = "1.17.1"
259
+ description = "Microsoft Azure Identity Library for Python"
260
+ optional = false
261
+ python-versions = ">=3.8"
262
+ files = [
263
+ {file = "azure-identity-1.17.1.tar.gz", hash = "sha256:32ecc67cc73f4bd0595e4f64b1ca65cd05186f4fe6f98ed2ae9f1aa32646efea"},
264
+ {file = "azure_identity-1.17.1-py3-none-any.whl", hash = "sha256:db8d59c183b680e763722bfe8ebc45930e6c57df510620985939f7f3191e0382"},
265
+ ]
266
+
267
+ [package.dependencies]
268
+ azure-core = ">=1.23.0"
269
+ cryptography = ">=2.5"
270
+ msal = ">=1.24.0"
271
+ msal-extensions = ">=0.3.0"
272
+ typing-extensions = ">=4.0.0"
273
+
274
+ [[package]]
275
+ name = "azure-search-documents"
276
+ version = "11.4.0"
277
+ description = "Microsoft Azure Cognitive Search Client Library for Python"
278
+ optional = false
279
+ python-versions = ">=3.7"
280
+ files = [
281
+ {file = "azure-search-documents-11.4.0.tar.gz", hash = "sha256:599f269f106fb51e646ff426a218c21811575598e6a769b23fa4a0127c0f57e0"},
282
+ {file = "azure_search_documents-11.4.0-py3-none-any.whl", hash = "sha256:e435266dc992a3450dc475309c9475f89a4bb0e9dac838140e609d9f1c7608ac"},
283
+ ]
284
+
285
+ [package.dependencies]
286
+ azure-common = ">=1.1,<2.0"
287
+ azure-core = ">=1.28.0,<2.0.0"
288
+ isodate = ">=0.6.0"
289
+
290
  [[package]]
291
  name = "certifi"
292
  version = "2024.6.2"
 
1483
  {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
1484
  ]
1485
 
1486
+ [[package]]
1487
+ name = "isodate"
1488
+ version = "0.6.1"
1489
+ description = "An ISO 8601 date/time/duration parser and formatter"
1490
+ optional = false
1491
+ python-versions = "*"
1492
+ files = [
1493
+ {file = "isodate-0.6.1-py2.py3-none-any.whl", hash = "sha256:0751eece944162659049d35f4f549ed815792b38793f07cf73381c1c87cbed96"},
1494
+ {file = "isodate-0.6.1.tar.gz", hash = "sha256:48c5881de7e8b0a0d648cb024c8062dc84e7b840ed81e864c7614fd3c127bde9"},
1495
+ ]
1496
+
1497
+ [package.dependencies]
1498
+ six = "*"
1499
+
1500
  [[package]]
1501
  name = "itsdangerous"
1502
  version = "2.2.0"
 
2076
  [package.extras]
2077
  broker = ["pymsalruntime (>=0.13.2,<0.17)"]
2078
 
2079
+ [[package]]
2080
+ name = "msal-extensions"
2081
+ version = "1.1.0"
2082
+ description = "Microsoft Authentication Library extensions (MSAL EX) provides a persistence API that can save your data on disk, encrypted on Windows, macOS and Linux. Concurrent data access will be coordinated by a file lock mechanism."
2083
+ optional = false
2084
+ python-versions = ">=3.7"
2085
+ files = [
2086
+ {file = "msal-extensions-1.1.0.tar.gz", hash = "sha256:6ab357867062db7b253d0bd2df6d411c7891a0ee7308d54d1e4317c1d1c54252"},
2087
+ {file = "msal_extensions-1.1.0-py3-none-any.whl", hash = "sha256:01be9711b4c0b1a151450068eeb2c4f0997df3bba085ac299de3a66f585e382f"},
2088
+ ]
2089
+
2090
+ [package.dependencies]
2091
+ msal = ">=0.4.1,<2.0.0"
2092
+ packaging = "*"
2093
+ portalocker = [
2094
+ {version = ">=1.0,<3", markers = "platform_system != \"Windows\""},
2095
+ {version = ">=1.6,<3", markers = "platform_system == \"Windows\""},
2096
+ ]
2097
+
2098
  [[package]]
2099
  name = "multidict"
2100
  version = "6.0.5"
 
4956
  [metadata]
4957
  lock-version = "2.0"
4958
  python-versions = "^3.10"
4959
+ content-hash = "adbf6715ae5e4d0cd93b90c6754be4996023e0ad1296f700b738887a06123a73"
pyproject.toml CHANGED
@@ -1,10 +1,10 @@
1
  [tool.poetry]
2
- name = "spinoza-project"
3
  version = "0.1.0"
4
  description = ""
5
  authors = ["Miguel Omenaca Muro <[email protected]>"]
6
  readme = "README.md"
7
- package-mode = false
8
 
9
  [tool.poetry.dependencies]
10
  python = "^3.10"
@@ -18,6 +18,8 @@ loadenv = "^0.1.1"
18
  datasets = "^2.20.0"
19
  langchain-community = "^0.2.5"
20
  transformers = "4.39.0"
 
 
21
 
22
 
23
  [build-system]
 
1
  [tool.poetry]
2
+ name = "spinoza_project"
3
  version = "0.1.0"
4
  description = ""
5
  authors = ["Miguel Omenaca Muro <[email protected]>"]
6
  readme = "README.md"
7
+ package-mode = true
8
 
9
  [tool.poetry.dependencies]
10
  python = "^3.10"
 
18
  datasets = "^2.20.0"
19
  langchain-community = "^0.2.5"
20
  transformers = "4.39.0"
21
+ azure-search-documents = "^11.4.0"
22
+ azure-identity = "^1.17.1"
23
 
24
 
25
  [build-system]
requirements.txt CHANGED
@@ -99,6 +99,18 @@ attrs==23.2.0 ; python_version >= "3.10" and python_version < "4.0" \
99
  authlib==1.3.1 ; python_version >= "3.10" and python_version < "4.0" \
100
  --hash=sha256:7ae843f03c06c5c0debd63c9db91f9fda64fa62a42a77419fa15fbb7e7a58917 \
101
  --hash=sha256:d35800b973099bbadc49b42b256ecb80041ad56b7fe1216a362c7943c088f377
 
 
 
 
 
 
 
 
 
 
 
 
102
  certifi==2024.6.2 ; python_version >= "3.10" and python_version < "4.0" \
103
  --hash=sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516 \
104
  --hash=sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56
@@ -724,6 +736,9 @@ intel-openmp==2021.4.0 ; python_version >= "3.10" and python_version < "4.0" and
724
  --hash=sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9 \
725
  --hash=sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e \
726
  --hash=sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f
 
 
 
727
  itsdangerous==2.2.0 ; python_version >= "3.10" and python_version < "4.0" \
728
  --hash=sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef \
729
  --hash=sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173
@@ -980,6 +995,9 @@ mkl==2021.4.0 ; python_version >= "3.10" and python_version < "4.0" and platform
980
  mpmath==1.3.0 ; python_version >= "3.10" and python_version < "4.0" \
981
  --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
982
  --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
 
 
 
983
  msal==1.28.1 ; python_version >= "3.10" and python_version < "4.0" \
984
  --hash=sha256:563c2d70de77a2ca9786aab84cb4e133a38a6897e6676774edc23d610bfc9e7b \
985
  --hash=sha256:d72bbfe2d5c2f2555f4bc6205be4450ddfd12976610dd9a16a9ab0f05c68b64d
 
99
  authlib==1.3.1 ; python_version >= "3.10" and python_version < "4.0" \
100
  --hash=sha256:7ae843f03c06c5c0debd63c9db91f9fda64fa62a42a77419fa15fbb7e7a58917 \
101
  --hash=sha256:d35800b973099bbadc49b42b256ecb80041ad56b7fe1216a362c7943c088f377
102
+ azure-common==1.1.28 ; python_version >= "3.10" and python_version < "4.0" \
103
+ --hash=sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3 \
104
+ --hash=sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad
105
+ azure-core==1.30.2 ; python_version >= "3.10" and python_version < "4.0" \
106
+ --hash=sha256:a14dc210efcd608821aa472d9fb8e8d035d29b68993819147bc290a8ac224472 \
107
+ --hash=sha256:cf019c1ca832e96274ae85abd3d9f752397194d9fea3b41487290562ac8abe4a
108
+ azure-identity==1.17.1 ; python_version >= "3.10" and python_version < "4.0" \
109
+ --hash=sha256:32ecc67cc73f4bd0595e4f64b1ca65cd05186f4fe6f98ed2ae9f1aa32646efea \
110
+ --hash=sha256:db8d59c183b680e763722bfe8ebc45930e6c57df510620985939f7f3191e0382
111
+ azure-search-documents==11.4.0 ; python_version >= "3.10" and python_version < "4.0" \
112
+ --hash=sha256:599f269f106fb51e646ff426a218c21811575598e6a769b23fa4a0127c0f57e0 \
113
+ --hash=sha256:e435266dc992a3450dc475309c9475f89a4bb0e9dac838140e609d9f1c7608ac
114
  certifi==2024.6.2 ; python_version >= "3.10" and python_version < "4.0" \
115
  --hash=sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516 \
116
  --hash=sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56
 
736
  --hash=sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9 \
737
  --hash=sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e \
738
  --hash=sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f
739
+ isodate==0.6.1 ; python_version >= "3.10" and python_version < "4.0" \
740
+ --hash=sha256:0751eece944162659049d35f4f549ed815792b38793f07cf73381c1c87cbed96 \
741
+ --hash=sha256:48c5881de7e8b0a0d648cb024c8062dc84e7b840ed81e864c7614fd3c127bde9
742
  itsdangerous==2.2.0 ; python_version >= "3.10" and python_version < "4.0" \
743
  --hash=sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef \
744
  --hash=sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173
 
995
  mpmath==1.3.0 ; python_version >= "3.10" and python_version < "4.0" \
996
  --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
997
  --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
998
+ msal-extensions==1.1.0 ; python_version >= "3.10" and python_version < "4.0" \
999
+ --hash=sha256:01be9711b4c0b1a151450068eeb2c4f0997df3bba085ac299de3a66f585e382f \
1000
+ --hash=sha256:6ab357867062db7b253d0bd2df6d411c7891a0ee7308d54d1e4317c1d1c54252
1001
  msal==1.28.1 ; python_version >= "3.10" and python_version < "4.0" \
1002
  --hash=sha256:563c2d70de77a2ca9786aab84cb4e133a38a6897e6676774edc23d610bfc9e7b \
1003
  --hash=sha256:d72bbfe2d5c2f2555f4bc6205be4450ddfd12976610dd9a16a9ab0f05c68b64d
spinoza_project/config.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ demo_name: Spinoza Q&A
2
+ tabs:
3
+ GIEC et IPBES: "*Outil dédié aux rapports du GIEC et de l'IPBES.*"
4
+ Textes Juridiques: "*Outil dédié aux codes Français modifiés par la loi climat (21/73).*"
5
+ Documents Stratégiques: "*Outil dédié aux données centrées sur le plan politique (SNBC).*"
6
+ ADEME:
7
+ "*Outil dédié aux données issues de l'ADEME et nous avons sélectionnés notamment différentes catégories de rapports:*\n
8
+ * *Les guides mis à disposition de la population*\n
9
+ * *Les rapport d'expériences sur des nouvelles technologies*\n
10
+ * *Des études et recherches sur des impacts locaux*\n
11
+ * *Des documents institutionnels (analyses demandées par la France & rapports d'activité)*\n
12
+ * *Les plans de transition sectoriels pour les secteurs industriels les plus émetteurs : (verre, papier, ciment, acier, aluminium, chimie, sucre)*"
13
+ Presse: "*Outil dédié aux données fournies par Aday concernant la presse.*"
14
+
15
+ logo_rsf: ""
16
+
17
+ logo_ap: ""
18
+
19
+ source_mapping:
20
+ GIEC et IPBES: "Science"
21
+ Textes Juridiques: "Loi"
22
+ Documents Stratégiques: "Politique"
23
+ ADEME: "ADEME"
24
+ Presse: "Presse"
25
+
26
+ prompt_naming:
27
+ Science: "Science"
28
+ Loi: "Loi"
29
+ Politique: "Politique"
30
+ ADEME: "ADEME"
31
+ Presse: "Presse"
32
+
33
+ database_index_path: "./app/data/database_tab_placeholder.pickle"
34
+ query_preprompt: "query: "
35
+ passage_preprompt: "passage: "
36
+ embedding_model: "intfloat/multilingual-e5-base"
37
+ num_document_retrieved: 5
38
+ min_similarity: 0.05
39
+
40
+ ## Chat API
41
+ user_token: "user"
42
+ assistant_token: "assistant"
43
+ system_token: "system"
44
+ stop_token: "" ## useless in chat mode
spinoza_project/source/backend/llm_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import AzureChatOpenAI
2
+ from msal import ConfidentialClientApplication
3
+ from langchain_openai import AzureOpenAIEmbeddings
4
+ from langchain.vectorstores.azuresearch import AzureSearch
5
+ import os
6
+
7
+
8
+ class LLM:
9
+ def __init__(self, llm):
10
+ self.llm = llm
11
+ self.callbacks = []
12
+
13
+ def stream(self, prompt, prompt_arguments):
14
+ self.llm.streaming = True
15
+ streamed_content = self.llm.stream(prompt.format_messages(**prompt_arguments))
16
+ output = ""
17
+ for op in streamed_content:
18
+ output += op.content
19
+ yield output
20
+
21
+ def get_prediction(self, prompt, prompt_arguments):
22
+ self.llm.callbacks = self.callbacks
23
+ return self.llm.predict_messages(
24
+ prompt.format_messages(**prompt_arguments)
25
+ ).content
26
+
27
+ async def get_aprediction(self, prompt, prompt_arguments):
28
+ self.llm.callbacks = self.callbacks
29
+ prediction = await self.llm.apredict_messages(
30
+ prompt.format_messages(**prompt_arguments)
31
+ )
32
+ return prediction
33
+
34
+ async def get_apredictions(self, prompts, prompts_arguments):
35
+ self.llm.callbacks = self.callbacks
36
+ predictions = []
37
+ for prompt_, prompt_args_ in zip(prompts.keys(), prompts_arguments):
38
+ prediction = await self.llm.apredict_messages(
39
+ prompts[prompt_].format_messages(**prompt_args_)
40
+ )
41
+ predictions.append(prediction.content)
42
+ return predictions
43
+
44
+
45
+ def get_token() -> str | None:
46
+ app = ConfidentialClientApplication(
47
+ client_id=os.getenv("CLIENT_ID"),
48
+ client_credential=os.getenv("CLIENT_SECRET"),
49
+ authority=f"https://login.microsoftonline.com/{os.getenv('TENANT_ID')}",
50
+ )
51
+ result = app.acquire_token_for_client(scopes=[os.getenv("SCOPE")])
52
+ if result is not None:
53
+ return result["access_token"]
54
+
55
+
56
+ def get_llm():
57
+ os.environ["OPENAI_API_KEY"] = get_token()
58
+ os.environ["AZURE_OPENAI_ENDPOINT"] = (
59
+ f"{os.getenv('OPENAI_API_ENDPOINT')}{os.getenv('DEPLOYMENT_ID')}/chat/completions?api-version={os.getenv('OPENAI_API_VERSION')}"
60
+ )
61
+
62
+ return LLM(AzureChatOpenAI())
63
+
64
+
65
+ def get_vectorstore(index_name, model="text-embedding-ada-002"):
66
+ os.environ["AZURE_OPENAI_ENDPOINT"] = (
67
+ f"{os.getenv('OPENAI_API_ENDPOINT')}{os.getenv('DEPLOYMENT_EMB_ID')}/embeddings?api-version={os.getenv('OPENAI_API_VERSION')}"
68
+ )
69
+ os.environ["AZURE_OPENAI_API_KEY"] = get_token()
70
+
71
+ aoai_embeddings = AzureOpenAIEmbeddings(
72
+ azure_deployment=model,
73
+ openai_api_version=os.getenv("OPENAI_API_VERSION"),
74
+ )
75
+
76
+ vector_store: AzureSearch = AzureSearch(
77
+ azure_search_endpoint=os.getenv("VECTOR_STORE_ADDRESS"),
78
+ azure_search_key=os.getenv("VECTOR_STORE_PASSWORD"),
79
+ index_name=index_name,
80
+ embedding_function=aoai_embeddings.embed_query,
81
+ )
82
+
83
+ return vector_store