Spaces:
Sleeping
Sleeping
Merge branch 'simple-rag'
Browse files- .gitignore +7 -0
- app.py +36 -28
- config.py +14 -0
- functioncall.py +163 -0
- functions.py +262 -0
- prompt_assets/few_shot.json +8 -0
- prompt_assets/output_sys_prompt.yml +10 -0
- prompt_assets/sys_prompt.yml +38 -0
- prompter.py +76 -0
- requirements.txt +130 -5
- schema.py +23 -0
- utils.py +149 -0
- validator.py +132 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
|
3 |
+
# Python
|
4 |
+
__pycache__/
|
5 |
+
|
6 |
+
# vLLM
|
7 |
+
inference_logs/
|
app.py
CHANGED
@@ -1,43 +1,43 @@
|
|
1 |
import os
|
2 |
import huggingface_hub
|
3 |
import streamlit as st
|
4 |
-
from
|
|
|
|
|
|
|
5 |
|
6 |
-
sys_msg = """#Context:
|
7 |
-
You are an expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
|
8 |
-
#Objective:
|
9 |
-
Please answer questions as best as possible given your current knowledge. You do not have access to up-to-date current market data. Try to demonstrate analytical depth and showcase ability to integrate complex data into practical advice, but answer the question directly.
|
10 |
-
#Style and tone:
|
11 |
-
Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
|
12 |
-
#Audience:
|
13 |
-
The questions will be asked by top executives and managers of successful startups. Assume the audience is composed of 40 year old males with high wealth and income, high risk appetite with high threshold for volatility.
|
14 |
-
#Response:
|
15 |
-
Direct answer to question, concise yet insightful."""
|
16 |
|
17 |
@st.cache_resource(show_spinner="Loading model..")
|
18 |
def init_llm():
|
19 |
-
huggingface_hub.login(token=
|
20 |
-
llm =
|
21 |
-
tok = llm.get_tokenizer()
|
22 |
-
tok.eos_token = '<|im_end|>' # Override to use turns
|
23 |
return llm
|
24 |
|
25 |
def get_response(prompt):
|
26 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
convo = [
|
28 |
-
{"role": "system", "content":
|
29 |
-
{"role": "user", "content":
|
30 |
]
|
31 |
-
|
32 |
-
|
33 |
-
sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=500, stop_token_ids=[128009])
|
34 |
-
outputs = llm.generate(prompts, sampling_params)
|
35 |
-
for output in outputs:
|
36 |
-
return output.outputs[0].text
|
37 |
except Exception as e:
|
38 |
return f"An error occurred: {str(e)}"
|
39 |
|
40 |
-
|
41 |
def main():
|
42 |
st.title("LLM-ADE 9B Demo")
|
43 |
|
@@ -46,13 +46,21 @@ def main():
|
|
46 |
if st.button("Generate"):
|
47 |
if input_text:
|
48 |
with st.spinner('Generating response...'):
|
49 |
-
|
50 |
-
st.write(
|
51 |
else:
|
52 |
st.warning("Please enter some text to generate a response.")
|
53 |
|
54 |
llm = init_llm()
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import huggingface_hub
|
3 |
import streamlit as st
|
4 |
+
from config import config
|
5 |
+
from utils import get_assistant_message
|
6 |
+
from functioncall import ModelInference
|
7 |
+
from prompter import PromptManager
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
@st.cache_resource(show_spinner="Loading model..")
|
11 |
def init_llm():
|
12 |
+
huggingface_hub.login(token=config.hf_token, new_session=False)
|
13 |
+
llm = ModelInference(chat_template=config.chat_template)
|
|
|
|
|
14 |
return llm
|
15 |
|
16 |
def get_response(prompt):
|
17 |
try:
|
18 |
+
return llm.generate_function_call(
|
19 |
+
prompt,
|
20 |
+
config.chat_template,
|
21 |
+
config.num_fewshot,
|
22 |
+
config.max_depth
|
23 |
+
)
|
24 |
+
except Exception as e:
|
25 |
+
return f"An error occurred: {str(e)}"
|
26 |
+
|
27 |
+
def get_output(context, user_input):
|
28 |
+
try:
|
29 |
+
prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml")
|
30 |
+
sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \
|
31 |
+
f"Information:\n{context}"
|
32 |
convo = [
|
33 |
+
{"role": "system", "content": sys_prompt},
|
34 |
+
{"role": "user", "content": user_input},
|
35 |
]
|
36 |
+
response = llm.run_inference(convo)
|
37 |
+
return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token)
|
|
|
|
|
|
|
|
|
38 |
except Exception as e:
|
39 |
return f"An error occurred: {str(e)}"
|
40 |
|
|
|
41 |
def main():
|
42 |
st.title("LLM-ADE 9B Demo")
|
43 |
|
|
|
46 |
if st.button("Generate"):
|
47 |
if input_text:
|
48 |
with st.spinner('Generating response...'):
|
49 |
+
agent_resp = get_response(input_text)
|
50 |
+
st.write(get_output(agent_resp, input_text))
|
51 |
else:
|
52 |
st.warning("Please enter some text to generate a response.")
|
53 |
|
54 |
llm = init_llm()
|
55 |
|
56 |
+
def main_headless():
|
57 |
+
while True:
|
58 |
+
input_text = input("Enter your text here: ")
|
59 |
+
agent_resp = get_response(input_text)
|
60 |
+
print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m')
|
61 |
|
62 |
+
if __name__ == "__main__":
|
63 |
+
if config.headless:
|
64 |
+
main_headless()
|
65 |
+
else:
|
66 |
+
main()
|
config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import Field
|
2 |
+
from pydantic_settings import BaseSettings
|
3 |
+
|
4 |
+
class Config(BaseSettings):
|
5 |
+
hf_token: str = Field(...)
|
6 |
+
model_path: str = Field("InvestmentResearchAI/LLM-ADE-dev")
|
7 |
+
headless: bool = Field(False, description="Run in headless mode.")
|
8 |
+
|
9 |
+
chat_template: str = Field("chatml", description="Chat template for prompt formatting")
|
10 |
+
num_fewshot: int | None = Field(None, description="Option to use json mode examples")
|
11 |
+
load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
|
12 |
+
max_depth: int = Field(5, description="Maximum number of recursive iteration")
|
13 |
+
|
14 |
+
config = Config(_env_file=".env")
|
functioncall.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from config import config
|
5 |
+
from typing import List, Dict
|
6 |
+
from vllm import LLM, SamplingParams
|
7 |
+
|
8 |
+
from transformers import (
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
AutoTokenizer,
|
11 |
+
BitsAndBytesConfig
|
12 |
+
)
|
13 |
+
|
14 |
+
import functions
|
15 |
+
from prompter import PromptManager
|
16 |
+
from validator import validate_function_call_schema
|
17 |
+
|
18 |
+
from utils import (
|
19 |
+
inference_logger,
|
20 |
+
get_assistant_message,
|
21 |
+
get_chat_template,
|
22 |
+
validate_and_extract_tool_calls
|
23 |
+
)
|
24 |
+
|
25 |
+
class ModelInference:
|
26 |
+
def __init__(self, chat_template: str, load_in_4bit: bool = False):
|
27 |
+
self.prompter = PromptManager()
|
28 |
+
self.bnb_config = None
|
29 |
+
|
30 |
+
if load_in_4bit == "True": # Never use this
|
31 |
+
self.bnb_config = BitsAndBytesConfig(
|
32 |
+
load_in_4bit=True,
|
33 |
+
bnb_4bit_quant_type="nf4",
|
34 |
+
bnb_4bit_use_double_quant=True,
|
35 |
+
)
|
36 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
config.model_path,
|
38 |
+
trust_remote_code=True,
|
39 |
+
return_dict=True,
|
40 |
+
quantization_config=self.bnb_config,
|
41 |
+
torch_dtype=torch.float16,
|
42 |
+
attn_implementation="flash_attention_2",
|
43 |
+
device_map="auto",
|
44 |
+
)
|
45 |
+
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True)
|
47 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
48 |
+
self.tokenizer.padding_side = "left"
|
49 |
+
|
50 |
+
if self.tokenizer.chat_template is None:
|
51 |
+
print("No chat template defined, getting chat_template...")
|
52 |
+
self.tokenizer.chat_template = get_chat_template(chat_template)
|
53 |
+
|
54 |
+
inference_logger.info(self.model.config)
|
55 |
+
inference_logger.info(self.model.generation_config)
|
56 |
+
inference_logger.info(self.tokenizer.special_tokens_map)
|
57 |
+
|
58 |
+
def process_completion_and_validate(self, completion, chat_template):
|
59 |
+
|
60 |
+
assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
|
61 |
+
|
62 |
+
if assistant_message:
|
63 |
+
validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message)
|
64 |
+
|
65 |
+
if validation:
|
66 |
+
inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
|
67 |
+
return tool_calls, assistant_message, error_message
|
68 |
+
else:
|
69 |
+
tool_calls = None
|
70 |
+
return tool_calls, assistant_message, error_message
|
71 |
+
else:
|
72 |
+
inference_logger.warning("Assistant message is None")
|
73 |
+
raise ValueError("Assistant message is None")
|
74 |
+
|
75 |
+
def execute_function_call(self, tool_call):
|
76 |
+
function_name = tool_call.get("name")
|
77 |
+
function_to_call = getattr(functions, function_name, None)
|
78 |
+
function_args = tool_call.get("arguments", {})
|
79 |
+
|
80 |
+
inference_logger.info(f"Invoking function call {function_name} ...")
|
81 |
+
function_response = function_to_call(*function_args.values())
|
82 |
+
results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
|
83 |
+
return results_dict
|
84 |
+
|
85 |
+
def run_inference(self, prompt: List[Dict[str, str]]):
|
86 |
+
inputs = self.tokenizer.apply_chat_template(
|
87 |
+
prompt,
|
88 |
+
add_generation_prompt=True,
|
89 |
+
return_tensors='pt'
|
90 |
+
)
|
91 |
+
|
92 |
+
tokens = self.model.generate(
|
93 |
+
inputs.to(self.model.device),
|
94 |
+
max_new_tokens=1500,
|
95 |
+
temperature=0.8,
|
96 |
+
repetition_penalty=1.1,
|
97 |
+
do_sample=True,
|
98 |
+
eos_token_id=self.tokenizer.eos_token_id
|
99 |
+
)
|
100 |
+
completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
|
101 |
+
return completion
|
102 |
+
|
103 |
+
def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
|
104 |
+
try:
|
105 |
+
depth = 0
|
106 |
+
user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
107 |
+
chat = [{"role": "user", "content": user_message}]
|
108 |
+
tools = functions.get_openai_tools()
|
109 |
+
prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
|
110 |
+
completion = self.run_inference(prompt)
|
111 |
+
|
112 |
+
def recursive_loop(prompt, completion, depth):
|
113 |
+
nonlocal max_depth
|
114 |
+
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
|
115 |
+
prompt.append({"role": "assistant", "content": assistant_message})
|
116 |
+
|
117 |
+
tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
|
118 |
+
if tool_calls:
|
119 |
+
inference_logger.info(f"Assistant Message:\n{assistant_message}")
|
120 |
+
|
121 |
+
for tool_call in tool_calls:
|
122 |
+
validation, message = validate_function_call_schema(tool_call, tools)
|
123 |
+
if validation:
|
124 |
+
try:
|
125 |
+
function_response = self.execute_function_call(tool_call)
|
126 |
+
tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
|
127 |
+
inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
|
128 |
+
except Exception as e:
|
129 |
+
inference_logger.info(f"Could not execute function: {e}")
|
130 |
+
tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
|
131 |
+
else:
|
132 |
+
inference_logger.info(message)
|
133 |
+
tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
|
134 |
+
prompt.append({"role": "tool", "content": tool_message})
|
135 |
+
|
136 |
+
depth += 1
|
137 |
+
if depth >= max_depth:
|
138 |
+
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
|
139 |
+
return
|
140 |
+
|
141 |
+
completion = self.run_inference(prompt)
|
142 |
+
return recursive_loop(prompt, completion, depth)
|
143 |
+
elif error_message:
|
144 |
+
inference_logger.info(f"Assistant Message:\n{assistant_message}")
|
145 |
+
tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
|
146 |
+
prompt.append({"role": "tool", "content": tool_message})
|
147 |
+
|
148 |
+
depth += 1
|
149 |
+
if depth >= max_depth:
|
150 |
+
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
|
151 |
+
return
|
152 |
+
|
153 |
+
completion = self.run_inference(prompt)
|
154 |
+
return recursive_loop(prompt, completion, depth)
|
155 |
+
else:
|
156 |
+
inference_logger.info(f"Assistant Message:\n{assistant_message}")
|
157 |
+
return assistant_message
|
158 |
+
|
159 |
+
return recursive_loop(prompt, completion, depth)
|
160 |
+
|
161 |
+
except Exception as e:
|
162 |
+
inference_logger.error(f"Exception occurred: {e}")
|
163 |
+
raise e
|
functions.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import inspect
|
3 |
+
import requests
|
4 |
+
import pandas as pd
|
5 |
+
import yfinance as yf
|
6 |
+
import concurrent.futures
|
7 |
+
|
8 |
+
from typing import List
|
9 |
+
from bs4 import BeautifulSoup
|
10 |
+
from utils import inference_logger
|
11 |
+
from langchain.tools import tool
|
12 |
+
from langchain_core.utils.function_calling import convert_to_openai_tool
|
13 |
+
|
14 |
+
@tool
|
15 |
+
def google_search_and_scrape(query: str) -> dict:
|
16 |
+
"""
|
17 |
+
Performs a Google search for the given query, retrieves the top search result URLs,
|
18 |
+
and scrapes the text content and table data from those pages in parallel.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
query (str): The search query.
|
22 |
+
Returns:
|
23 |
+
list: A list of dictionaries containing the URL, text content, and table data for each scraped page.
|
24 |
+
"""
|
25 |
+
num_results = 2
|
26 |
+
url = 'https://www.google.com/search'
|
27 |
+
params = {'q': query, 'num': num_results}
|
28 |
+
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'}
|
29 |
+
|
30 |
+
inference_logger.info(f"Performing google search with query: {query}\nplease wait...")
|
31 |
+
response = requests.get(url, params=params, headers=headers)
|
32 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
33 |
+
urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')]
|
34 |
+
|
35 |
+
inference_logger.info(f"Scraping text from urls, please wait...")
|
36 |
+
[inference_logger.info(url) for url in urls]
|
37 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
38 |
+
futures = [executor.submit(lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url in urls[:num_results] if isinstance(url, str)]
|
39 |
+
results = []
|
40 |
+
for future in concurrent.futures.as_completed(futures):
|
41 |
+
url, html = future.result()
|
42 |
+
soup = BeautifulSoup(html, 'html.parser')
|
43 |
+
paragraphs = [p.text.strip() for p in soup.find_all('p') if p.text.strip()]
|
44 |
+
text_content = ' '.join(paragraphs)
|
45 |
+
text_content = re.sub(r'\s+', ' ', text_content)
|
46 |
+
table_data = [[cell.get_text(strip=True) for cell in row.find_all('td')] for table in soup.find_all('table') for row in table.find_all('tr')]
|
47 |
+
if text_content or table_data:
|
48 |
+
results.append({'url': url, 'content': text_content, 'tables': table_data})
|
49 |
+
return results
|
50 |
+
|
51 |
+
@tool
|
52 |
+
def get_current_stock_price(symbol: str) -> float:
|
53 |
+
"""
|
54 |
+
Get the current stock price for a given symbol.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
symbol (str): The stock symbol.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
float: The current stock price, or None if an error occurs.
|
61 |
+
"""
|
62 |
+
try:
|
63 |
+
stock = yf.Ticker(symbol)
|
64 |
+
# Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market
|
65 |
+
current_price = stock.info.get("regularMarketPrice", stock.info.get("currentPrice"))
|
66 |
+
return current_price if current_price else None
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error fetching current price for {symbol}: {e}")
|
69 |
+
return None
|
70 |
+
|
71 |
+
@tool
|
72 |
+
def get_stock_fundamentals(symbol: str) -> dict:
|
73 |
+
"""
|
74 |
+
Get fundamental data for a given stock symbol using yfinance API.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
symbol (str): The stock symbol.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
dict: A dictionary containing fundamental data.
|
81 |
+
Keys:
|
82 |
+
- 'symbol': The stock symbol.
|
83 |
+
- 'company_name': The long name of the company.
|
84 |
+
- 'sector': The sector to which the company belongs.
|
85 |
+
- 'industry': The industry to which the company belongs.
|
86 |
+
- 'market_cap': The market capitalization of the company.
|
87 |
+
- 'pe_ratio': The forward price-to-earnings ratio.
|
88 |
+
- 'pb_ratio': The price-to-book ratio.
|
89 |
+
- 'dividend_yield': The dividend yield.
|
90 |
+
- 'eps': The trailing earnings per share.
|
91 |
+
- 'beta': The beta value of the stock.
|
92 |
+
- '52_week_high': The 52-week high price of the stock.
|
93 |
+
- '52_week_low': The 52-week low price of the stock.
|
94 |
+
"""
|
95 |
+
try:
|
96 |
+
stock = yf.Ticker(symbol)
|
97 |
+
info = stock.info
|
98 |
+
fundamentals = {
|
99 |
+
'symbol': symbol,
|
100 |
+
'company_name': info.get('longName', ''),
|
101 |
+
'sector': info.get('sector', ''),
|
102 |
+
'industry': info.get('industry', ''),
|
103 |
+
'market_cap': info.get('marketCap', None),
|
104 |
+
'pe_ratio': info.get('forwardPE', None),
|
105 |
+
'pb_ratio': info.get('priceToBook', None),
|
106 |
+
'dividend_yield': info.get('dividendYield', None),
|
107 |
+
'eps': info.get('trailingEps', None),
|
108 |
+
'beta': info.get('beta', None),
|
109 |
+
'52_week_high': info.get('fiftyTwoWeekHigh', None),
|
110 |
+
'52_week_low': info.get('fiftyTwoWeekLow', None)
|
111 |
+
}
|
112 |
+
return fundamentals
|
113 |
+
except Exception as e:
|
114 |
+
print(f"Error getting fundamentals for {symbol}: {e}")
|
115 |
+
return {}
|
116 |
+
|
117 |
+
@tool
|
118 |
+
def get_financial_statements(symbol: str) -> dict:
|
119 |
+
"""
|
120 |
+
Get financial statements for a given stock symbol.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
symbol (str): The stock symbol.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
dict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).
|
127 |
+
"""
|
128 |
+
try:
|
129 |
+
stock = yf.Ticker(symbol)
|
130 |
+
financials = stock.financials
|
131 |
+
return financials
|
132 |
+
except Exception as e:
|
133 |
+
print(f"Error fetching financial statements for {symbol}: {e}")
|
134 |
+
return {}
|
135 |
+
|
136 |
+
@tool
|
137 |
+
def get_key_financial_ratios(symbol: str) -> dict:
|
138 |
+
"""
|
139 |
+
Get key financial ratios for a given stock symbol.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
symbol (str): The stock symbol.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
dict: Dictionary containing key financial ratios.
|
146 |
+
"""
|
147 |
+
try:
|
148 |
+
stock = yf.Ticker(symbol)
|
149 |
+
key_ratios = stock.info
|
150 |
+
return key_ratios
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Error fetching key financial ratios for {symbol}: {e}")
|
153 |
+
return {}
|
154 |
+
|
155 |
+
@tool
|
156 |
+
def get_analyst_recommendations(symbol: str) -> pd.DataFrame:
|
157 |
+
"""
|
158 |
+
Get analyst recommendations for a given stock symbol.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
symbol (str): The stock symbol.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
pd.DataFrame: DataFrame containing analyst recommendations.
|
165 |
+
"""
|
166 |
+
try:
|
167 |
+
stock = yf.Ticker(symbol)
|
168 |
+
recommendations = stock.recommendations
|
169 |
+
return recommendations
|
170 |
+
except Exception as e:
|
171 |
+
print(f"Error fetching analyst recommendations for {symbol}: {e}")
|
172 |
+
return pd.DataFrame()
|
173 |
+
|
174 |
+
@tool
|
175 |
+
def get_dividend_data(symbol: str) -> pd.DataFrame:
|
176 |
+
"""
|
177 |
+
Get dividend data for a given stock symbol.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
symbol (str): The stock symbol.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
pd.DataFrame: DataFrame containing dividend data.
|
184 |
+
"""
|
185 |
+
try:
|
186 |
+
stock = yf.Ticker(symbol)
|
187 |
+
dividends = stock.dividends
|
188 |
+
return dividends
|
189 |
+
except Exception as e:
|
190 |
+
print(f"Error fetching dividend data for {symbol}: {e}")
|
191 |
+
return pd.DataFrame()
|
192 |
+
|
193 |
+
@tool
|
194 |
+
def get_company_news(symbol: str) -> pd.DataFrame:
|
195 |
+
"""
|
196 |
+
Get company news and press releases for a given stock symbol.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
symbol (str): The stock symbol.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
pd.DataFrame: DataFrame containing company news and press releases.
|
203 |
+
"""
|
204 |
+
try:
|
205 |
+
news = yf.Ticker(symbol).news
|
206 |
+
return news
|
207 |
+
except Exception as e:
|
208 |
+
print(f"Error fetching company news for {symbol}: {e}")
|
209 |
+
return pd.DataFrame()
|
210 |
+
|
211 |
+
@tool
|
212 |
+
def get_technical_indicators(symbol: str) -> pd.DataFrame:
|
213 |
+
"""
|
214 |
+
Get technical indicators for a given stock symbol.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
symbol (str): The stock symbol.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
pd.DataFrame: DataFrame containing technical indicators.
|
221 |
+
"""
|
222 |
+
try:
|
223 |
+
indicators = yf.Ticker(symbol).history(period="max")
|
224 |
+
return indicators
|
225 |
+
except Exception as e:
|
226 |
+
print(f"Error fetching technical indicators for {symbol}: {e}")
|
227 |
+
return pd.DataFrame()
|
228 |
+
|
229 |
+
@tool
|
230 |
+
def get_company_profile(symbol: str) -> dict:
|
231 |
+
"""
|
232 |
+
Get company profile and overview for a given stock symbol.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
symbol (str): The stock symbol.
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
dict: Dictionary containing company profile and overview.
|
239 |
+
"""
|
240 |
+
try:
|
241 |
+
profile = yf.Ticker(symbol).info
|
242 |
+
return profile
|
243 |
+
except Exception as e:
|
244 |
+
print(f"Error fetching company profile for {symbol}: {e}")
|
245 |
+
return {}
|
246 |
+
|
247 |
+
def get_openai_tools() -> List[dict]:
|
248 |
+
functions = [
|
249 |
+
google_search_and_scrape,
|
250 |
+
get_current_stock_price,
|
251 |
+
get_company_news,
|
252 |
+
get_company_profile,
|
253 |
+
get_stock_fundamentals,
|
254 |
+
get_financial_statements,
|
255 |
+
get_key_financial_ratios,
|
256 |
+
get_analyst_recommendations,
|
257 |
+
get_dividend_data,
|
258 |
+
get_technical_indicators
|
259 |
+
]
|
260 |
+
|
261 |
+
tools = [convert_to_openai_tool(f) for f in functions]
|
262 |
+
return tools
|
prompt_assets/few_shot.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n</tool_call>\n```\n"
|
4 |
+
},
|
5 |
+
{
|
6 |
+
"example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n</tool_call>\n```"
|
7 |
+
}
|
8 |
+
]
|
prompt_assets/output_sys_prompt.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Role: |
|
2 |
+
You are an expert financial advisor named IRAI.
|
3 |
+
You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
|
4 |
+
You can use information given to you, but do not mention function calls.
|
5 |
+
Objective: |
|
6 |
+
Answer questions accurately and truthfully given your current knowledge. Answer the question directly.
|
7 |
+
Instructions: |
|
8 |
+
The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
|
9 |
+
Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
|
10 |
+
Give a direct answer to question, concise yet insightful.
|
prompt_assets/sys_prompt.yml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Role: |
|
2 |
+
You are a function calling AI agent with self-recursion.
|
3 |
+
You can call only one function at a time and analyse data you get from function response.
|
4 |
+
You are provided with function signatures within <tools></tools> XML tags.
|
5 |
+
The current date is: {date}.
|
6 |
+
Objective: |
|
7 |
+
You may use agentic frameworks for reasoning and planning to help with user query.
|
8 |
+
Please call a function and wait for function results to be provided to you in the next iteration.
|
9 |
+
Don't make assumptions about what values to plug into function arguments.
|
10 |
+
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
|
11 |
+
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
|
12 |
+
Analyze the data once you get the results and call another function.
|
13 |
+
At each iteration please continue adding the your analysis to previous summary.
|
14 |
+
Your final response should directly answer the user query with an anlysis or summary of the results of function calls.
|
15 |
+
Tools: |
|
16 |
+
Here are the available tools:
|
17 |
+
<tools> {tools} </tools>
|
18 |
+
If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
|
19 |
+
<tool_call>
|
20 |
+
{{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
|
21 |
+
</tool_call>
|
22 |
+
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
23 |
+
Examples: |
|
24 |
+
Here are some example usage of functions:
|
25 |
+
{examples}
|
26 |
+
Schema: |
|
27 |
+
Use the following pydantic model json schema for each tool call you will make:
|
28 |
+
{schema}
|
29 |
+
Instructions: |
|
30 |
+
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
|
31 |
+
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
|
32 |
+
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
|
33 |
+
Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
|
34 |
+
If you plan to continue with analysis, always call another function.
|
35 |
+
For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
36 |
+
<tool_call>
|
37 |
+
{{"arguments": <args-dict>, "name": <function-name>}}
|
38 |
+
</tool_call>
|
prompter.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import Dict
|
4 |
+
from schema import FunctionCall
|
5 |
+
from utils import (
|
6 |
+
get_fewshot_examples
|
7 |
+
)
|
8 |
+
import yaml
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
|
12 |
+
class PromptSchema(BaseModel):
|
13 |
+
Role: str
|
14 |
+
Objective: str
|
15 |
+
Tools: str
|
16 |
+
Examples: str
|
17 |
+
Schema: str
|
18 |
+
Instructions: str
|
19 |
+
|
20 |
+
class PromptManager:
|
21 |
+
def __init__(self):
|
22 |
+
self.script_dir = os.path.dirname(os.path.abspath(__file__))
|
23 |
+
|
24 |
+
def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str:
|
25 |
+
formatted_prompt = ""
|
26 |
+
for field, value in prompt_schema.dict().items():
|
27 |
+
if field == "Examples" and variables.get("examples") is None:
|
28 |
+
continue
|
29 |
+
formatted_value = value.format(**variables)
|
30 |
+
if field == "Instructions":
|
31 |
+
formatted_prompt += f"{formatted_value}"
|
32 |
+
else:
|
33 |
+
formatted_value = formatted_value.replace("\n", " ")
|
34 |
+
formatted_prompt += f"{formatted_value}"
|
35 |
+
return formatted_prompt
|
36 |
+
|
37 |
+
def read_yaml_file(self, file_path: str) -> PromptSchema:
|
38 |
+
with open(file_path, 'r') as file:
|
39 |
+
yaml_content = yaml.safe_load(file)
|
40 |
+
|
41 |
+
prompt_schema = PromptSchema(
|
42 |
+
Role=yaml_content.get('Role', ''),
|
43 |
+
Objective=yaml_content.get('Objective', ''),
|
44 |
+
Tools=yaml_content.get('Tools', ''),
|
45 |
+
Examples=yaml_content.get('Examples', ''),
|
46 |
+
Schema=yaml_content.get('Schema', ''),
|
47 |
+
Instructions=yaml_content.get('Instructions', ''),
|
48 |
+
)
|
49 |
+
return prompt_schema
|
50 |
+
|
51 |
+
def generate_prompt(self, user_prompt, tools, num_fewshot=None):
|
52 |
+
prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml')
|
53 |
+
prompt_schema = self.read_yaml_file(prompt_path)
|
54 |
+
|
55 |
+
if num_fewshot is not None:
|
56 |
+
examples = get_fewshot_examples(num_fewshot)
|
57 |
+
else:
|
58 |
+
examples = None
|
59 |
+
|
60 |
+
schema_json = json.loads(FunctionCall.schema_json())
|
61 |
+
|
62 |
+
variables = {
|
63 |
+
"date": datetime.date.today(),
|
64 |
+
"tools": tools,
|
65 |
+
"examples": examples,
|
66 |
+
"schema": schema_json
|
67 |
+
}
|
68 |
+
sys_prompt = self.format_yaml_prompt(prompt_schema, variables)
|
69 |
+
|
70 |
+
prompt = [
|
71 |
+
{'content': sys_prompt, 'role': 'system'}
|
72 |
+
]
|
73 |
+
prompt.extend(user_prompt)
|
74 |
+
return prompt
|
75 |
+
|
76 |
+
|
requirements.txt
CHANGED
@@ -1,6 +1,131 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
xformers==0.0.23
|
6 |
-
|
|
|
|
1 |
+
aiohttp==3.9.5
|
2 |
+
aioprometheus==23.12.0
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.3.0
|
5 |
+
annotated-types==0.6.0
|
6 |
+
anyio==4.3.0
|
7 |
+
appdirs==1.4.4
|
8 |
+
async-timeout==4.0.3
|
9 |
+
attrs==23.2.0
|
10 |
+
beautifulsoup4==4.12.3
|
11 |
+
blinker==1.8.2
|
12 |
+
cachetools==5.3.3
|
13 |
+
certifi==2024.2.2
|
14 |
+
charset-normalizer==3.3.2
|
15 |
+
click==8.1.7
|
16 |
+
dataclasses-json==0.6.5
|
17 |
+
dnspython==2.6.1
|
18 |
+
email_validator==2.1.1
|
19 |
+
exceptiongroup==1.2.1
|
20 |
+
fastapi==0.111.0
|
21 |
+
fastapi-cli==0.0.3
|
22 |
+
filelock==3.14.0
|
23 |
+
frozendict==2.4.4
|
24 |
+
frozenlist==1.4.1
|
25 |
+
fsspec==2024.3.1
|
26 |
+
gitdb==4.0.11
|
27 |
+
GitPython==3.1.43
|
28 |
+
greenlet==3.0.3
|
29 |
+
h11==0.14.0
|
30 |
+
html5lib==1.1
|
31 |
+
httpcore==1.0.5
|
32 |
+
httptools==0.6.1
|
33 |
+
httpx==0.27.0
|
34 |
+
huggingface-hub==0.23.0
|
35 |
+
idna==3.7
|
36 |
+
Jinja2==3.1.4
|
37 |
+
jsonpatch==1.33
|
38 |
+
jsonpointer==2.4
|
39 |
+
jsonschema==4.22.0
|
40 |
+
jsonschema-specifications==2023.12.1
|
41 |
+
langchain==0.1.17
|
42 |
+
langchain-community==0.0.37
|
43 |
+
langchain-core==0.1.52
|
44 |
+
langchain-text-splitters==0.0.1
|
45 |
+
langsmith==0.1.54
|
46 |
+
lxml==5.2.1
|
47 |
+
markdown-it-py==3.0.0
|
48 |
+
MarkupSafe==2.1.5
|
49 |
+
marshmallow==3.21.2
|
50 |
+
mdurl==0.1.2
|
51 |
+
mpmath==1.3.0
|
52 |
+
msgpack==1.0.8
|
53 |
+
multidict==6.0.5
|
54 |
+
multitasking==0.0.11
|
55 |
+
mypy-extensions==1.0.0
|
56 |
+
networkx==3.3
|
57 |
+
ninja==1.11.1.1
|
58 |
+
numpy==1.26.4
|
59 |
+
nvidia-cublas-cu12==12.1.3.1
|
60 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
61 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
62 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
63 |
+
nvidia-cudnn-cu12==8.9.2.26
|
64 |
+
nvidia-cufft-cu12==11.0.2.54
|
65 |
+
nvidia-curand-cu12==10.3.2.106
|
66 |
+
nvidia-cusolver-cu12==11.4.5.107
|
67 |
+
nvidia-cusparse-cu12==12.1.0.106
|
68 |
+
nvidia-nccl-cu12==2.18.1
|
69 |
+
nvidia-nvjitlink-cu12==12.4.127
|
70 |
+
nvidia-nvtx-cu12==12.1.105
|
71 |
+
orjson==3.10.3
|
72 |
+
packaging==23.2
|
73 |
+
pandas==2.2.2
|
74 |
+
peewee==3.17.3
|
75 |
+
pillow==10.3.0
|
76 |
+
protobuf==4.25.3
|
77 |
+
psutil==5.9.8
|
78 |
+
pyarrow==16.0.0
|
79 |
+
pydantic==2.7.1
|
80 |
+
pydantic-settings==2.2.1
|
81 |
+
pydantic_core==2.18.2
|
82 |
+
pydeck==0.9.0
|
83 |
+
Pygments==2.18.0
|
84 |
+
python-dateutil==2.9.0.post0
|
85 |
+
python-dotenv==1.0.1
|
86 |
+
python-multipart==0.0.9
|
87 |
+
pytz==2024.1
|
88 |
+
PyYAML==6.0.1
|
89 |
+
quantile-python==1.1
|
90 |
+
ray==2.20.0
|
91 |
+
referencing==0.35.1
|
92 |
+
regex==2024.4.28
|
93 |
+
requests==2.31.0
|
94 |
+
rich==13.7.1
|
95 |
+
rpds-py==0.18.1
|
96 |
+
safetensors==0.4.3
|
97 |
+
sentencepiece==0.2.0
|
98 |
+
shellingham==1.5.4
|
99 |
+
six==1.16.0
|
100 |
+
smmap==5.0.1
|
101 |
+
sniffio==1.3.1
|
102 |
+
soupsieve==2.5
|
103 |
+
SQLAlchemy==2.0.30
|
104 |
+
starlette==0.37.2
|
105 |
+
streamlit==1.34.0
|
106 |
+
sympy==1.12
|
107 |
+
tenacity==8.3.0
|
108 |
+
tokenizers==0.19.1
|
109 |
+
toml==0.10.2
|
110 |
+
toolz==0.12.1
|
111 |
+
torch==2.1.1
|
112 |
+
tornado==6.4
|
113 |
+
tqdm==4.66.4
|
114 |
+
transformers==4.40.2
|
115 |
+
triton==2.1.0
|
116 |
+
typer==0.12.3
|
117 |
+
typing-inspect==0.9.0
|
118 |
+
typing_extensions==4.11.0
|
119 |
+
tzdata==2024.1
|
120 |
+
ujson==5.9.0
|
121 |
+
urllib3==2.2.1
|
122 |
+
uvicorn==0.29.0
|
123 |
+
uvloop==0.19.0
|
124 |
+
vllm==0.2.5
|
125 |
+
watchdog==4.0.0
|
126 |
+
watchfiles==0.21.0
|
127 |
+
webencodings==0.5.1
|
128 |
+
websockets==12.0
|
129 |
xformers==0.0.23
|
130 |
+
yarl==1.9.4
|
131 |
+
yfinance==0.2.38
|
schema.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import List, Dict, Literal, Optional
|
3 |
+
|
4 |
+
class FunctionCall(BaseModel):
|
5 |
+
arguments: dict
|
6 |
+
"""
|
7 |
+
The arguments to call the function with, as generated by the model in JSON
|
8 |
+
format. Note that the model does not always generate valid JSON, and may
|
9 |
+
hallucinate parameters not defined by your function schema. Validate the
|
10 |
+
arguments in your code before calling your function.
|
11 |
+
"""
|
12 |
+
|
13 |
+
name: str
|
14 |
+
"""The name of the function to call."""
|
15 |
+
|
16 |
+
class FunctionDefinition(BaseModel):
|
17 |
+
name: str
|
18 |
+
description: Optional[str] = None
|
19 |
+
parameters: Optional[Dict[str, object]] = None
|
20 |
+
|
21 |
+
class FunctionSignature(BaseModel):
|
22 |
+
function: FunctionDefinition
|
23 |
+
type: Literal["function"]
|
utils.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import datetime
|
7 |
+
import xml.etree.ElementTree as ET
|
8 |
+
|
9 |
+
from logging.handlers import RotatingFileHandler
|
10 |
+
|
11 |
+
logging.basicConfig(
|
12 |
+
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
13 |
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
14 |
+
level=logging.INFO,
|
15 |
+
)
|
16 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
17 |
+
now = datetime.datetime.now()
|
18 |
+
log_folder = os.path.join(script_dir, "inference_logs")
|
19 |
+
os.makedirs(log_folder, exist_ok=True)
|
20 |
+
log_file_path = os.path.join(
|
21 |
+
log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
|
22 |
+
)
|
23 |
+
# Use RotatingFileHandler from the logging.handlers module
|
24 |
+
file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
|
25 |
+
file_handler.setLevel(logging.INFO)
|
26 |
+
|
27 |
+
formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
|
28 |
+
file_handler.setFormatter(formatter)
|
29 |
+
|
30 |
+
inference_logger = logging.getLogger("function-calling-inference")
|
31 |
+
inference_logger.addHandler(file_handler)
|
32 |
+
|
33 |
+
def get_fewshot_examples(num_fewshot):
|
34 |
+
"""return a list of few shot examples"""
|
35 |
+
example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
|
36 |
+
with open(example_path, 'r') as file:
|
37 |
+
examples = json.load(file) # Use json.load with the file object, not the file path
|
38 |
+
if num_fewshot > len(examples):
|
39 |
+
raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
|
40 |
+
return examples[:num_fewshot]
|
41 |
+
|
42 |
+
def get_chat_template(chat_template):
|
43 |
+
"""read chat template from jinja file"""
|
44 |
+
template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
|
45 |
+
|
46 |
+
if not os.path.exists(template_path):
|
47 |
+
print
|
48 |
+
inference_logger.error(f"Template file not found: {chat_template}")
|
49 |
+
return None
|
50 |
+
try:
|
51 |
+
with open(template_path, 'r') as file:
|
52 |
+
template = file.read()
|
53 |
+
return template
|
54 |
+
except Exception as e:
|
55 |
+
print(f"Error loading template: {e}")
|
56 |
+
return None
|
57 |
+
|
58 |
+
def get_assistant_message(completion, chat_template, eos_token):
|
59 |
+
"""define and match pattern to find the assistant message"""
|
60 |
+
completion = completion.strip()
|
61 |
+
|
62 |
+
if chat_template == "zephyr":
|
63 |
+
assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL)
|
64 |
+
elif chat_template == "chatml":
|
65 |
+
assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL)
|
66 |
+
|
67 |
+
elif chat_template == "vicuna":
|
68 |
+
assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL)
|
69 |
+
else:
|
70 |
+
raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
|
71 |
+
|
72 |
+
assistant_match = assistant_pattern.search(completion)
|
73 |
+
if assistant_match:
|
74 |
+
assistant_content = assistant_match.group(1).strip()
|
75 |
+
if chat_template == "vicuna":
|
76 |
+
eos_token = f"</s>{eos_token}"
|
77 |
+
return assistant_content.replace(eos_token, "")
|
78 |
+
else:
|
79 |
+
assistant_content = None
|
80 |
+
inference_logger.info("No match found for the assistant pattern")
|
81 |
+
return assistant_content
|
82 |
+
|
83 |
+
def validate_and_extract_tool_calls(assistant_content):
|
84 |
+
validation_result = False
|
85 |
+
tool_calls = []
|
86 |
+
error_message = None
|
87 |
+
|
88 |
+
try:
|
89 |
+
# wrap content in root element
|
90 |
+
xml_root_element = f"<root>{assistant_content}</root>"
|
91 |
+
root = ET.fromstring(xml_root_element)
|
92 |
+
|
93 |
+
# extract JSON data
|
94 |
+
for element in root.findall(".//tool_call"):
|
95 |
+
json_data = None
|
96 |
+
try:
|
97 |
+
json_text = element.text.strip()
|
98 |
+
|
99 |
+
try:
|
100 |
+
# Prioritize json.loads for better error handling
|
101 |
+
json_data = json.loads(json_text)
|
102 |
+
except json.JSONDecodeError as json_err:
|
103 |
+
try:
|
104 |
+
# Fallback to ast.literal_eval if json.loads fails
|
105 |
+
json_data = ast.literal_eval(json_text)
|
106 |
+
except (SyntaxError, ValueError) as eval_err:
|
107 |
+
error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\
|
108 |
+
f"- JSON Decode Error: {json_err}\n"\
|
109 |
+
f"- Fallback Syntax/Value Error: {eval_err}\n"\
|
110 |
+
f"- Problematic JSON text: {json_text}"
|
111 |
+
inference_logger.error(error_message)
|
112 |
+
continue
|
113 |
+
except Exception as e:
|
114 |
+
error_message = f"Cannot strip text: {e}"
|
115 |
+
inference_logger.error(error_message)
|
116 |
+
|
117 |
+
if json_data is not None:
|
118 |
+
tool_calls.append(json_data)
|
119 |
+
validation_result = True
|
120 |
+
|
121 |
+
except ET.ParseError as err:
|
122 |
+
error_message = f"XML Parse Error: {err}"
|
123 |
+
inference_logger.error(f"XML Parse Error: {err}")
|
124 |
+
|
125 |
+
# Return default values if no valid data is extracted
|
126 |
+
return validation_result, tool_calls, error_message
|
127 |
+
|
128 |
+
def extract_json_from_markdown(text):
|
129 |
+
"""
|
130 |
+
Extracts the JSON string from the given text using a regular expression pattern.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
text (str): The input text containing the JSON string.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
|
137 |
+
"""
|
138 |
+
json_pattern = r'```json\r?\n(.*?)\r?\n```'
|
139 |
+
match = re.search(json_pattern, text, re.DOTALL)
|
140 |
+
if match:
|
141 |
+
json_string = match.group(1)
|
142 |
+
try:
|
143 |
+
data = json.loads(json_string)
|
144 |
+
return data
|
145 |
+
except json.JSONDecodeError as e:
|
146 |
+
print(f"Error decoding JSON string: {e}")
|
147 |
+
else:
|
148 |
+
print("JSON string not found in the text.")
|
149 |
+
return None
|
validator.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import json
|
3 |
+
from jsonschema import validate
|
4 |
+
from pydantic import ValidationError
|
5 |
+
from utils import inference_logger, extract_json_from_markdown
|
6 |
+
from schema import FunctionCall, FunctionSignature
|
7 |
+
|
8 |
+
def validate_function_call_schema(call, signatures):
|
9 |
+
try:
|
10 |
+
call_data = FunctionCall(**call)
|
11 |
+
except ValidationError as e:
|
12 |
+
return False, str(e)
|
13 |
+
|
14 |
+
for signature in signatures:
|
15 |
+
try:
|
16 |
+
signature_data = FunctionSignature(**signature)
|
17 |
+
if signature_data.function.name == call_data.name:
|
18 |
+
# Validate types in function arguments
|
19 |
+
for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
|
20 |
+
if arg_name in call_data.arguments:
|
21 |
+
call_arg_value = call_data.arguments[arg_name]
|
22 |
+
if call_arg_value:
|
23 |
+
try:
|
24 |
+
validate_argument_type(arg_name, call_arg_value, arg_schema)
|
25 |
+
except Exception as arg_validation_error:
|
26 |
+
return False, str(arg_validation_error)
|
27 |
+
|
28 |
+
# Check if all required arguments are present
|
29 |
+
required_arguments = signature_data.function.parameters.get('required', [])
|
30 |
+
result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
|
31 |
+
if not result:
|
32 |
+
return False, f"Missing required arguments: {missing_arguments}"
|
33 |
+
|
34 |
+
return True, None
|
35 |
+
except Exception as e:
|
36 |
+
# Handle validation errors for the function signature
|
37 |
+
return False, str(e)
|
38 |
+
|
39 |
+
# No matching function signature found
|
40 |
+
return False, f"No matching function signature found for function: {call_data.name}"
|
41 |
+
|
42 |
+
def check_required_arguments(call_arguments, required_arguments):
|
43 |
+
missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
|
44 |
+
return not bool(missing_arguments), missing_arguments
|
45 |
+
|
46 |
+
def validate_enum_value(arg_name, arg_value, enum_values):
|
47 |
+
if arg_value not in enum_values:
|
48 |
+
raise Exception(
|
49 |
+
f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
|
50 |
+
)
|
51 |
+
|
52 |
+
def validate_argument_type(arg_name, arg_value, arg_schema):
|
53 |
+
arg_type = arg_schema.get('type', None)
|
54 |
+
if arg_type:
|
55 |
+
if arg_type == 'string' and 'enum' in arg_schema:
|
56 |
+
enum_values = arg_schema['enum']
|
57 |
+
if None not in enum_values and enum_values != []:
|
58 |
+
try:
|
59 |
+
validate_enum_value(arg_name, arg_value, enum_values)
|
60 |
+
except Exception as e:
|
61 |
+
# Propagate the validation error message
|
62 |
+
raise Exception(f"Error validating function call: {e}")
|
63 |
+
|
64 |
+
python_type = get_python_type(arg_type)
|
65 |
+
if not isinstance(arg_value, python_type):
|
66 |
+
raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
|
67 |
+
|
68 |
+
def get_python_type(json_type):
|
69 |
+
type_mapping = {
|
70 |
+
'string': str,
|
71 |
+
'number': (int, float),
|
72 |
+
'integer': int,
|
73 |
+
'boolean': bool,
|
74 |
+
'array': list,
|
75 |
+
'object': dict,
|
76 |
+
'null': type(None),
|
77 |
+
}
|
78 |
+
return type_mapping[json_type]
|
79 |
+
|
80 |
+
def validate_json_data(json_object, json_schema):
|
81 |
+
valid = False
|
82 |
+
error_message = None
|
83 |
+
result_json = None
|
84 |
+
|
85 |
+
try:
|
86 |
+
# Attempt to load JSON using json.loads
|
87 |
+
try:
|
88 |
+
result_json = json.loads(json_object)
|
89 |
+
except json.decoder.JSONDecodeError:
|
90 |
+
# If json.loads fails, try ast.literal_eval
|
91 |
+
try:
|
92 |
+
result_json = ast.literal_eval(json_object)
|
93 |
+
except (SyntaxError, ValueError) as e:
|
94 |
+
try:
|
95 |
+
result_json = extract_json_from_markdown(json_object)
|
96 |
+
except Exception as e:
|
97 |
+
error_message = f"JSON decoding error: {e}"
|
98 |
+
inference_logger.info(f"Validation failed for JSON data: {error_message}")
|
99 |
+
return valid, result_json, error_message
|
100 |
+
|
101 |
+
# Return early if both json.loads and ast.literal_eval fail
|
102 |
+
if result_json is None:
|
103 |
+
error_message = "Failed to decode JSON data"
|
104 |
+
inference_logger.info(f"Validation failed for JSON data: {error_message}")
|
105 |
+
return valid, result_json, error_message
|
106 |
+
|
107 |
+
# Validate each item in the list against schema if it's a list
|
108 |
+
if isinstance(result_json, list):
|
109 |
+
for index, item in enumerate(result_json):
|
110 |
+
try:
|
111 |
+
validate(instance=item, schema=json_schema)
|
112 |
+
inference_logger.info(f"Item {index+1} is valid against the schema.")
|
113 |
+
except ValidationError as e:
|
114 |
+
error_message = f"Validation failed for item {index+1}: {e}"
|
115 |
+
break
|
116 |
+
else:
|
117 |
+
# Default to validation without list
|
118 |
+
try:
|
119 |
+
validate(instance=result_json, schema=json_schema)
|
120 |
+
except ValidationError as e:
|
121 |
+
error_message = f"Validation failed: {e}"
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
error_message = f"Error occurred: {e}"
|
125 |
+
|
126 |
+
if error_message is None:
|
127 |
+
valid = True
|
128 |
+
inference_logger.info("JSON data is valid against the schema.")
|
129 |
+
else:
|
130 |
+
inference_logger.info(f"Validation failed for JSON data: {error_message}")
|
131 |
+
|
132 |
+
return valid, result_json, error_message
|