WilliamGazeley commited on
Commit
9e2a95f
1 Parent(s): 1b19641

Migrate to Ollama

Browse files
Files changed (3) hide show
  1. src/app.py +27 -21
  2. src/config.py +8 -3
  3. src/functioncall.py +14 -42
src/app.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
 
2
  import huggingface_hub
3
  import streamlit as st
4
  from config import config
5
  from utils import get_assistant_message
6
  from functioncall import ModelInference
7
- from prompter import PromptManager
8
 
9
- print("Why, hello there!", flush=True)
10
 
11
  @st.cache_resource(show_spinner="Loading model..")
12
  def init_llm():
@@ -14,40 +13,44 @@ def init_llm():
14
  llm = ModelInference(chat_template=config.chat_template)
15
  return llm
16
 
 
17
  def get_response(prompt):
18
  try:
19
  return llm.generate_function_call(
20
- prompt,
21
- config.chat_template,
22
- config.num_fewshot,
23
- config.max_depth
24
  )
25
  except Exception as e:
26
  return f"An error occurred: {str(e)}"
27
-
 
28
  def get_output(context, user_input):
29
  try:
30
  config.status.update(label=":bulb: Preparing answer..")
31
- prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml")
32
- sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \
33
- f"Information:\n{context}"
 
 
 
 
34
  convo = [
35
  {"role": "system", "content": sys_prompt},
36
  {"role": "user", "content": user_input},
37
  ]
38
  response = llm.run_inference(convo)
39
- return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token)
40
  except Exception as e:
41
  return f"An error occurred: {str(e)}"
42
 
 
43
  def main():
44
  st.title("LLM-ADE 9B Demo")
45
-
46
  input_text = st.text_area("Enter your text here:", value="", height=200)
47
-
48
  if st.button("Generate"):
49
  if input_text:
50
- with st.status('Generating response...') as status:
51
  config.status = status
52
  agent_resp = get_response(input_text)
53
  st.write(get_output(agent_resp, input_text))
@@ -55,17 +58,20 @@ def main():
55
  else:
56
  st.warning("Please enter some text to generate a response.")
57
 
 
58
  llm = init_llm()
59
 
60
- def main_headless():
61
- while True:
62
- input_text = input("Enter your text here: ")
63
- agent_resp = get_response(input_text)
64
- print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m')
 
 
65
 
66
  if __name__ == "__main__":
67
- print(f"Test env vars: {os.getenv('TEST_SECRET')}")
68
  if config.headless:
69
- main_headless()
 
70
  else:
71
  main()
 
1
  import os
2
+ from time import time
3
  import huggingface_hub
4
  import streamlit as st
5
  from config import config
6
  from utils import get_assistant_message
7
  from functioncall import ModelInference
 
8
 
 
9
 
10
  @st.cache_resource(show_spinner="Loading model..")
11
  def init_llm():
 
13
  llm = ModelInference(chat_template=config.chat_template)
14
  return llm
15
 
16
+
17
  def get_response(prompt):
18
  try:
19
  return llm.generate_function_call(
20
+ prompt, config.chat_template, config.num_fewshot, config.max_depth
 
 
 
21
  )
22
  except Exception as e:
23
  return f"An error occurred: {str(e)}"
24
+
25
+
26
  def get_output(context, user_input):
27
  try:
28
  config.status.update(label=":bulb: Preparing answer..")
29
+ script_dir = os.path.dirname(os.path.abspath(__file__))
30
+ prompt_path = os.path.join(script_dir, 'prompt_assets', 'output_sys_prompt.yml')
31
+ prompt_schema = llm.prompter.read_yaml_file(prompt_path)
32
+ sys_prompt = (
33
+ llm.prompter.format_yaml_prompt(prompt_schema, dict())
34
+ + f"Information:\n{context}"
35
+ )
36
  convo = [
37
  {"role": "system", "content": sys_prompt},
38
  {"role": "user", "content": user_input},
39
  ]
40
  response = llm.run_inference(convo)
41
+ return response
42
  except Exception as e:
43
  return f"An error occurred: {str(e)}"
44
 
45
+
46
  def main():
47
  st.title("LLM-ADE 9B Demo")
48
+
49
  input_text = st.text_area("Enter your text here:", value="", height=200)
50
+
51
  if st.button("Generate"):
52
  if input_text:
53
+ with st.status("Generating response...") as status:
54
  config.status = status
55
  agent_resp = get_response(input_text)
56
  st.write(get_output(agent_resp, input_text))
 
58
  else:
59
  st.warning("Please enter some text to generate a response.")
60
 
61
+
62
  llm = init_llm()
63
 
64
+
65
+ def main_headless(prompt: str):
66
+ start = time()
67
+ agent_resp = get_response(prompt)
68
+ print("\033[94m" + get_output(agent_resp, prompt) + "\033[0m")
69
+ print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
70
+
71
 
72
  if __name__ == "__main__":
 
73
  if config.headless:
74
+ import fire
75
+ fire.Fire(main_headless)
76
  else:
77
  main()
src/config.py CHANGED
@@ -2,12 +2,18 @@ from pydantic import Field
2
  from pydantic_settings import BaseSettings
3
  from typing import Dict, Any
4
 
 
 
 
 
 
5
  class Config(BaseSettings):
6
  hf_token: str = Field(...)
7
- hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
 
8
  headless: bool = Field(False, description="Run in headless mode.")
9
 
10
- status: Any = None # Hold the status
11
 
12
  az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
