Spaces:
Sleeping
Sleeping
WilliamGazeley
commited on
Commit
·
9efba8b
1
Parent(s):
391d6e2
Integrate working langchain ollama model
Browse files- Dockerfile +1 -0
- src/agents/__init__.py +34 -0
- src/agents/format_scratchpad/functions.py +63 -0
- src/agents/functions_agent/base.py +48 -0
- src/agents/output_parsers/functions.py +77 -0
- src/agents/output_parsers/utils.py +64 -0
- src/functioncall.py +3 -2
- src/functions.py +21 -17
- src/prompts/prompt.py +17 -0
- src/prompts/rag_template.yaml +12 -0
Dockerfile
CHANGED
@@ -45,6 +45,7 @@ RUN pyenv install ${PYTHON_VERSION} && \
|
|
45 |
COPY --chown=1000 ./requirements.txt /tmp/requirements.txt
|
46 |
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
47 |
pip install flash-attn --no-build-isolation
|
|
|
48 |
|
49 |
COPY --chown=1000 src ${HOME}/app
|
50 |
|
|
|
45 |
COPY --chown=1000 ./requirements.txt /tmp/requirements.txt
|
46 |
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
47 |
pip install flash-attn --no-build-isolation
|
48 |
+
RUN ollama pull ${OLLAMA_MODEL}
|
49 |
|
50 |
COPY --chown=1000 src ${HOME}/app
|
51 |
|
src/agents/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.chat_models import ChatOllama
|
2 |
+
from prompts.prompt import rag_agent_prompt
|
3 |
+
from agents.functions_agent.base import create_functions_agent
|
4 |
+
from langchain.agents import AgentExecutor
|
5 |
+
from langchain.memory import ChatMessageHistory
|
6 |
+
from functions import get_openai_functions, tools, get_openai_tools
|
7 |
+
from config import config
|
8 |
+
|
9 |
+
llm = ChatOllama(model = config.ollama_model, temperature = 0.55)
|
10 |
+
|
11 |
+
tools_dict = get_openai_tools()
|
12 |
+
|
13 |
+
history = ChatMessageHistory()
|
14 |
+
|
15 |
+
functions_agent = create_functions_agent(llm=llm, prompt=rag_agent_prompt)
|
16 |
+
functions_agent_executor = AgentExecutor(agent=functions_agent, tools=tools, verbose=True, return_intermediate_steps=True)
|
17 |
+
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
while True:
|
21 |
+
try:
|
22 |
+
inp = input("User:")
|
23 |
+
if inp == "/bye":
|
24 |
+
break
|
25 |
+
|
26 |
+
response = functions_agent_executor.invoke({"input": inp, "chat_history": history, "tools" : tools_dict})
|
27 |
+
response['output'] = response['output'].replace("<|im_end|>", "")
|
28 |
+
history.add_user_message(inp)
|
29 |
+
history.add_ai_message(response['output'])
|
30 |
+
|
31 |
+
print(response['output'])
|
32 |
+
except Exception as e:
|
33 |
+
print(e)
|
34 |
+
continue
|
src/agents/format_scratchpad/functions.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List, Sequence, Tuple
|
3 |
+
|
4 |
+
from langchain_core.agents import AgentAction, AgentActionMessageLog
|
5 |
+
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
|
6 |
+
|
7 |
+
def _convert_agent_action_to_messages(
|
8 |
+
agent_action: AgentAction, observation: str
|
9 |
+
) -> List[BaseMessage]:
|
10 |
+
"""Convert an agent action to a message.
|
11 |
+
This code is used to reconstruct the original AI message from the agent action.
|
12 |
+
Args:
|
13 |
+
agent_action: Agent action to convert.
|
14 |
+
Returns:
|
15 |
+
AIMessage that corresponds to the original tool invocation.
|
16 |
+
"""
|
17 |
+
|
18 |
+
if isinstance(agent_action, AgentActionMessageLog):
|
19 |
+
return list(agent_action.message_log) + [f"<tool_response>\n{_create_function_message(agent_action, observation)}\n</tool_response>"]
|
20 |
+
else:
|
21 |
+
return [AIMessage(content=agent_action.log)]
|
22 |
+
|
23 |
+
def _create_function_message(
|
24 |
+
agent_action: AgentAction, observation: str
|
25 |
+
) -> str:
|
26 |
+
"""Convert agent action and observation into a function message.
|
27 |
+
Args:
|
28 |
+
agent_action: the tool invocation request from the agent
|
29 |
+
observation: the result of the tool invocation
|
30 |
+
Returns:
|
31 |
+
FunctionMessage that corresponds to the original tool invocation
|
32 |
+
"""
|
33 |
+
|
34 |
+
if not isinstance(observation, str):
|
35 |
+
try:
|
36 |
+
content = json.dumps(observation, ensure_ascii=False)
|
37 |
+
except Exception:
|
38 |
+
content = str(observation)
|
39 |
+
else:
|
40 |
+
content = observation
|
41 |
+
tool_response = {
|
42 |
+
"name": agent_action.tool,
|
43 |
+
"content": content,
|
44 |
+
}
|
45 |
+
return json.dumps(tool_response)
|
46 |
+
|
47 |
+
def format_to_function_messages(
|
48 |
+
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
49 |
+
) -> List[BaseMessage]:
|
50 |
+
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
51 |
+
Args:
|
52 |
+
intermediate_steps: Steps the LLM has taken to date, along with observations
|
53 |
+
Returns:
|
54 |
+
list of messages to send to the LLM for the next prediction
|
55 |
+
"""
|
56 |
+
|
57 |
+
messages = []
|
58 |
+
for agent_action, observation in intermediate_steps:
|
59 |
+
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
60 |
+
return messages
|
61 |
+
|
62 |
+
# Backwards compatibility
|
63 |
+
format_to_functions = format_to_function_messages
|
src/agents/functions_agent/base.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from langchain_core.language_models import BaseLanguageModel
|
4 |
+
from langchain_core.prompts.chat import ChatPromptTemplate
|
5 |
+
from langchain_core.runnables import Runnable, RunnablePassthrough
|
6 |
+
from langchain_core.tools import BaseTool
|
7 |
+
|
8 |
+
from agents.format_scratchpad.functions import (
|
9 |
+
format_to_function_messages,
|
10 |
+
)
|
11 |
+
from agents.output_parsers.functions import (
|
12 |
+
FunctionsAgentOutputParser,
|
13 |
+
)
|
14 |
+
|
15 |
+
def create_functions_agent(
|
16 |
+
llm: BaseLanguageModel, prompt: ChatPromptTemplate
|
17 |
+
) -> Runnable:
|
18 |
+
"""Create an agent that uses function calling.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
llm: LLM to use as the agent. Should work with Nous Hermes function calling,
|
22 |
+
so either be an Nous Hermes based model that supports that or a wrapper of
|
23 |
+
a different model that adds in equivalent support.
|
24 |
+
prompt: The prompt to use. See Prompt section below for more.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
A Runnable sequence representing an agent. It takes as input all the same input
|
28 |
+
variables as the prompt passed in does. It returns as output either an
|
29 |
+
AgentAction or AgentFinish.
|
30 |
+
"""
|
31 |
+
if "agent_scratchpad" not in (
|
32 |
+
prompt.input_variables + list(prompt.partial_variables)
|
33 |
+
):
|
34 |
+
raise ValueError(
|
35 |
+
"Prompt must have input variable `agent_scratchpad`, but wasn't found."
|
36 |
+
f"Found {prompt.input_variables} instead."
|
37 |
+
)
|
38 |
+
agent = (
|
39 |
+
RunnablePassthrough.assign(
|
40 |
+
agent_scratchpad=lambda x: format_to_function_messages(
|
41 |
+
x["intermediate_steps"]
|
42 |
+
)
|
43 |
+
)
|
44 |
+
| prompt
|
45 |
+
| llm
|
46 |
+
| FunctionsAgentOutputParser()
|
47 |
+
)
|
48 |
+
return agent
|
src/agents/output_parsers/functions.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from json import JSONDecodeError
|
3 |
+
from typing import List, Union
|
4 |
+
|
5 |
+
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
6 |
+
from langchain_core.exceptions import OutputParserException
|
7 |
+
from langchain_core.messages import (
|
8 |
+
AIMessage,
|
9 |
+
BaseMessage,
|
10 |
+
)
|
11 |
+
from langchain_core.outputs import ChatGeneration, Generation
|
12 |
+
|
13 |
+
from langchain.agents.agent import AgentOutputParser
|
14 |
+
from agents.output_parsers.utils import parse_tool_call, check_tool_call
|
15 |
+
import ast
|
16 |
+
|
17 |
+
class FunctionsAgentOutputParser(AgentOutputParser):
|
18 |
+
"""Parses a message into agent action/finish.
|
19 |
+
|
20 |
+
Is meant to be used with a model with Nous Hermes 2 Pro as the base, as it relies on the specific
|
21 |
+
function_call parameter from Nous Research to convey what tools to use.
|
22 |
+
|
23 |
+
If a function_call parameter is passed, then that is used to get
|
24 |
+
the tool and tool input.
|
25 |
+
|
26 |
+
If one is not passed, then the AIMessage is assumed to be the final output.
|
27 |
+
It was add a
|
28 |
+
"""
|
29 |
+
|
30 |
+
@property
|
31 |
+
def _type(self) -> str:
|
32 |
+
return "functions-agent"
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def _parse_ai_message(message: BaseMessage):
|
36 |
+
"""Parse an AI message."""
|
37 |
+
if not isinstance(message, AIMessage):
|
38 |
+
raise TypeError(f"Expected an AI message got {type(message)}")
|
39 |
+
|
40 |
+
actions = []
|
41 |
+
|
42 |
+
pattern = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
43 |
+
try:
|
44 |
+
tool_calls = [parse_tool_call(t.strip()) for t in pattern.findall(message.content)]
|
45 |
+
except:
|
46 |
+
raise OutputParserException(
|
47 |
+
f"Could not parse tool calls from message content: {message.content}. Please ensure that the tool calls are valid JSON."
|
48 |
+
)
|
49 |
+
|
50 |
+
if not tool_calls:
|
51 |
+
return AgentFinish(
|
52 |
+
return_values={"output": message.content}, log=str(message.content)
|
53 |
+
)
|
54 |
+
|
55 |
+
for tool_call in tool_calls:
|
56 |
+
tool_name, tool_input = check_tool_call(tool_call)
|
57 |
+
content_msg = f"\n{message.content}\n" if message.content else "\n"
|
58 |
+
log = f"\nInvoking: `{tool_name}` with `{tool_input}`\n{content_msg}\n"
|
59 |
+
actions.append(AgentActionMessageLog(
|
60 |
+
tool=tool_name,
|
61 |
+
tool_input=tool_input,
|
62 |
+
log=log,
|
63 |
+
message_log=[message],
|
64 |
+
))
|
65 |
+
|
66 |
+
return actions
|
67 |
+
|
68 |
+
def parse_result(
|
69 |
+
self, result: List[Generation], *, partial: bool = False
|
70 |
+
) -> Union[AgentAction, AgentFinish]:
|
71 |
+
if not isinstance(result[0], ChatGeneration):
|
72 |
+
raise ValueError("This output parser only works on ChatGeneration output")
|
73 |
+
message = result[0].message
|
74 |
+
return self._parse_ai_message(message)
|
75 |
+
|
76 |
+
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
77 |
+
raise ValueError("Can only parse messages")
|
src/agents/output_parsers/utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.utils.function_calling import convert_to_openai_function
|
2 |
+
from functions import tools
|
3 |
+
import re
|
4 |
+
import ast
|
5 |
+
|
6 |
+
def parse_args(args: str):
|
7 |
+
args = args.strip()
|
8 |
+
args = args.replace("true", "True")
|
9 |
+
args = args.replace("false", "False")
|
10 |
+
args = args.replace("null", "None")
|
11 |
+
args = args.replace("\"", "\"\"\"")
|
12 |
+
i = 0
|
13 |
+
while args[i] != "\"" and args[i] != "\'" and i < len(args) - 1:
|
14 |
+
i += 1
|
15 |
+
args = args[i:]
|
16 |
+
if args[-4:] != "True" and args[-5:] != "False":
|
17 |
+
i = len(args) - 1
|
18 |
+
while args[i] != "\"" and args[i] != "\'" and i > 0:
|
19 |
+
i -= 1
|
20 |
+
args = args[:i + 1]
|
21 |
+
print(args)
|
22 |
+
return ast.literal_eval("{" + args + "}")
|
23 |
+
|
24 |
+
def parse_tool_call(call: str):
|
25 |
+
call = call.strip()
|
26 |
+
name: bool = "\"name\": " in call or "\'name\':" in call
|
27 |
+
args: bool = "\"arguments\": " in call or "\'arguments\':" in call
|
28 |
+
if not name:
|
29 |
+
print({"arguments": {}, "name": "missing_function_call"})
|
30 |
+
return {"arguments": {}, "name": "missing_function_call"}
|
31 |
+
if not args:
|
32 |
+
pattern = re.compile(r"\"name\": \"(.*?)\"|\'name\': \'(.*?)\'", re.DOTALL)
|
33 |
+
match = pattern.findall(call)
|
34 |
+
for n in match:
|
35 |
+
if isinstance(n, tuple):
|
36 |
+
n = n[0]
|
37 |
+
print({"arguments": {}, "name": n})
|
38 |
+
return {"arguments": {}, "name": n}
|
39 |
+
args_pattern = re.compile(r"\"arguments\": {(.*?)}|\'arguments\': {(.*?)}", re.DOTALL)
|
40 |
+
args_match = args_pattern.findall(call)
|
41 |
+
for a in args_match:
|
42 |
+
print(a, "\n")
|
43 |
+
print(a[0])
|
44 |
+
args = parse_args(a[0])
|
45 |
+
name_pattern = re.compile(r"\"name\": \"(.*?)\"", re.DOTALL)
|
46 |
+
name_match = name_pattern.findall(call)
|
47 |
+
for n in name_match:
|
48 |
+
if isinstance(n, tuple):
|
49 |
+
n = n[0]
|
50 |
+
print({"arguments": args, "name": n})
|
51 |
+
return {"arguments": args, "name": n}
|
52 |
+
|
53 |
+
|
54 |
+
def check_tool_call(call: dict):
|
55 |
+
global tools
|
56 |
+
tools = [convert_to_openai_function(t) for t in tools]
|
57 |
+
if call["name"] not in [t["name"] for t in tools]:
|
58 |
+
return "handle_tools_error", {"error": {"error": {"name": call["name"]}}}
|
59 |
+
tool = next((t for t in tools if t["name"] == call["name"]), None)
|
60 |
+
|
61 |
+
if set(list(tool["parameters"]["properties"])) != set(list(call["arguments"])):
|
62 |
+
print({"tool_response": {"error": {"expected": list(tool["parameters"]["properties"]), "received": list(call["arguments"])}, "name": call["name"]}})
|
63 |
+
return "handle_tools_error", {"error": {"error": {"expected": list(tool["parameters"]["properties"]), "received": list(call["arguments"])}, "name": call["name"]}}
|
64 |
+
return call["name"], call["arguments"]
|
src/functioncall.py
CHANGED
@@ -11,6 +11,7 @@ import functions
|
|
11 |
from prompter import PromptManager
|
12 |
from validator import validate_function_call_schema
|
13 |
from langchain_community.chat_models import ChatOllama
|
|
|
14 |
from langchain.prompts import PromptTemplate
|
15 |
from langchain_core.output_parsers import StrOutputParser
|
16 |
|
@@ -23,7 +24,7 @@ class ModelInference:
|
|
23 |
def __init__(self, chat_template: str):
|
24 |
self.prompter = PromptManager()
|
25 |
|
26 |
-
self.model =
|
27 |
template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
|
28 |
chain = template | self.model | StrOutputParser()
|
29 |
|
@@ -69,6 +70,7 @@ class ModelInference:
|
|
69 |
add_generation_prompt=True,
|
70 |
tokenize=False,
|
71 |
)
|
|
|
72 |
completion = self.model.invoke(inputs, format='json')
|
73 |
return completion.content
|
74 |
|
@@ -84,7 +86,6 @@ class ModelInference:
|
|
84 |
|
85 |
def recursive_loop(prompt, completion, depth):
|
86 |
nonlocal max_depth
|
87 |
-
breakpoint()
|
88 |
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
|
89 |
prompt.append({"role": "assistant", "content": assistant_message})
|
90 |
|
|
|
11 |
from prompter import PromptManager
|
12 |
from validator import validate_function_call_schema
|
13 |
from langchain_community.chat_models import ChatOllama
|
14 |
+
from langchain_community.llms import Ollama
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
from langchain_core.output_parsers import StrOutputParser
|
17 |
|
|
|
24 |
def __init__(self, chat_template: str):
|
25 |
self.prompter = PromptManager()
|
26 |
|
27 |
+
self.model = Ollama(model=config.ollama_model, temperature=0.0, format='json')
|
28 |
template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
|
29 |
chain = template | self.model | StrOutputParser()
|
30 |
|
|
|
70 |
add_generation_prompt=True,
|
71 |
tokenize=False,
|
72 |
)
|
73 |
+
inputs = inputs.replace("<|begin_of_text|>", "") # Something wrong with the chat template, hotfix
|
74 |
completion = self.model.invoke(inputs, format='json')
|
75 |
return completion.content
|
76 |
|
|
|
86 |
|
87 |
def recursive_loop(prompt, completion, depth):
|
88 |
nonlocal max_depth
|
|
|
89 |
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
|
90 |
prompt.append({"role": "assistant", "content": assistant_message})
|
91 |
|
src/functions.py
CHANGED
@@ -11,7 +11,7 @@ from bs4 import BeautifulSoup
|
|
11 |
from logger import logger
|
12 |
from openai import AzureOpenAI
|
13 |
from langchain.tools import tool
|
14 |
-
from langchain_core.utils.function_calling import convert_to_openai_tool
|
15 |
from config import config
|
16 |
|
17 |
from azure.core.credentials import AzureKeyCredential
|
@@ -281,20 +281,24 @@ def get_company_profile(symbol: str) -> dict:
|
|
281 |
print(f"Error fetching company profile for {symbol}: {e}")
|
282 |
return {}
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
def get_openai_tools() -> List[dict]:
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
# get_stock_fundamentals,
|
292 |
-
# get_financial_statements,
|
293 |
-
get_key_financial_ratios,
|
294 |
-
# get_analyst_recommendations,
|
295 |
-
# get_dividend_data,
|
296 |
-
# get_technical_indicators
|
297 |
-
]
|
298 |
-
|
299 |
-
tools = [convert_to_openai_tool(f) for f in functions]
|
300 |
-
return tools
|
|
|
11 |
from logger import logger
|
12 |
from openai import AzureOpenAI
|
13 |
from langchain.tools import tool
|
14 |
+
from langchain_core.utils.function_calling import convert_to_openai_tool, convert_to_openai_function
|
15 |
from config import config
|
16 |
|
17 |
from azure.core.credentials import AzureKeyCredential
|
|
|
281 |
print(f"Error fetching company profile for {symbol}: {e}")
|
282 |
return {}
|
283 |
|
284 |
+
tools = [
|
285 |
+
get_analysis,
|
286 |
+
# google_search_and_scrape,
|
287 |
+
get_current_stock_price,
|
288 |
+
get_company_news,
|
289 |
+
# get_company_profile,
|
290 |
+
# get_stock_fundamentals,
|
291 |
+
# get_financial_statements,
|
292 |
+
get_key_financial_ratios,
|
293 |
+
# get_analyst_recommendations,
|
294 |
+
# get_dividend_data,
|
295 |
+
# get_technical_indicators
|
296 |
+
]
|
297 |
+
|
298 |
def get_openai_tools() -> List[dict]:
|
299 |
+
tools_ = [convert_to_openai_tool(f) for f in tools]
|
300 |
+
return tools_
|
301 |
+
|
302 |
+
def get_openai_functions() -> List[str]:
|
303 |
+
functions = [convert_to_openai_function(f) for f in tools]
|
304 |
+
return functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/prompts/prompt.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
3 |
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
with open(f"{current_dir}/rag_template.yaml", "r") as yaml_file:
|
8 |
+
templates = yaml.safe_load(yaml_file)
|
9 |
+
|
10 |
+
# RAG Agent
|
11 |
+
sys_msg_template: str = templates["sys_msg"]
|
12 |
+
human_msg_template: str = templates["human_msg"]
|
13 |
+
rag_agent_prompt = ChatPromptTemplate.from_messages([
|
14 |
+
SystemMessagePromptTemplate.from_template(sys_msg_template),
|
15 |
+
HumanMessagePromptTemplate.from_template(human_msg_template),
|
16 |
+
MessagesPlaceholder(variable_name = "agent_scratchpad")
|
17 |
+
])
|
src/prompts/rag_template.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sys_msg: "
|
2 |
+
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
|
3 |
+
<tools>
|
4 |
+
{tools}
|
5 |
+
</tools>
|
6 |
+
Use the following pydantic model json schema for each tool call you will make: {{\"properties\": {{\"arguments\": {{\"title\": \"Arguments\", \"type\": \"object\"}}, \"name\": {{\"title\": \"Name\", \"type\": \"string\"}}}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}
|
7 |
+
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
8 |
+
<tool_call>
|
9 |
+
{{\"arguments\": <args-dict>, \"name\": <function-name>}}
|
10 |
+
</tool_call>"
|
11 |
+
human_msg: "
|
12 |
+
{input}"
|