from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig from Perceptrix.streamer import TextStreamer from utils import setup_device import torch import os model_name = os.environ.get('CHAT_MODEL') model_path = "models/CRYSTAL-chat" if model_name == None else model_name config = AutoConfig.from_pretrained( model_path, trust_remote_code=True) device = setup_device() bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16 ) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16, config=config, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True, offload_folder="offloads", quantization_config=bnb_config if str(device) != "cpu" else None, ) tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" tokenizer = tokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.eval() streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, save_file="reply.txt") def evaluate( prompt='', temperature=0.4, top_p=0.65, top_k=35, repetition_penalty=1.1, max_new_tokens=512, **kwargs, ): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, **kwargs, ) with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=True, max_new_tokens=max_new_tokens, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, streamer=streamer, ) s = generation_output.sequences[0] output = tokenizer.decode(s, skip_special_tokens=True) yield output.split("### Response:")[-1].strip() def predict( inputs, temperature=0.4, top_p=0.65, top_k=35, repetition_penalty=1.1, max_new_tokens=512, ): now_prompt = inputs response = evaluate( now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True ) for i in response: print(i) response = i return response instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user." def perceptrix(prompt): prompt = instructions+"\n"+prompt response = predict( inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512 ) spl_tokens = ["<|im_start|>", "<|im_end|>"] clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "") return response[len(clean_prompt):] if __name__ == "__main__": history = "" while True: user_input = input("User: ") user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n" result = perceptrix(history+user_input) history += user_input + result + "<|im_end|>\n"