crystal-technologies's picture
Upload 303 files
de4ade4
raw
history blame
3.58 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig
from Perceptrix.streamer import TextStreamer
from utils import setup_device
import torch
import time
import os
model_name = os.environ.get('CHAT_MODEL')
model_path = "models/CRYSTAL-chat" if model_name == None else model_name
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=True)
device = setup_device()
device = "mps"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16,
# quantization_config=bnb_config,
offload_folder="offloads",
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer = tokenizer
model.eval()
streamer = TextStreamer(tokenizer, skip_prompt=True,
skip_special_tokens=True, save_file="reply.txt")
def evaluate(
prompt='',
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
**kwargs,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
yield output.split("### Response:")[-1].strip()
def predict(
inputs,
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
):
now_prompt = inputs
response = evaluate(
now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True
)
for i in response:
print(i)
response = i
return response
instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user."
def perceptrix(prompt):
prompt = instructions+"\n"+prompt
response = predict(
inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512
)
spl_tokens = ["<|im_start|>", "<|im_end|>"]
clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
return response[len(clean_prompt):]
if __name__ == "__main__":
history = ""
while True:
user_input = input("User: ")
start = time.time()
user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n"
result = perceptrix(history+user_input)
history += user_input + result + "<|im_end|>\n"
print("Answer completed in ~", round(time.time()-start), "seconds")