from transformers import AutoModelForCausalLM, AutoTokenizer import torch def load_rag_benchmark_tester_ds(): # pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo from datasets import load_dataset ds_name = "llmware/rag_instruct_benchmark_tester" dataset = load_dataset(ds_name) print("update: loading test dataset - ", dataset) test_set = [] for i, samples in enumerate(dataset["train"]): test_set.append(samples) # to view test set samples # print("rag benchmark dataset test samples: ", i, samples) return test_set def run_test(model_name, test_ds): device = "cuda" if torch.cuda.is_available() else "cpu" print("update: model will be loaded on device - ", device) model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) model.to(device) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) for i, entries in enumerate(test_ds): # prepare prompt packaging used in fine-tuning process new_prompt = ": " + entries["context"] + "\n" + entries["query"] + "\n" + ":" inputs = tokenizer(new_prompt, return_tensors="pt") start_of_output = len(inputs.input_ids[0]) # temperature: set at 0.3 for consistency of output # max_new_tokens: set at 100 - may prematurely stop a few of the summaries outputs = model.generate( inputs.input_ids.to(device), eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.3, max_new_tokens=100, ) output_only = tokenizer.decode(outputs[0][start_of_output:],skip_special_tokens=True) # quick/optional post-processing clean-up of potential fine-tuning artifacts eot = output_only.find("<|endoftext|>") if eot > -1: output_only = output_only[:eot] bot = output_only.find(":") if bot > -1: output_only = output_only[bot+len(":"):] # end - post-processing print("\n") print(i, "llm_response - ", output_only) print(i, "gold_answer - ", entries["answer"]) return 0 if __name__ == "__main__": test_ds = load_rag_benchmark_tester_ds() model_name = "llmware/bling-falcon-1b-0.1" output = run_test(model_name,test_ds)