WilliamGazeley commited on
Commit
5894c9b
·
1 Parent(s): c124df1

Add final output agent

Browse files
app.py CHANGED
@@ -2,22 +2,14 @@ import os
2
  import huggingface_hub
3
  import streamlit as st
4
  from config import config
5
- from vllm import LLM, SamplingParams
6
  from functioncall import ModelInference
 
7
 
8
- sys_msg = """You are an expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
9
- #Objective:
10
- Answer questions accurately and truthfully given your current knowledge. You do not have access to up-to-date current market data; this will be available in the future. Answer the question directly.
11
- #Style and tone:
12
- Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
13
- #Audience:
14
- The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
15
- #Response:
16
- Direct answer to question, concise yet insightful."""
17
 
18
  @st.cache_resource(show_spinner="Loading model..")
19
  def init_llm():
20
- huggingface_hub.login(token=os.getenv("HF_TOKEN"), new_session=False)
21
  llm = ModelInference(chat_template='chatml')
22
  return llm
23
 
@@ -31,7 +23,20 @@ def get_response(prompt):
31
  )
32
  except Exception as e:
33
  return f"An error occurred: {str(e)}"
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def main():
37
  st.title("LLM-ADE 9B Demo")
@@ -41,8 +46,8 @@ def main():
41
  if st.button("Generate"):
42
  if input_text:
43
  with st.spinner('Generating response...'):
44
- response_text = get_response(input_text)
45
- st.write(response_text)
46
  else:
47
  st.warning("Please enter some text to generate a response.")
48
 
@@ -50,8 +55,12 @@ llm = init_llm()
50
 
51
  def main_headless():
52
  while True:
53
- input_text = input("Enter your text here: ")
54
- print(get_response(input_text))
 
55
 
56
  if __name__ == "__main__":
57
- main_headless()
 
 
 
 
2
  import huggingface_hub
3
  import streamlit as st
4
  from config import config
5
+ from utils import get_assistant_message
6
  from functioncall import ModelInference
7
+ from prompter import PromptManager
8
 
 
 
 
 
 
 
 
 
 
9
 
10
  @st.cache_resource(show_spinner="Loading model..")
11
  def init_llm():
12
+ huggingface_hub.login(token=config.hf_token, new_session=False)
13
  llm = ModelInference(chat_template='chatml')
14
  return llm
15
 
 
23
  )
24
  except Exception as e:
25
  return f"An error occurred: {str(e)}"
26
+
27
+ def get_output(context, user_input):
28
+ try:
29
+ prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml")
30
+ sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \
31
+ f"Information:\n{context}"
32
+ convo = [
33
+ {"role": "system", "content": sys_prompt},
34
+ {"role": "user", "content": user_input},
35
+ ]
36
+ response = llm.run_inference(convo)
37
+ return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token)
38
+ except Exception as e:
39
+ return f"An error occurred: {str(e)}"
40
 
41
  def main():
42
  st.title("LLM-ADE 9B Demo")
 
46
  if st.button("Generate"):
47
  if input_text:
48
  with st.spinner('Generating response...'):
49
+ agent_resp = get_response(input_text)
50
+ st.write(get_output(agent_resp, input_text))
51
  else:
52
  st.warning("Please enter some text to generate a response.")
53
 
 
55
 
56
  def main_headless():
57
  while True:
58
+ input_text = input("Enter your text here: ")
59
+ agent_resp = get_response(input_text)
60
+ print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m')
61
 
62
  if __name__ == "__main__":
63
+ if config.headless:
64
+ main_headless()
65
+ else:
66
+ main()
config.py CHANGED
@@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings
3
 
4
  class Config(BaseSettings):
5
  hf_token: str = Field(...)
6
- model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
 
7
 
8
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
9
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
 
3
 
4
  class Config(BaseSettings):
5
  hf_token: str = Field(...)
6
+ model_path: str = Field("InvestmentResearchAI/LLM-ADE-dev")
7
+ headless: bool = Field(False, description="Run in headless mode.")
8
 
9
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
10
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
functioncall.py CHANGED
@@ -2,9 +2,14 @@ import argparse
2
  import torch
3
  import json
4
  from config import config
 
5
  from vllm import LLM, SamplingParams
6
 
7
- from transformers import BitsAndBytesConfig
 
 
 
 
8
 
9
  import functions
10
  from prompter import PromptManager
@@ -28,9 +33,17 @@ class ModelInference:
28
  bnb_4bit_quant_type="nf4",
29
  bnb_4bit_use_double_quant=True,
30
  )
31
- self.model = LLM(model=config.model)
32
-
33
- self.tokenizer = self.model.get_tokenizer()
 
 
 
 
 
 
 
 
34
  self.tokenizer.pad_token = self.tokenizer.eos_token
