crystal-technologies's picture
Upload 303 files
de4ade4
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from callbacks import Iteratorize, Stream
import transformers
import locale
import torch
import tqdm
import sys
import os
locale.getpreferredencoding = lambda: "UTF-8"
if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open('../pwd.txt', 'r') as pwd:
folder_location = pwd.read()
model_id = f"{folder_location}models/CRYSTAL-model"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
use_fast=False)
model = LlamaForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
load_in_8bit=True,
device_map="auto",
torch_dtype=torch.float16,
)
NOTFIRSTRUN = False
PROMPT = '''### Instruction:
{}
### Input:
{}
### Response:'''
if not NOTFIRSTRUN:
model.config.pad_token_id = tokenizer.pad_token_id = 0
model.config.bos_token_id = 1
model.config.eos_token_id = 2
model.eval()
NOTFIRSTRUN = True
lora_config = LoraConfig.from_pretrained("CRYSTAL-fine")
model = get_peft_model(model, lora_config)
model = torch.compile(model)
def evaluate(
prompt='',
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
stream_output=False,
**kwargs,
):
if str(device) == "cuda":
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
input_ids = inputs["input_ids"].to("cuda")
elif str(device) == "mps":
inputs = tokenizer(prompt, return_tensors="pt").to("mps")
input_ids = inputs["input_ids"].to("mps")
else:
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
**kwargs,
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
if stream_output:
def generate_with_callback(callback=None, **kwargs):
kwargs.setdefault(
"stopping_criteria", transformers.StoppingCriteriaList()
)
kwargs["stopping_criteria"].append(
Stream(callback_func=callback)
)
with torch.no_grad():
model.generate(**kwargs)
def generate_with_streaming(**kwargs):
return Iteratorize(
generate_with_callback, kwargs, callback=None
)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
decoded_output = tokenizer.decode(output)
if output[-1] in [tokenizer.eos_token_id]:
break
yield decoded_output.split("### Response:")[-1].strip()
return
with torch.no_grad():
generation_output = model.generate(
**input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
yield output.split("### Response:")[-1].strip()
temp = 0.4
topp = 0.65
topk = 35
repp = 1.1
maxt = 512
maxh = 5
stream_output = True
sys_input = ""
sys_net = False
chat_history = []
def user(user_message, history):
for idx, content in enumerate(history):
history[idx] = [
content[0].replace('<br>', ''),
content[1].replace('<br>', '')
]
user_message = user_message.replace('<br>', '')
return "", history + [[user_message, None]]
def bot(
history,
current_events,
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
maxh=10,
stream_output=False,
system_prompt="",
system_net=False,
username="User"
):
instruction = """You are an AI made by Vatsal Dutt named CRYSTAL which stands for Comprehensive Robotics Yielding Sophisticated Technology And Logistics."""
current_events = "\nThese are the latest updates:\n"+current_events
hist = f"""{username}: {instruction+current_events}\nCRYSTAL: Okay!\n"""
for idx, content in enumerate(history):
history[idx] = [
content[0].replace('<br>', ''),
None if content[1] is None else content[1].replace('<br>', '')
]
roleplay_keywords = ["模仿", "扮演", "作为", "作為", "装作", "裝作"]
for user, assistant in history[:-1]:
user = user
assistant = assistant
hist += f'{username}: {user}\nCRYSTAL: {assistant}\n'
now_prompt = PROMPT.format(hist, ("CRYSTAL: " + system_prompt if system_prompt !=
"" else system_prompt+'\n\n') + f"{username}: {history[-1][0]}")
if not system_net or not any(roleplay_keyword in history[-1][0] for roleplay_keyword in roleplay_keywords):
system_prompt = ""
bot_message = evaluate(
now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, stream_output
)
if stream_output:
bot_message = tqdm.tqdm(bot_message, unit='token')
for mes in bot_message:
mes = mes
history[-1][1] = mes
history = history[-maxh:]
yield [history, system_prompt]
def ask_crystal(user_message, chat_history, current_events, username="User"):
current_events = "Use the following data as input to answer any of the user queries:\n"+current_events
user_message = current_events+"\n"+user_message
response, updated_chat_history = user(user_message, chat_history)
bot_response_generator = bot(updated_chat_history, current_events, temp, topp,
topk, repp, maxt, maxh, stream_output, sys_input, sys_net, username=username)
for [history, system_prompt] in bot_response_generator:
os.system('clear')
bot_response = history[-1][1]
print(bot_response)
history[-1][1] = bot_response
chat_history = history
return bot_response, chat_history
def chat():
chat_history = []
while True:
events = """Time: 9:33 AM"""
response, chat_history = ask_crystal(
input("Enter Query: "), chat_history, current_events=events)
if __name__ == "__main__":
chat()