|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig |
|
from Perceptrix.streamer import TextStreamer |
|
from utils import setup_device |
|
import torch |
|
import time |
|
import os |
|
|
|
model_name = os.environ.get('CHAT_MODEL') |
|
|
|
model_path = "models/CRYSTAL-chat" if model_name == None else model_name |
|
config = AutoConfig.from_pretrained( |
|
model_path, trust_remote_code=True) |
|
|
|
device = setup_device() |
|
device = "mps" |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
config=config, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
|
|
offload_folder="offloads", |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, |
|
trust_remote_code=True, |
|
) |
|
|
|
if tokenizer.pad_token_id is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
tokenizer.padding_side = "left" |
|
tokenizer = tokenizer |
|
model.eval() |
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=True, |
|
skip_special_tokens=True, save_file="reply.txt") |
|
|
|
def evaluate( |
|
prompt='', |
|
temperature=0.4, |
|
top_p=0.65, |
|
top_k=35, |
|
repetition_penalty=1.1, |
|
max_new_tokens=512, |
|
**kwargs, |
|
): |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(device) |
|
generation_config = GenerationConfig( |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
**kwargs, |
|
) |
|
|
|
with torch.no_grad(): |
|
generation_output = model.generate( |
|
input_ids=input_ids, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=max_new_tokens, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
streamer=streamer, |
|
) |
|
s = generation_output.sequences[0] |
|
output = tokenizer.decode(s, skip_special_tokens=True) |
|
yield output.split("### Response:")[-1].strip() |
|
|
|
|
|
def predict( |
|
inputs, |
|
temperature=0.4, |
|
top_p=0.65, |
|
top_k=35, |
|
repetition_penalty=1.1, |
|
max_new_tokens=512, |
|
): |
|
now_prompt = inputs |
|
|
|
response = evaluate( |
|
now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True |
|
) |
|
|
|
for i in response: |
|
print(i) |
|
response = i |
|
|
|
return response |
|
|
|
|
|
instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user." |
|
|
|
def perceptrix(prompt): |
|
prompt = instructions+"\n"+prompt |
|
response = predict( |
|
inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512 |
|
) |
|
spl_tokens = ["<|im_start|>", "<|im_end|>"] |
|
clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "") |
|
return response[len(clean_prompt):] |
|
|
|
|
|
if __name__ == "__main__": |
|
history = "" |
|
while True: |
|
user_input = input("User: ") |
|
start = time.time() |
|
user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n" |
|
result = perceptrix(history+user_input) |
|
history += user_input + result + "<|im_end|>\n" |
|
print("Answer completed in ~", round(time.time()-start), "seconds") |
|
|