35
  self.tokenizer.padding_side = "left"
36
 
@@ -69,17 +82,23 @@ class ModelInference:
69
  results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
70
  return results_dict
71
 
72
- def run_inference(self, prompt):
73
- sampling_params = SamplingParams(
74
- temperature=0.8,
75
- top_p=0.95,
 
 
 
 
 
 
 
76
  repetition_penalty=1.1,
77
- max_tokens=500,
78
- stop_token_ids=[128009])
79
-
80
- outputs = self.model.generate([prompt], sampling_params)
81
- for output in outputs:
82
- return output.outputs[0].text
83
 
84
  def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
85
  try:
@@ -120,7 +139,7 @@ class ModelInference:
120
  return
121
 
122
  completion = self.run_inference(prompt)
123
- recursive_loop(prompt, completion, depth)
124
  elif error_message:
125
  inference_logger.info(f"Assistant Message:\n{assistant_message}")
126
  tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
@@ -132,32 +151,13 @@ class ModelInference:
132
  return
133
 
134
  completion = self.run_inference(prompt)
135
- recursive_loop(prompt, completion, depth)
136
  else:
137
  inference_logger.info(f"Assistant Message:\n{assistant_message}")
 
138
 
139
- recursive_loop(prompt, completion, depth)
140
 
141
  except Exception as e:
142
  inference_logger.error(f"Exception occurred: {e}")
143
  raise e
144
-
145
- if __name__ == "__main__":
146
- parser = argparse.ArgumentParser(description="Run recursive function calling loop")
147
- parser.add_argument("--model_path", type=str, help="Path to the model folder")
148
- parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting")
149
- parser.add_argument("--num_fewshot", type=int, default=None, help="Option to use json mode examples")
150
- parser.add_argument("--load_in_4bit", type=str, default="False", help="Option to load in 4bit with bitsandbytes")
151
- parser.add_argument("--query", type=str, default="I need the current stock price of Tesla (TSLA)")
152
- parser.add_argument("--max_depth", type=int, default=5, help="Maximum number of recursive iteration")
153
- args = parser.parse_args()
154
-
155
- # specify custom model path
156
- if args.model_path:
157
- inference = ModelInference(args.model_path, args.chat_template, args.load_in_4bit)
158
- else:
159
- model_path = 'InvestmentResearchAI/LLM-ADE-dev'
160
- inference = ModelInference(model_path, args.chat_template, args.load_in_4bit)
161
-
162
- # Run the model evaluator
163
- inference.generate_function_call(args.query, args.chat_template, args.num_fewshot, args.max_depth)
 
2
  import torch
3
  import json
4
  from config import config
5
+ from typing import List, Dict
6
  from vllm import LLM, SamplingParams
7
 
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ BitsAndBytesConfig
12
+ )
13
 
14
  import functions
15
  from prompter import PromptManager
 
33
  bnb_4bit_quant_type="nf4",
34
  bnb_4bit_use_double_quant=True,
35
  )
36
+ self.model = AutoModelForCausalLM.from_pretrained(
37
+ config.model_path,
38
+ trust_remote_code=True,
39
+ return_dict=True,
40
+ quantization_config=self.bnb_config,
41
+ torch_dtype=torch.float16,
42
+ attn_implementation="flash_attention_2",
43
+ device_map="auto",
44
+ )
45
+
46
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True)
47
  self.tokenizer.pad_token = self.tokenizer.eos_token
48
  self.tokenizer.padding_side = "left"
49
 
 
82
  results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
83
  return results_dict
84
 
85
+ def run_inference(self, prompt: List[Dict[str, str]]):
86
+ inputs = self.tokenizer.apply_chat_template(
87
+ prompt,
88
+ add_generation_prompt=True,
89
+ return_tensors='pt'
90
+ )
91
+
92
+ tokens = self.model.generate(
93
+ inputs.to(self.model.device),
94
+ max_new_tokens=1500,
95
+ temperature=0.8,
96
  repetition_penalty=1.1,
97
+ do_sample=True,
98
+ eos_token_id=self.tokenizer.eos_token_id
99
+ )
100
+ completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
101
+ return completion
 
102
 
103
  def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
104
  try:
 
139
  return
140
 
141
  completion = self.run_inference(prompt)
142
+ return recursive_loop(prompt, completion, depth)
143
  elif error_message:
144
  inference_logger.info(f"Assistant Message:\n{assistant_message}")
145
  tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
 
151
  return
152
 
153
  completion = self.run_inference(prompt)
154
+ return recursive_loop(prompt, completion, depth)
155
  else:
156
  inference_logger.info(f"Assistant Message:\n{assistant_message}")
157
+ return assistant_message
158
 
