|
import time |
|
|
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
pipeline, |
|
logging, |
|
) |
|
import torch |
|
import json |
|
import re |
|
|
|
|
|
use_4bit = True |
|
|
|
|
|
bnb_4bit_compute_dtype = "float16" |
|
|
|
|
|
bnb_4bit_quant_type = "nf4" |
|
|
|
use_nested_quant = False |
|
|
|
|
|
device_map = {"": 0} |
|
|
|
compute_dtype = getattr(torch, bnb_4bit_compute_dtype) |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=use_4bit, |
|
bnb_4bit_quant_type=bnb_4bit_quant_type, |
|
bnb_4bit_compute_dtype=compute_dtype, |
|
bnb_4bit_use_double_quant=use_nested_quant, |
|
) |
|
model_name = "cjsanjay/llama-3-8B-gorilla-meraki_v2" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
quantization_config=bnb_config, |
|
device_map=device_map |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
|
|
logging.set_verbosity(logging.CRITICAL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open('meraki_full_unknown_fn_dataset_llama_v1.json', 'r') as json_file: |
|
known_test_dataset_gorilla = json.load(json_file) |
|
|
|
matched = 0 |
|
skipped = 0 |
|
failed = 0 |
|
total = len(known_test_dataset_gorilla) |
|
failed_questions = [] |
|
skipped_questions = [] |
|
accuracy = {} |
|
i = 0 |
|
processed_questions = [] |
|
pattern = r'<|im_start|>assistant(.*?)(?:<|im_end|>|$)' |
|
system = ("You are an AI programming assistant, utilizing the finetuned LLM model you only answer questions related to " |
|
"function calling using the provided functions. For politically sensitive questions, security and privacy " |
|
"issues, and other non-computer science questions, you will refuse to answer. Use ") |
|
|
|
|
|
def extract_assistant_function_response(r_patter, generated_text): |
|
""" |
|
|
|
:param r_patter: |
|
:param generated_text: |
|
:return: |
|
""" |
|
m_result = re.findall(pattern, seq['generated_text'], re.DOTALL) |
|
|
|
m_result = [match.strip() for match in m_result] |
|
for match in m_result: |
|
if match.find("api_name") > -1: |
|
return match.strip() |
|
|
|
return None |
|
|
|
|
|
for d in known_test_dataset_gorilla: |
|
i += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
functions_string = json.dumps(d['Functions']) |
|
messages = [ |
|
{"role": "system", "content": f"{system}\n### Instruction: <<functions>> {functions_string}"}, |
|
{"role": "user", "content": d['Instruction']}, |
|
] |
|
processed_questions.append(d) |
|
pipeline1 = pipeline( |
|
"text-generation", |
|
model=model, |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device_map="auto", |
|
tokenizer=tokenizer |
|
) |
|
prompt = pipeline1.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
terminators = [ |
|
pipeline1.tokenizer.eos_token_id, |
|
pipeline1.tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
outputs = pipeline1( |
|
prompt, |
|
max_new_tokens=512, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9, |
|
) |
|
final_assistant_response = None |
|
assistant_raw_response = "" |
|
for seq in outputs: |
|
assistant_raw_response = seq['generated_text'] |
|
final_assistant_response = extract_assistant_function_response(pattern, seq['generated_text']) |
|
try: |
|
if final_assistant_response is None: |
|
d["GotOutput"] = str(assistant_raw_response) |
|
failed_questions.append(d) |
|
failed += 1 |
|
print(f"Improper response from assistant Expected: {d['Output']}, Got: {assistant_raw_response}") |
|
output_data = final_assistant_response |
|
try: |
|
output_data_json = json.loads(final_assistant_response) |
|
if "arguments" in output_data_json: |
|
try: |
|
arg_dict_ans = json.loads(output_data_json["arguments"].replace("'", '"').replace("True", "true").replace("False", "false")) |
|
arg_dict_input = json.loads(d["Output"]["arguments"].replace("'", '"').replace("True", "true").replace("False", "false")) |
|
except Exception as ex: |
|
print (f"Json loading failed for args string: {str(ex)}, Falling back to string comparison, args_string: {output_data_json['arguments']}") |
|
raise |
|
if output_data_json["api_name"] == d["Output"]["api_name"] and arg_dict_ans == arg_dict_input: |
|
matched += 1 |
|
print ("Matched") |
|
else: |
|
d["GotOutput"] = str(output_data) |
|
failed_questions.append(d) |
|
failed += 1 |
|
print(f"JSON mismatch Expected: {d['Output']}, Got: {output_data_json}") |
|
else: |
|
if output_data_json == d["Output"]: |
|
matched += 1 |
|
print ("Matched") |
|
else: |
|
d["GotOutput"] = str(output_data) |
|
failed_questions.append(d) |
|
failed += 1 |
|
print(f"JSON mismatch Expected: {d['Output']}, Got: {output_data_json}") |
|
except Exception as ex: |
|
print (f"Json loading failed: {str(ex)}, Falling back to string comparison") |
|
if str(output_data) == str(d["Output"]): |
|
matched += 1 |
|
print ("Matched") |
|
else: |
|
d["GotOutput"] = str(output_data) |
|
failed_questions.append(d) |
|
failed += 1 |
|
print(f"Expected: {d['Output']}, Got: {output_data}") |
|
except Exception as ex: |
|
print(f"Expected: {d['Output']}, Got: {output_data}, error: {str(ex)}") |
|
failed_questions.append(d) |
|
failed += 1 |
|
|
|
del pipeline1 |
|
del outputs |
|
pipeline1 = None |
|
outputs = None |
|
with torch.no_grad(): |
|
torch.cuda.empty_cache() |
|
print(f"Done: {i}/{total}, Skipped: {skipped}, matched: {matched}, failed: {failed}") |
|
if len(processed_questions) >= 100: |
|
break |
|
time.sleep(1) |
|
input() |
|
|
|
accuracy["matched"] = matched |
|
accuracy["total"] = total - skipped |
|
accuracy["recall"] = float(accuracy["matched"])/accuracy["total"] |
|
|
|
with open("failed_questions_meraki_unknown_test_dataset_llama3_gorilla.json", "w") as f: |
|
json.dump(failed_questions, f, indent=4) |
|
|
|
with open("skipped_questions_meraki_unknown_test_dataset_llama3_gorilla.json", "w") as f: |
|
json.dump(skipped_questions, f, indent=4) |
|
|
|
with open("accuracy_meraki_unknown_test_dataset_llama3_gorilla", "w") as f: |
|
json.dump(accuracy, f, indent=4) |
|
|