13
  az_search_api_key: str = Field(...)
@@ -17,7 +23,6 @@ class Config(BaseSettings):
17
 
18
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
19
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
20
- load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
21
  max_depth: int = Field(3, description="Maximum number of recursive iteration")
22
 
23
  config = Config(_env_file=".env")
 
2
  from pydantic_settings import BaseSettings
3
  from typing import Dict, Any
4
 
5
+ class MockStatus():
6
+ # Required for headless mode
7
+ def update(self, *args, **kwargs):
8
+ print("MockStatus update called with args: ", args, " and kwargs: ", kwargs)
9
+
10
  class Config(BaseSettings):
11
  hf_token: str = Field(...)
12
+ hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev") # We need this because I can't get the model template out of the ollama model
13
+ ollama_model: str = Field("llama3")
14
  headless: bool = Field(False, description="Run in headless mode.")
15
 
16
+ status: Any = MockStatus()
17
 
18
  az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
19
  az_search_api_key: str = Field(...)
 
23
 
24
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
25
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
 
26
  max_depth: int = Field(3, description="Maximum number of recursive iteration")
27
 
28
  config = Config(_env_file=".env")
src/functioncall.py CHANGED
@@ -13,6 +13,7 @@ from transformers import (
13
  import functions
14
  from prompter import PromptManager
15
  from validator import validate_function_call_schema
 
16
 
17
  from utils import (
18
  inference_logger,
@@ -22,26 +23,12 @@ from utils import (
22
  )
23
 
24
  class ModelInference:
25
- def __init__(self, chat_template: str, load_in_4bit: bool = False):
26
  self.prompter = PromptManager()
27
- self.bnb_config = None
28
-
29
- if load_in_4bit == "True": # Never use this
30
- self.bnb_config = BitsAndBytesConfig(
31
- load_in_4bit=True,
32
- bnb_4bit_quant_type="nf4",
33
- bnb_4bit_use_double_quant=True,
34
- )
35
- self.model = AutoModelForCausalLM.from_pretrained(
36
- config.hf_model,
37
- trust_remote_code=True,
38
- return_dict=True,
39
- quantization_config=self.bnb_config,
40
- torch_dtype=torch.float16,
41
- attn_implementation="flash_attention_2",
42
- device_map="auto",
43
- )
44
-
45
  self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
46
  self.tokenizer.pad_token = self.tokenizer.eos_token
47
  self.tokenizer.padding_side = "left"
@@ -49,24 +36,18 @@ class ModelInference:
49
  if self.tokenizer.chat_template is None:
50
  print("No chat template defined, getting chat_template...")
51
  self.tokenizer.chat_template = get_chat_template(chat_template)
52
-
53
- inference_logger.info(self.model.config)
54
- inference_logger.info(self.model.generation_config)
55
- inference_logger.info(self.tokenizer.special_tokens_map)
56
 
57
- def process_completion_and_validate(self, completion, chat_template):
58
-
59
- assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
60
 
61
- if assistant_message:
62
- validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message)
 
63
 
64
  if validation:
65
  inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
66
- return tool_calls, assistant_message, error_message
67
  else:
68
  tool_calls = None
69
- return tool_calls, assistant_message, error_message
70
  else:
71
  inference_logger.warning("Assistant message is None")
72
  raise ValueError("Assistant message is None")
@@ -86,19 +67,10 @@ class ModelInference:
86
  inputs = self.tokenizer.apply_chat_template(
87
  prompt,
88
  add_generation_prompt=True,
89
- return_tensors='pt'
90
- )
91
-
92
- tokens = self.model.generate(
93
- inputs.to(self.model.device),
94
- max_new_tokens=1500,
95
- temperature=0.8,
96
- repetition_penalty=1.2,
97
- do_sample=True,
98
- eos_token_id=self.tokenizer.eos_token_id
99
  )
100
- completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
101
- return completion
102
 
103
  def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
104
  try:
 
13
  import functions
14
  from prompter import PromptManager
15
  from validator import validate_function_call_schema
16
+ from langchain_community.chat_models import ChatOllama
17
 
18
  from utils import (
19
  inference_logger,
 
23
  )
24
 
25
  class ModelInference:
26
+ def __init__(self, chat_template: str):
27
  self.prompter = PromptManager()
28
+
29
+ self.model = ChatOllama(model=config.ollama_model,
30
+ temperature=0.0, format='json')
31
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
33
  self.tokenizer.pad_token = self.tokenizer.eos_token
34
  self.tokenizer.padding_side = "left"
 
36
  if self.tokenizer.chat_template is None:
37
  print("No chat template defined, getting chat_template...")
38
  self.tokenizer.chat_template = get_chat_template(chat_template)
 
 
 
 
39
 
 
 
 
40
 
41
+ def process_completion_and_validate(self, completion, chat_template):
42
+ if completion:
43
+ validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
44
 
45
  if validation:
46
  inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
47
+ return tool_calls, completion, error_message
48
  else:
49
  tool_calls = None
50
+ return tool_calls, completion, error_message
51
  else:
52
  inference_logger.warning("Assistant message is None")
53
  raise ValueError("Assistant message is None")
 
67
  inputs = self.tokenizer.apply_chat_template(
68
  prompt,
69
  add_generation_prompt=True,
70
+ tokenize=False,
 
 
 
 
 
 
 
 
 
71
  )
72
+ completion = self.model.invoke(inputs, format='json')
73
+ return completion.content
74
 
75
  def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
76
  try: