srossitto79 commited on
Commit
5ec7b76
·
1 Parent(s): d13f3dd

added airLLM

Browse files
Files changed (6) hide show
  1. .DS_Store +0 -0
  2. AirLLM.py +52 -0
  3. RBotReloaded.py +6 -2
  4. agent_llama_ui.py +2 -1
  5. requirements.txt +2 -1
  6. start_agent.sh +22 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
AirLLM.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Mapping, Optional
2
+
3
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
4
+ from langchain.llms.base import LLM
5
+
6
+ from airllm import AirLLMLlama2
7
+
8
+ class AirLLM(LLM):
9
+ max_len: int
10
+ model: AirLLMLlama2
11
+
12
+ def __init__(self, llama2_model_id : str, max_len : int, compression = ""):
13
+ # could use hugging face model repo id:
14
+ self.model = AirLLMLlama2(llama2_model_id,compression=compression)
15
+ self.max_len = max_len
16
+
17
+ @property
18
+ def _llm_type(self) -> str:
19
+ return "custom"
20
+
21
+ def _call(
22
+ self,
23
+ prompt: str,
24
+ stop: Optional[List[str]] = None,
25
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
26
+ **kwargs: Any,
27
+ ) -> str:
28
+ if stop is not None:
29
+ raise ValueError("stop kwargs are not permitted.")
30
+
31
+ input_tokens = model.tokenizer(input_text,
32
+ return_tensors="pt",
33
+ return_attention_mask=False,
34
+ truncation=True,
35
+ max_length=self.max_len,
36
+ padding=True)
37
+
38
+ generation_output = model.generate(
39
+ input_tokens['input_ids'].cuda(),
40
+ max_new_tokens=20,
41
+ use_cache=True,
42
+ return_dict_in_generate=True)
43
+
44
+ output = model.tokenizer.decode(generation_output.sequences[0])
45
+ return output
46
+
47
+
48
+ @property
49
+ def _identifying_params(self) -> Mapping[str, Any]:
50
+ """Get the identifying parameters."""
51
+ return {"max_len": self.max_len}
52
+
RBotReloaded.py CHANGED
@@ -34,12 +34,13 @@ from typing import Any, Dict, List
34
  import torch
35
  from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
36
  import inspect
 
37
 
38
  # Config
39
  EMBD_CHUNK_SIZE = 512
40
  AI_NAME = "Agent Llama"
41
  USER_NAME = "Buddy"
42
- MODELS_DIR = "models"
43
 
44
  def validate_and_fix_params(tool_name, params_list):
45
  try:
@@ -66,7 +67,7 @@ def validate_and_fix_params(tool_name, params_list):
66
  return []
67
 
68
  # Helper to load LM
69
- def create_llm(model_id=f"{MODELS_DIR}/mistral-7b-instruct-v0.1.Q4_K_M.gguf", load_4bit=False, load_8bit=False, ctx_len = 8192, temperature=0.5, top_p=0.95):
70
  if (model_id.startswith("http")):
71
  print(f"Creating TextGen LLM base_url:{model_id}")
72
  return TextGen(model_url=model_id, callbacks=[StreamingStdOutCallbackHandler()])
@@ -89,6 +90,9 @@ def create_llm(model_id=f"{MODELS_DIR}/mistral-7b-instruct-v0.1.Q4_K_M.gguf", lo
89
  except Exception as ex:
90
  print(f"Load Error {str(ex)}")
91
  return None
 
 
 
92
 
93
  # Class to store pages and run queries
94
  class StorageRetrievalLLM:
 
34
  import torch
35
  from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
36
  import inspect
37
+ import AirLLM
38
 
39
  # Config
40
  EMBD_CHUNK_SIZE = 512
41
  AI_NAME = "Agent Llama"
42
  USER_NAME = "Buddy"
43
+ MODELS_DIR = "./models"
44
 
45
  def validate_and_fix_params(tool_name, params_list):
46
  try:
 
67
  return []
68
 
69
  # Helper to load LM
70
+ def create_llm(model_id=f"{MODELS_DIR}/deepseek-coder-6.7b-instruct.Q5_K_M.gguf", load_4bit=False, load_8bit=False, ctx_len = 8192, temperature=0.5, top_p=0.95):
71
  if (model_id.startswith("http")):
72
  print(f"Creating TextGen LLM base_url:{model_id}")
73
  return TextGen(model_url=model_id, callbacks=[StreamingStdOutCallbackHandler()])
 
90
  except Exception as ex:
91
  print(f"Load Error {str(ex)}")
92
  return None
93
+ else:
94
+ print(f"Trying AirLLM to load model_id:{model_id}")
95
+ return AirLLM(llama2_model_id=model_id, max_len=ctx_len, compression=("4bit" if load_4bit else "8bit" if load_8bit else ""))
96
 
97
  # Class to store pages and run queries
98
  class StorageRetrievalLLM:
agent_llama_ui.py CHANGED
@@ -15,7 +15,7 @@ from langchain.schema import AIMessage, HumanMessage
15
  load_dotenv()
16
 
17
 
18
- default_model = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
19
  default_context = 8192
20
  default_load_type = "Auto"
21
  default_iterations = 2
@@ -43,6 +43,7 @@ def get_models():
43
  models = os.listdir(models_directory)
44
  # Filter out any subdirectories, if any
45
  models = [model for model in models if (model.lower().split(".")[-1] in supported_extensions) and os.path.isfile(os.path.join(models_directory, model))]
 
46
  if len(models) == 0:
47
  st.write("Downloading models")
48
  from huggingface_hub import hf_hub_download
 
15
  load_dotenv()
16
 
17
 
18
+ default_model = ""
19
  default_context = 8192
20
  default_load_type = "Auto"
21
  default_iterations = 2
 
43
  models = os.listdir(models_directory)
44
  # Filter out any subdirectories, if any
45
  models = [model for model in models if (model.lower().split(".")[-1] in supported_extensions) and os.path.isfile(os.path.join(models_directory, model))]
46
+
47
  if len(models) == 0:
48
  st.write("Downloading models")
49
  from huggingface_hub import hf_hub_download
requirements.txt CHANGED
@@ -42,4 +42,5 @@ Pillow
42
  langchain
43
  googletrans
44
  python-dotenv
45
- omegaconf
 
 
42
  langchain
43
  googletrans
44
  python-dotenv
45
+ omegaconf
46
+ airllm
start_agent.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Define the name of your virtual environment
4
+ ENV_NAME="myenv"
5
+
6
+ # Check if the virtual environment folder exists
7
+ if [ ! -d "$ENV_NAME" ]; then
8
+ # Create a new virtual environment
9
+ python -m venv $ENV_NAME
10
+ fi
11
+
12
+ # Activate the virtual environment
13
+ source $ENV_NAME/bin/activate
14
+
15
+ # Install the required packages from requirements.txt
16
+ python -m pip install -r requirements.txt
17
+
18
+ # Run your Streamlit application
19
+ python -m streamlit run agent_llama_ui.py
20
+
21
+ # Deactivate the virtual environment
22
+ deactivate