poet-validators / simple_generation_player.py
jinymusim's picture
Upload simple_generation_player.py
2fead0d verified
import argparse
import os
import torch
import numpy as np
import sys
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from utils.poet_utils import StropheParams, Tokens, TextManipulation, TextAnalysis
from utils.base_poet_models import PoetModelBase
from utils.validators import ValidatorInterface
from corpus_capsulated_datasets import CorpusDatasetPytorch
parser = argparse.ArgumentParser()
parser.add_argument("--model_path_full", default='jinymusim/gpt-czech-poet', type=str, help="Path to Model")
parser.add_argument("--rhyme_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils', 'validators', 'rhyme', 'distilroberta-base_BPE_validator_1704126399565')), type=str, help="Path to Model")
parser.add_argument("--metre_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'meter', 'ufal-robeczech-base_BPE_validator_1704126400265')), type=str, help="Path to Model")
parser.add_argument("--year_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'year', 'ufal-robeczech-base_BPE_validator_1702393305267')), type=str, help="Path to Model")
parser.add_argument("--validator_tokenizer_model_rhyme", default='distilroberta-base', type=str, help="Validator tokenizer")
parser.add_argument("--validator_tokenizer_model_meter", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
parser.add_argument("--validator_tokenizer_model_year", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
parser.add_argument("--val_syllables_rhyme", default=False, type=bool, help="Does validator use syllables")
parser.add_argument("--val_syllables_meter", default=False, type=bool, help="Does validator use syllables")
parser.add_argument("--val_syllables_year", default=False, type=bool, help="Does validator use syllables")
if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
_ ,model_rel_name = os.path.split(args.model_path_full)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PoetModelBase(args.model_path_full).to(device)
model.eval()
rhyme_model, meter_model, year_model = None, None, None
rhyme_model_name, meter_model_name, year_model_name = "", "", ""
if args.rhyme_model_path_full:
rhyme_model: ValidatorInterface = (torch.load(args.rhyme_model_path_full, map_location=torch.device('cpu'))).to(device)
rhyme_model.eval()
_, rhyme_model_name = os.path.split(args.rhyme_model_path_full)
if args.metre_model_path_full:
meter_model: ValidatorInterface = (torch.load(args.metre_model_path_full, map_location=torch.device('cpu'))).to(device)
meter_model.eval()
_, meter_model_name = os.path.split(args.metre_model_path_full)
if args.year_model_path_full:
year_model: ValidatorInterface = (torch.load(args.year_model_path_full, map_location=torch.device('cpu'))).to(device)
year_model.eval()
_, year_model_name = os.path.split(args.year_model_path_full)
# Load Rhyme tokenizer
validator_tokenizer_rhyme: PreTrainedTokenizerBase = None
if args.validator_tokenizer_model_rhyme:
try:
validator_tokenizer_rhyme = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_rhyme)
except:
validator_tokenizer_rhyme: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_rhyme)
validator_tokenizer_rhyme.eos_token = Tokens.EOS
validator_tokenizer_rhyme.eos_token_id = Tokens.EOS_ID
validator_tokenizer_rhyme.pad_token = Tokens.PAD
validator_tokenizer_rhyme.pad_token_id = Tokens.PAD_ID
validator_tokenizer_rhyme.unk_token = Tokens.UNK
validator_tokenizer_rhyme.unk_token_id = Tokens.UNK_ID
validator_tokenizer_rhyme.cls_token = Tokens.CLS
validator_tokenizer_rhyme.cls_token_id = Tokens.CLS_ID
validator_tokenizer_rhyme.sep_token = Tokens.SEP
validator_tokenizer_rhyme.sep_token_id = Tokens.SEP_ID
# Load Meter tokenizer
validator_tokenizer_meter: PreTrainedTokenizerBase = None
if args.validator_tokenizer_model_meter:
try:
validator_tokenizer_meter = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_meter, revision='v1.0')
except:
validator_tokenizer_meter: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_meter)
validator_tokenizer_meter.eos_token = Tokens.EOS
validator_tokenizer_meter.eos_token_id = Tokens.EOS_ID
validator_tokenizer_meter.pad_token = Tokens.PAD
validator_tokenizer_meter.pad_token_id = Tokens.PAD_ID
validator_tokenizer_meter.unk_token = Tokens.UNK
validator_tokenizer_meter.unk_token_id = Tokens.UNK_ID
validator_tokenizer_meter.cls_token = Tokens.CLS
validator_tokenizer_meter.cls_token_id = Tokens.CLS_ID
validator_tokenizer_meter.sep_token = Tokens.SEP
validator_tokenizer_meter.sep_token_id = Tokens.SEP_ID
# Load Year tokenizer
validator_tokenizer_year: PreTrainedTokenizerBase = None
if args.validator_tokenizer_model_year:
try:
validator_tokenizer_year = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_year, revision='v1.0')
except:
validator_tokenizer_year: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_year)
validator_tokenizer_year.eos_token = Tokens.EOS
validator_tokenizer_year.eos_token_id = Tokens.EOS_ID
validator_tokenizer_year.pad_token = Tokens.PAD
validator_tokenizer_year.pad_token_id = Tokens.PAD_ID
validator_tokenizer_year.unk_token = Tokens.UNK
validator_tokenizer_year.unk_token_id = Tokens.UNK_ID
validator_tokenizer_year.cls_token = Tokens.CLS
validator_tokenizer_year.cls_token_id = Tokens.CLS_ID
validator_tokenizer_year.sep_token = Tokens.SEP
validator_tokenizer_year.sep_token_id = Tokens.SEP_ID
# Load LM tokenizers
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(args.model_path_full)
generation = "BASIC"
def decoder_helper(type, user_input):
if type == "BASIC":
tokenized = tokenizer.encode(user_input, return_tensors='pt', truncation=True)
out = model.model.generate(tokenized.to(device),
max_length=512,
do_sample=True,
top_k=50,
eos_token_id = tokenizer.eos_token_id,
early_stopping=True,
pad_token_id= tokenizer.pad_token_id)
return tokenizer.decode(out.cpu()[0], skip_special_tokens=True)
if type=="FORCED":
return model.generate_forced(user_input, tokenizer, sample=True, device=device)
help = f"""Current setting is {generation} generating.
Change it by writing FORCED/BASIC to input.
Type HELP for HELP.
Type EXIT to exit.
Supported input consist of:
# RHYME_SCHEMA # YEAR
METER # NUM_SYLLABLES # ENDING #
FORCED generation supports incomplete verse input of
J # 13 #
D # 11 # na #
Each verse parameter must be followed by # to be properly analyzed.
IMPORTANT! After inputing prompt (Excluding HELP/EXIT/FORCED/BASIC), user MUST input empty line to finish the prompt!
"""
print("Welcome to simple czech strophe generation.", help)
while True:
user_input = ""
while True:
curr_line = input(">").strip()
if curr_line == 'EXIT':
sys.exit()
elif curr_line == "HELP":
print(help)
continue
elif curr_line == "BASIC":
print("Changed to BASIC")
generation = 'BASIC'
continue
elif curr_line == "FORCED":
print("Changed to FORCED")
generation = "FORCED"
continue
if not curr_line:
break
user_input += curr_line + "\n"
user_input = user_input.strip()
user_reqs = model.analyze_prompt(user_input)
if "RHYME" not in user_reqs.keys() and generation == "BASIC" and user_input:
print("BASIC generation can't work with imputed format.", help)
print("User input is substituted for #")
user_input = '#'
if not user_input:
print("No input, reverting to #")
user_input = '#'
generated_poem:str = decoder_helper(generation, user_input)
# Predictions
meters = []
rhyme_pred = ''
year_pred = 0
for line in generated_poem.splitlines():
# Skip Empty lines
if not line.strip():
break
if not (TextManipulation._remove_most_nonchar(line)).strip():
break
# Validate for Strophe Parameters
if TextAnalysis._is_param_line(line):
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_rhyme,
is_syllable=False, syllables=args.val_syllables_rhyme,
max_len=rhyme_model.model.config.max_position_embeddings - 2)
rhyme_pred =StropheParams.RHYME[np.argmax(rhyme_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_year,
is_syllable=False, syllables=args.val_syllables_year,
max_len=year_model.model.config.max_position_embeddings - 2)
year_pred = round(year_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy()[0])
continue
data = CorpusDatasetPytorch.collate_meter([{"input_ids" :["FIRST LINE SKIP!\n" + line]}],tokenizer=validator_tokenizer_meter,
is_syllable=False, syllables=args.val_syllables_meter,
max_len=meter_model.model.config.max_position_embeddings - 2)
meters.append(
StropheParams.METER[np.argmax(meter_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
)
print(f"REQUESTED: {user_reqs}, GENERATED USING: {generation}\n")
print(generated_poem.strip())
print(f"PREDICTED: {rhyme_pred}, {year_pred}, {meters}\n\n")