159
+ return recursive_loop(prompt, completion, depth)
160
 
161
  except Exception as e:
162
  inference_logger.error(f"Exception occurred: {e}")
163
  raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
functions.py CHANGED
@@ -11,57 +11,6 @@ from utils import inference_logger
11
  from langchain.tools import tool
12
  from langchain_core.utils.function_calling import convert_to_openai_tool
13
 
14
- @tool
15
- def code_interpreter(code_markdown: str) -> dict | str:
16
- """
17
- Execute the provided Python code string on the terminal using exec.
18
-
19
- The string should contain valid, executable and pure Python code in markdown syntax.
20
- Code should also import any required Python packages.
21
-
22
- Args:
23
- code_markdown (str): The Python code with markdown syntax to be executed.
24
- For example: ```python\n<code-string>\n```
25
-
26
- Returns:
27
- dict | str: A dictionary containing variables declared and values returned by function calls,
28
- or an error message if an exception occurred.
29
-
30
- Note:
31
- Use this function with caution, as executing arbitrary code can pose security risks.
32
- """
33
- try:
34
- # Extracting code from Markdown code block
35
- code_lines = code_markdown.split('\n')[1:-1]
36
- code_without_markdown = '\n'.join(code_lines)
37
-
38
- # Create a new namespace for code execution
39
- exec_namespace = {}
40
-
41
- # Execute the code in the new namespace
42
- exec(code_without_markdown, exec_namespace)
43
-
44
- # Collect variables and function call results
45
- result_dict = {}
46
- for name, value in exec_namespace.items():
47
- if callable(value):
48
- try:
49
- result_dict[name] = value()
50
- except TypeError:
51
- # If the function requires arguments, attempt to call it with arguments from the namespace
52
- arg_names = inspect.getfullargspec(value).args
53
- args = {arg_name: exec_namespace.get(arg_name) for arg_name in arg_names}
54
- result_dict[name] = value(**args)
55
- elif not name.startswith('_'): # Exclude variables starting with '_'
56
- result_dict[name] = value
57
-
58
- return result_dict
59
-
60
- except Exception as e:
61
- error_message = f"An error occurred: {e}"
62
- inference_logger.error(error_message)
63
- return error_message
64
-
65
  @tool
66
  def google_search_and_scrape(query: str) -> dict:
67
  """
@@ -297,7 +246,6 @@ def get_company_profile(symbol: str) -> dict:
297
 
298
  def get_openai_tools() -> List[dict]:
299
  functions = [
300
- code_interpreter,
301
  google_search_and_scrape,
302
  get_current_stock_price,
303
  get_company_news,
 
11
  from langchain.tools import tool
12
  from langchain_core.utils.function_calling import convert_to_openai_tool
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @tool
15
  def google_search_and_scrape(query: str) -> dict:
16
  """
 
246
 
247
  def get_openai_tools() -> List[dict]:
