Spaces:
Running
Running
File size: 9,349 Bytes
e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 9d06861 e91ac58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import os, json, gc
import time
import torch
import transformers
import random
from transformers import BitsAndBytesConfig#, AutoModelForCausalLM, AutoTokenizer
from langchain.output_parsers import RetryWithErrorOutputParser
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_experimental.llms import JsonFormer
from langchain.tools import tool
# from langchain_community.llms import CTransformers
# from ctransformers import AutoModelForCausalLM, AutoConfig, Config
from langchain_community.llms import LlamaCpp
# from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.base import BaseCallbackHandler
from huggingface_hub import hf_hub_download
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
from vouchervision.tool_wikipedia import WikipediaLinks
class LocalCPUMistralHandler:
RETRY_DELAY = 2 # Wait 2 seconds before retrying
MAX_RETRIES = 5 # Maximum number of retries
STARTING_TEMP = 0.1
TOKENIZER_NAME = None
VENDOR = 'mistral'
SEED = 2023
def __init__(self, logger, model_name, JSON_dict_structure):
self.logger = logger
self.monitor = SystemLoadMonitor(logger)
self.has_GPU = torch.cuda.is_available()
self.JSON_dict_structure = JSON_dict_structure
self.model_file = None
self.model_name = model_name
# https://medium.com/@scholarly360/mistral-7b-complete-guide-on-colab-129fa5e9a04d
self.model_name = "Mistral-7B-Instruct-v0.2-GGUF" #huggingface-cli download TheBloke/Mistral-7B-Instruct-v0.2-GGUF mistral-7b-instruct-v0.2.Q4_K_M.gguf --local-dir /home/brlab/.cache --local-dir-use-symlinks False
self.model_id = f"TheBloke/{self.model_name}"
name_parts = self.model_name.split('-')
if self.model_name == "Mistral-7B-Instruct-v0.2-GGUF":
self.model_file = 'mistral-7b-instruct-v0.2.Q4_K_M.gguf'
self.model_path = hf_hub_download(repo_id=self.model_id,
filename=self.model_file,
repo_type="model")
else:
raise f"Unsupported GGUF model name"
# self.model_id = f"mistralai/{self.model_name}"
self.gpu_usage = {'max_load': 0, 'max_memory_usage': 0, 'monitoring': True}
self.starting_temp = float(self.STARTING_TEMP)
self.temp_increment = float(0.2)
self.adjust_temp = self.starting_temp
system_prompt = "You are a helpful AI assistant who answers queries with JSON objects and no explanations."
template = """
<s>[INST]{}[/INST]</s>
[INST]{}[/INST]
""".format(system_prompt, "{query}")
# Create a prompt from the template so we can use it with Langchain
self.prompt = PromptTemplate(template=template, input_variables=["query"])
# Set up a parser
self.parser = JsonOutputParser()
self._set_config()
# def _clear_VRAM(self):
# # Clear CUDA cache if it's being used
# if self.has_GPU:
# self.local_model = None
# del self.local_model
# gc.collect() # Explicitly invoke garbage collector
# torch.cuda.empty_cache()
# else:
# self.local_model = None
# del self.local_model
# gc.collect() # Explicitly invoke garbage collector
def _set_config(self):
# self._clear_VRAM()
self.config = {'max_new_tokens': 1024,
'temperature': self.starting_temp,
'seed': self.SEED,
'top_p': 1,
'top_k': 40,
'n_ctx': 4096,
'do_sample': True,
}
self._build_model_chain_parser()
def _adjust_config(self):
new_temp = self.adjust_temp + self.temp_increment
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
self.adjust_temp += self.temp_increment
self.config['temperature'] = self.adjust_temp
def _reset_config(self):
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
self.adjust_temp = self.starting_temp
self.config['temperature'] = self.starting_temp
def _build_model_chain_parser(self):
self.local_model = LlamaCpp(
model_path=self.model_path,
max_tokens=self.config.get('max_new_tokens'),
top_p=self.config.get('top_p'),
# callback_manager=callback_manager,
# n_gpu_layers=1,
# n_batch=512,
n_ctx=self.config.get('n_ctx'),
stop=["[INST]"],
verbose=False,
streaming=False,
)
# Set up the retry parser with the runnable
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
# Create an llm chain with LLM and prompt
self.chain = self.prompt | self.local_model
def call_llm_local_cpu_MistralAI(self, prompt_template, json_report, paths):
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
self.json_report = json_report
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
self.monitor.start_monitoring_usage()
nt_in = 0
nt_out = 0
ind = 0
while ind < self.MAX_RETRIES:
ind += 1
try:
### BELOW IS BASIC MISTRAL CALL
# mistral_prompt = f"<s>[INST] {prompt_template} [/INST]"
# results = self.local_model(mistral_prompt, temperature = 0.7,
# repetition_penalty = 1.15,
# max_new_tokens = 2048)
# print(results)
model_kwargs = {"temperature": self.adjust_temp}
# Invoke the chain to generate prompt text
results = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
# Use retry_parser to parse the response with retry logic
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
if output is None:
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
self._adjust_config()
else:
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
if output is None:
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
self._adjust_config()
else:
self.monitor.stop_inference_timer() # Starts tool timer too
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
Wiki = WikipediaLinks(json_file_path_wiki)
Wiki.gather_wikipedia_results(output)
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
usage_report = self.monitor.stop_monitoring_report_usage()
if self.adjust_temp != self.starting_temp:
self._reset_config()
json_report.set_text(text_main=f'LLM call successful')
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
except Exception as e:
self.logger.error(f'{e}')
self._adjust_config()
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
usage_report = self.monitor.stop_monitoring_report_usage()
self._reset_config()
json_report.set_text(text_main=f'LLM call failed')
return None, nt_in, nt_out, None, None, usage_report
|