|
import torch |
|
import json |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
|
|
device_map = "auto" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"/path/to/meta-llama3-8b/", |
|
return_dict=True, |
|
torch_dtype=torch.float16, |
|
device_map=device_map) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("/path/to/meta-llama3-8b/",add_eos_token=True) |
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id + 1 |
|
tokenizer.padding_side = "right" |
|
|
|
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, max_new_tokens=100) |
|
|
|
test_dataset = load_dataset("json", data_files={'test':'/path/to/parser_test_moves_15.jsonl'})["test"] |
|
|
|
def is_first_moves(sample): |
|
answer = 0 |
|
slist = sample.split('\n') |
|
if slist[0].startswith('Context: 0 <Buil> Mission has started.'): |
|
struct = [i for i in slist if i.startswith('Structure:')] |
|
rels = struct[0].split(':')[1].strip() |
|
if len(rels) == 0: |
|
answer = 1 |
|
return answer |
|
|
|
|
|
def check_endpoints(struct, head): |
|
""" |
|
takes a struct string and a head int and returns only |
|
the struct rels with sources that are >= head |
|
""" |
|
new_rels_list = [] |
|
new_rels = None |
|
if struct: |
|
rels = struct.split(' ') |
|
for rel in rels: |
|
if len(rel) > 0: |
|
source = int(rel.split('(')[1].split(',')[0].strip()) |
|
if source >= head: |
|
new_rels_list.append(rel) |
|
if len(new_rels_list) > 0: |
|
new_rels = ' '.join(new_rels_list) |
|
return new_rels |
|
|
|
def add_previous(sample, previous, predictions): |
|
new_output = [] |
|
keep_str = None |
|
|
|
slist = sample.split('\n') |
|
head = int(slist[0].split('Context:')[1].split('<')[0].strip()) |
|
|
|
for s in slist: |
|
if s.startswith('Structure:'): |
|
new_structure = check_endpoints(previous, head) |
|
if new_structure: |
|
s = 'Structure: ' + new_structure + ' ' + predictions |
|
keep_str = new_structure + ' ' + predictions |
|
else: |
|
s = 'Structure: ' + predictions |
|
keep_str = predictions |
|
new_output.append(s) |
|
new_output_string = '\n'.join(new_output) |
|
return keep_str, new_output_string |
|
|
|
def format_gen(preds): |
|
labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN', |
|
'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ'] |
|
split_list = [st.strip() for st in preds.split(' ')] |
|
clean_list = [] |
|
for a in split_list: |
|
s_tuple = None |
|
rel = None |
|
try: |
|
s = a.split('(')[1].split(')')[0].split(',') |
|
r = a.split('(')[0].strip() |
|
except IndexError: |
|
print('split error one') |
|
else: |
|
try: |
|
s_tuple = (int(s[0]), int(s[1])) |
|
except IndexError: |
|
print('split error two') |
|
except ValueError: |
|
print('value error three') |
|
if r in labels: |
|
|
|
rel = r |
|
if rel != None and s_tuple != None: |
|
clean_list.append(rel + '(' + str(s_tuple[0]) + ',' + str(s_tuple[1]) + ')') |
|
clean_preds = ' '.join(clean_list) |
|
return clean_preds |
|
|
|
|
|
def formatting_prompts_func(example): |
|
output_text = '<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n' + example + '\n ### DS:' |
|
return output_text |
|
|
|
|
|
f = open("/path/to/val-output-file.txt","w") |
|
|
|
new_generations = None |
|
previous_generations = None |
|
for datum in tqdm(test_dataset['sample']): |
|
|
|
|
|
if is_first_moves(datum): |
|
text = formatting_prompts_func(datum) |
|
previous_generations = None |
|
else: |
|
|
|
update_prev, amended_text = add_previous(datum, previous_generations, new_generations) |
|
previous_generations = update_prev |
|
text = formatting_prompts_func(amended_text) |
|
generated = pipe(text)[0]['generated_text'] |
|
print(generated, file=f) |
|
new_generations = format_gen(generated.split('### DS:')[1]) |
|
|
|
f.close() |
|
|
|
|