248
  functions = [
 
249
  google_search_and_scrape,
250
  get_current_stock_price,
251
  get_company_news,
prompt_assets/output_sys_prompt.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Role: |
2
+ You are an expert financial advisor named IRAI.
3
+ You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
4
+ You can use information given to you, but do not mention function calls.
5
+ Objective: |
6
+ Answer questions accurately and truthfully given your current knowledge. Answer the question directly.
7
+ Instructions: |
8
+ The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
9
+ Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
10
+ Give a direct answer to question, concise yet insightful.
prompt_assets/sys_prompt.yml CHANGED
@@ -1,5 +1,4 @@
1
  Role: |
2
- You are an expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
3
  You are a function calling AI agent with self-recursion.
4
  You can call only one function at a time and analyse data you get from function response.
5
  You are provided with function signatures within <tools></tools> XML tags.
@@ -37,7 +36,3 @@ Instructions: |
37
  <tool_call>
38
  {{"arguments": <args-dict>, "name": <function-name>}}
39
  </tool_call>
40
- Style and tone: |
41
- Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
42
- Audience: |
43
- The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
 
1
  Role: |
 
2
  You are a function calling AI agent with self-recursion.
3
  You can call only one function at a time and analyse data you get from function response.
4
  You are provided with function signatures within <tools></tools> XML tags.
 
36
  <tool_call>
37
  {{"arguments": <args-dict>, "name": <function-name>}}
38
  </tool_call>
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,131 @@
1
- streamlit
2
- transformers
3
- torch
4
- vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  xformers==0.0.23
6
-
 
 
1
+ aiohttp==3.9.5
2
+ aioprometheus==23.12.0
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ appdirs==1.4.4
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ beautifulsoup4==4.12.3
11
+ blinker==1.8.2
12
+ cachetools==5.3.3
13
+ certifi==2024.2.2
14
+ charset-normalizer==3.3.2
15
+ click==8.1.7
16
+ dataclasses-json==0.6.5
17
+ dnspython==2.6.1
18
+ email_validator==2.1.1
19
+ exceptiongroup==1.2.1
20
+ fastapi==0.111.0
21
+ fastapi-cli==0.0.3
22
+ filelock==3.14.0
23
+ frozendict==2.4.4
24
+ frozenlist==1.4.1
25
+ fsspec==2024.3.1
26
+ gitdb==4.0.11
27
+ GitPython==3.1.43
28
+ greenlet==3.0.3
29
+ h11==0.14.0
30
+ html5lib==1.1
31
+ httpcore==1.0.5
32
+ httptools==0.6.1
33
+ httpx==0.27.0
34
+ huggingface-hub==0.23.0
35
+ idna==3.7
36
+ Jinja2==3.1.4
37
+ jsonpatch==1.33
38
+ jsonpointer==2.4
39
+ jsonschema==4.22.0
40
+ jsonschema-specifications==2023.12.1
41
+ langchain==0.1.17
42
+ langchain-community==0.0.37
43
+ langchain-core==0.1.52
44
+ langchain-text-splitters==0.0.1
45
+ langsmith==0.1.54
46
+ lxml==5.2.1
47
+ markdown-it-py==3.0.0
48
+ MarkupSafe==2.1.5
49
+ marshmallow==3.21.2
50
+ mdurl==0.1.2
51
+ mpmath==1.3.0
52
+ msgpack==1.0.8
53
+ multidict==6.0.5
54
+ multitasking==0.0.11
55
+ mypy-extensions==1.0.0
56
+ networkx==3.3
57
+ ninja==1.11.1.1
58
+ numpy==1.26.4
59
+ nvidia-cublas-cu12==12.1.3.1
60
+ nvidia-cuda-cupti-cu12==12.1.105
61
+ nvidia-cuda-nvrtc-cu12==12.1.105
62
+ nvidia-cuda-runtime-cu12==12.1.105
63
+ nvidia-cudnn-cu12==8.9.2.26
64
+ nvidia-cufft-cu12==11.0.2.54
65
+ nvidia-curand-cu12==10.3.2.106
66
+ nvidia-cusolver-cu12==11.4.5.107
67
+ nvidia-cusparse-cu12==12.1.0.106
68
+ nvidia-nccl-cu12==2.18.1
69
+ nvidia-nvjitlink-cu12==12.4.127
70
+ nvidia-nvtx-cu12==12.1.105
71
+ orjson==3.10.3
72
+ packaging==23.2
73
+ pandas==2.2.2
74
+ peewee==3.17.3
75
+ pillow==10.3.0
76
+ protobuf==4.25.3
77
+ psutil==5.9.8
78
+ pyarrow==16.0.0
79
+ pydantic==2.7.1
80
+ pydantic-settings==2.2.1
81
+ pydantic_core==2.18.2
82
+ pydeck==0.9.0
83
+ Pygments==2.18.0
84
+ python-dateutil==2.9.0.post0
85
+ python-dotenv==1.0.1
86
+ python-multipart==0.0.9
87
+ pytz==2024.1
88
+ PyYAML==6.0.1
89
+ quantile-python==1.1
90
+ ray==2.20.0
91
+ referencing==0.35.1
92
+ regex==2024.4.28
93
+ requests==2.31.0
94
+ rich==13.7.1
95
+ rpds-py==0.18.1
96
+ safetensors==0.4.3
97
+ sentencepiece==0.2.0
98
+ shellingham==1.5.4
99
+ six==1.16.0
100
+ smmap==5.0.1
101
+ sniffio==1.3.1
102
+ soupsieve==2.5
103
+ SQLAlchemy==2.0.30
104
+ starlette==0.37.2
105
+ streamlit==1.34.0
106
+ sympy==1.12
107
+ tenacity==8.3.0
108
+ tokenizers==0.19.1
109
+ toml==0.10.2
110
+ toolz==0.12.1
111
+ torch==2.1.1
112
+ tornado==6.4
113
+ tqdm==4.66.4
114
+ transformers==4.40.2
115
+ triton==2.1.0
116
+ typer==0.12.3
117
+ typing-inspect==0.9.0
118
+ typing_extensions==4.11.0
119
+ tzdata==2024.1
120
+ ujson==5.9.0
121
+ urllib3==2.2.1
122
+ uvicorn==0.29.0
123
+ uvloop==0.19.0
124
+ vllm==0.2.5
125
+ watchdog==4.0.0
126
+ watchfiles==0.21.0
127
+ webencodings==0.5.1
128
+ websockets==12.0
129
  xformers==0.0.23
130
+ yarl==1.9.4
131
+ yfinance==0.2.38