WilliamGazeley commited on
Commit
391d6e2
1 Parent(s): e40d8d8

Update get_analysis to better version

Browse files
Files changed (5) hide show
  1. ex.env +4 -0
  2. requirements.txt +1 -0
  3. src/config.py +4 -2
  4. src/functions.py +27 -16
  5. tests/test_functions.py +13 -0
ex.env ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ HF_TOKEN=
2
+ OLLAMA_MODEL=
3
+ AZ_SEARCH_API_KEY=
4
+ AZURE_OPENAI_API_KEY=
requirements.txt CHANGED
@@ -26,3 +26,4 @@ accelerate==0.27.2
26
  azure-search-documents==11.6.0b1
27
  azure-identity==1.16.0
28
  loguru==0.7.2
 
 
26
  azure-search-documents==11.6.0b1
27
  azure-identity==1.16.0
28
  loguru==0.7.2
29
+ openai==1.30.1
src/config.py CHANGED
@@ -17,9 +17,11 @@ class Config(BaseSettings):
17
 
18
  az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
19
  az_search_api_key: str = Field(...)
20
- az_search_idx_name: str = Field("analysis-index")
21
  az_search_top_k: int = Field(4, description="Max number of results to retrun")
22
- az_search_min_score: float = Field(9.0, description="Only results above this confidence score is used")
 
 
23
 
24
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
25
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
 
17
 
18
  az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
19
  az_search_api_key: str = Field(...)
20
+ az_search_idx_name: str = Field("analysis-index-2024-05-19")
21
  az_search_top_k: int = Field(4, description="Max number of results to retrun")
22
+
23
+ azure_openai_api_key: str = Field(...)
24
+ azure_openai_endpoint: str = Field("https://irai-openai-eastus.openai.azure.com/")
25
 
26
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
27
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
src/functions.py CHANGED
@@ -1,24 +1,31 @@
1
  import re
2
- import inspect
3
  import requests
4
  import pandas as pd
5
  import yfinance as yf
6
  import concurrent.futures
 
7
  from datetime import datetime
8
 
9
  from typing import List
10
  from bs4 import BeautifulSoup
11
  from logger import logger
 
12
  from langchain.tools import tool
13
  from langchain_core.utils.function_calling import convert_to_openai_tool
14
  from config import config
15
 
16
  from azure.core.credentials import AzureKeyCredential
17
  from azure.search.documents import SearchClient
 
18
 
19
 
20
  az_creds = AzureKeyCredential(config.az_search_api_key)
21
  az_search_client = SearchClient(config.az_search_endpoint, config.az_search_idx_name, az_creds)
 
 
 
 
 
22
 
23
  @tool
24
  def get_analysis(query: str) -> dict:
@@ -32,25 +39,29 @@ def get_analysis(query: str) -> dict:
32
  Returns:
33
  list: A list of dictionaries containing the pieces of analysis.
34
  """
 
 
 
 
 
 
35
  results = az_search_client.search(
36
- query_type="semantic",
37
- search_text=query,
38
- select="title,content,asset_name,write_date",
39
- include_total_count=True,
40
- top=config.az_search_top_k,
41
- semantic_configuration_name="basic-keywords",
42
- vector_queries=None, # Docs are too semantically similar, disable for now
43
- )
44
 
45
  output = []
46
  for x in results:
47
- if x["@search.score"] >= config.az_search_min_score:
48
- output.append({
49
- "security": x["asset_name"],
50
- "date written": datetime.strptime(x["write_date"], "%Y%m%d").date(),
51
- "title": x["title"],
52
- "content": x["content"]
53
- })
54
  return output
55
 
56
  @tool
 
1
  import re
 
2
  import requests
3
  import pandas as pd
4
  import yfinance as yf
5
  import concurrent.futures
6
+ from time import time
7
  from datetime import datetime
8
 
9
  from typing import List
10
  from bs4 import BeautifulSoup
11
  from logger import logger
12
+ from openai import AzureOpenAI
13
  from langchain.tools import tool
14
  from langchain_core.utils.function_calling import convert_to_openai_tool
15
  from config import config
16
 
17
  from azure.core.credentials import AzureKeyCredential
18
  from azure.search.documents import SearchClient
19
+ from azure.search.documents.models import VectorizedQuery
20
 
21
 
22
  az_creds = AzureKeyCredential(config.az_search_api_key)
23
  az_search_client = SearchClient(config.az_search_endpoint, config.az_search_idx_name, az_creds)
24
+ openai_client = AzureOpenAI(
25
+ azure_endpoint=config.azure_openai_endpoint,
26
+ api_key=config.azure_openai_api_key,
27
+ api_version="2024-02-01"
28
+ )
29
 
30
  @tool
31
  def get_analysis(query: str) -> dict:
 
39
  Returns:
40
  list: A list of dictionaries containing the pieces of analysis.
41
  """
42
+ start_time = time()
43
+
44
+ embed_model = "default-large-embeddings"
45
+
46
+ vec = openai_client.embeddings.create(input=[query], model=embed_model).data[0].embedding
47
+ vector_query = VectorizedQuery(vector=vec, k_nearest_neighbors=config.az_search_top_k * 2, fields="vector")
48
  results = az_search_client.search(
49
+ search_text="*",
50
+ vector_queries=[vector_query],
51
+ select=["date", "popularity", "sequence", "context", "securities"],
52
+ order_by=["securities desc", "date desc", "popularity desc"],
53
+ top=config.az_search_top_k,
54
+ )
 
 
55
 
56
  output = []
57
  for x in results:
58
+ output.append({
59
+ "securities": x["securities"],
60
+ "date written": x["date"].split("T")[0],
61
+ "summary": x["context"],
62
+ "content": x["sequence"],
63
+ })
64
+ print(f"Search took {time() - start_time:.2f} seconds\n---")
65
  return output
66
 
67
  @tool
tests/test_functions.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tests the functions themselves, not the function calling
2
+
3
+ from pprint import pprint
4
+ from functions import get_analysis
5
+
6
+ def test_get_analysis():
7
+ query = "How is MSTR doing?"
8
+ output = get_analysis(query)
9
+
10
+ pprint(output)
11
+ assert len(output) != 0
12
+ assert "MSTR" in output[0]['securities']
13
+