from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig from Perceptrix.streamer import TextStreamer from utils import setup_device import torch import time import os model_name = os.environ.get('CHAT_MODEL') model_path = "models/CRYSTAL-chat" if model_name == None else model_name config = AutoConfig.from_pretrained( model_path, trust_remote_code=True) device = setup_device() device = "mps" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_path, config=config, low_cpu_mem_usage=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16, # quantization_config=bnb_config, offload_folder="offloads", ) tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" tokenizer = tokenizer model.eval() streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, save_file="reply.txt") def evaluate( prompt='', temperature=0.4, top_p=0.65, top_k=35, repetition_penalty=1.1, max_new_tokens=512, **kwargs, ): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, **kwargs, ) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=max_new_tokens, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, streamer=streamer, ) s = generation_output.sequences[0] output = tokenizer.decode(s, skip_special_tokens=True) yield output.split("### Response:")[-1].strip() def predict( inputs, temperature=0.4, top_p=0.65, top_k=35, repetition_penalty=1.1, max_new_tokens=512, ): now_prompt = inputs response = evaluate( now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True ) for i in response: print(i) response = i return response instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user." def perceptrix(prompt): prompt = instructions+"\n"+prompt response = predict( inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512 ) spl_tokens = ["<|im_start|>", "<|im_end|>"] clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "") return response[len(clean_prompt):] if __name__ == "__main__": history = "" while True: user_input = input("User: ") start = time.time() user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n" result = perceptrix(history+user_input) history += user_input + result + "<|im_end|>\n" print("Answer completed in ~", round(time.time()-start), "seconds")