|
import argparse |
|
import os |
|
import torch |
|
import numpy as np |
|
|
|
import sys |
|
|
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast |
|
from utils.poet_utils import StropheParams, Tokens, TextManipulation, TextAnalysis |
|
from utils.base_poet_models import PoetModelBase |
|
from utils.validators import ValidatorInterface |
|
|
|
from corpus_capsulated_datasets import CorpusDatasetPytorch |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--model_path_full", default='jinymusim/gpt-czech-poet', type=str, help="Path to Model") |
|
|
|
parser.add_argument("--rhyme_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils', 'validators', 'rhyme', 'distilroberta-base_BPE_validator_1704126399565')), type=str, help="Path to Model") |
|
parser.add_argument("--metre_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'meter', 'ufal-robeczech-base_BPE_validator_1704126400265')), type=str, help="Path to Model") |
|
parser.add_argument("--year_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'year', 'ufal-robeczech-base_BPE_validator_1702393305267')), type=str, help="Path to Model") |
|
|
|
parser.add_argument("--validator_tokenizer_model_rhyme", default='distilroberta-base', type=str, help="Validator tokenizer") |
|
parser.add_argument("--validator_tokenizer_model_meter", default='ufal/robeczech-base', type=str, help="Validator tokenizer") |
|
parser.add_argument("--validator_tokenizer_model_year", default='ufal/robeczech-base', type=str, help="Validator tokenizer") |
|
parser.add_argument("--val_syllables_rhyme", default=False, type=bool, help="Does validator use syllables") |
|
parser.add_argument("--val_syllables_meter", default=False, type=bool, help="Does validator use syllables") |
|
parser.add_argument("--val_syllables_year", default=False, type=bool, help="Does validator use syllables") |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args([] if "__file__" not in globals() else None) |
|
|
|
_ ,model_rel_name = os.path.split(args.model_path_full) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
model = PoetModelBase(args.model_path_full).to(device) |
|
model.eval() |
|
|
|
rhyme_model, meter_model, year_model = None, None, None |
|
rhyme_model_name, meter_model_name, year_model_name = "", "", "" |
|
if args.rhyme_model_path_full: |
|
rhyme_model: ValidatorInterface = (torch.load(args.rhyme_model_path_full, map_location=torch.device('cpu'))).to(device) |
|
rhyme_model.eval() |
|
_, rhyme_model_name = os.path.split(args.rhyme_model_path_full) |
|
|
|
if args.metre_model_path_full: |
|
meter_model: ValidatorInterface = (torch.load(args.metre_model_path_full, map_location=torch.device('cpu'))).to(device) |
|
meter_model.eval() |
|
_, meter_model_name = os.path.split(args.metre_model_path_full) |
|
|
|
if args.year_model_path_full: |
|
year_model: ValidatorInterface = (torch.load(args.year_model_path_full, map_location=torch.device('cpu'))).to(device) |
|
year_model.eval() |
|
_, year_model_name = os.path.split(args.year_model_path_full) |
|
|
|
validator_tokenizer_rhyme: PreTrainedTokenizerBase = None |
|
if args.validator_tokenizer_model_rhyme: |
|
try: |
|
validator_tokenizer_rhyme = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_rhyme) |
|
except: |
|
validator_tokenizer_rhyme: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_rhyme) |
|
validator_tokenizer_rhyme.eos_token = Tokens.EOS |
|
validator_tokenizer_rhyme.eos_token_id = Tokens.EOS_ID |
|
validator_tokenizer_rhyme.pad_token = Tokens.PAD |
|
validator_tokenizer_rhyme.pad_token_id = Tokens.PAD_ID |
|
validator_tokenizer_rhyme.unk_token = Tokens.UNK |
|
validator_tokenizer_rhyme.unk_token_id = Tokens.UNK_ID |
|
validator_tokenizer_rhyme.cls_token = Tokens.CLS |
|
validator_tokenizer_rhyme.cls_token_id = Tokens.CLS_ID |
|
validator_tokenizer_rhyme.sep_token = Tokens.SEP |
|
validator_tokenizer_rhyme.sep_token_id = Tokens.SEP_ID |
|
|
|
|
|
validator_tokenizer_meter: PreTrainedTokenizerBase = None |
|
if args.validator_tokenizer_model_meter: |
|
try: |
|
validator_tokenizer_meter = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_meter, revision='v1.0') |
|
except: |
|
validator_tokenizer_meter: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_meter) |
|
validator_tokenizer_meter.eos_token = Tokens.EOS |
|
validator_tokenizer_meter.eos_token_id = Tokens.EOS_ID |
|
validator_tokenizer_meter.pad_token = Tokens.PAD |
|
validator_tokenizer_meter.pad_token_id = Tokens.PAD_ID |
|
validator_tokenizer_meter.unk_token = Tokens.UNK |
|
validator_tokenizer_meter.unk_token_id = Tokens.UNK_ID |
|
validator_tokenizer_meter.cls_token = Tokens.CLS |
|
validator_tokenizer_meter.cls_token_id = Tokens.CLS_ID |
|
validator_tokenizer_meter.sep_token = Tokens.SEP |
|
validator_tokenizer_meter.sep_token_id = Tokens.SEP_ID |
|
|
|
|
|
validator_tokenizer_year: PreTrainedTokenizerBase = None |
|
if args.validator_tokenizer_model_year: |
|
try: |
|
validator_tokenizer_year = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_year, revision='v1.0') |
|
except: |
|
validator_tokenizer_year: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_year) |
|
validator_tokenizer_year.eos_token = Tokens.EOS |
|
validator_tokenizer_year.eos_token_id = Tokens.EOS_ID |
|
validator_tokenizer_year.pad_token = Tokens.PAD |
|
validator_tokenizer_year.pad_token_id = Tokens.PAD_ID |
|
validator_tokenizer_year.unk_token = Tokens.UNK |
|
validator_tokenizer_year.unk_token_id = Tokens.UNK_ID |
|
validator_tokenizer_year.cls_token = Tokens.CLS |
|
validator_tokenizer_year.cls_token_id = Tokens.CLS_ID |
|
validator_tokenizer_year.sep_token = Tokens.SEP |
|
validator_tokenizer_year.sep_token_id = Tokens.SEP_ID |
|
|
|
|
|
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(args.model_path_full) |
|
|
|
generation = "BASIC" |
|
|
|
def decoder_helper(type, user_input): |
|
if type == "BASIC": |
|
tokenized = tokenizer.encode(user_input, return_tensors='pt', truncation=True) |
|
out = model.model.generate(tokenized.to(device), |
|
max_length=512, |
|
do_sample=True, |
|
top_k=50, |
|
eos_token_id = tokenizer.eos_token_id, |
|
early_stopping=True, |
|
pad_token_id= tokenizer.pad_token_id) |
|
return tokenizer.decode(out.cpu()[0], skip_special_tokens=True) |
|
if type=="FORCED": |
|
return model.generate_forced(user_input, tokenizer, sample=True, device=device) |
|
|
|
help = f"""Current setting is {generation} generating. |
|
Change it by writing FORCED/BASIC to input. |
|
Type HELP for HELP. |
|
Type EXIT to exit. |
|
|
|
Supported input consist of: |
|
# RHYME_SCHEMA # YEAR |
|
METER # NUM_SYLLABLES # ENDING # |
|
|
|
FORCED generation supports incomplete verse input of |
|
J # 13 # |
|
D # 11 # na # |
|
|
|
Each verse parameter must be followed by # to be properly analyzed. |
|
|
|
IMPORTANT! After inputing prompt (Excluding HELP/EXIT/FORCED/BASIC), user MUST input empty line to finish the prompt! |
|
""" |
|
|
|
print("Welcome to simple czech strophe generation.", help) |
|
|
|
while True: |
|
|
|
user_input = "" |
|
while True: |
|
curr_line = input(">").strip() |
|
if curr_line == 'EXIT': |
|
sys.exit() |
|
elif curr_line == "HELP": |
|
print(help) |
|
continue |
|
elif curr_line == "BASIC": |
|
print("Changed to BASIC") |
|
generation = 'BASIC' |
|
continue |
|
elif curr_line == "FORCED": |
|
print("Changed to FORCED") |
|
generation = "FORCED" |
|
continue |
|
if not curr_line: |
|
break |
|
user_input += curr_line + "\n" |
|
|
|
user_input = user_input.strip() |
|
user_reqs = model.analyze_prompt(user_input) |
|
|
|
if "RHYME" not in user_reqs.keys() and generation == "BASIC" and user_input: |
|
print("BASIC generation can't work with imputed format.", help) |
|
print("User input is substituted for #") |
|
user_input = '#' |
|
if not user_input: |
|
print("No input, reverting to #") |
|
user_input = '#' |
|
|
|
generated_poem:str = decoder_helper(generation, user_input) |
|
|
|
|
|
meters = [] |
|
rhyme_pred = '' |
|
year_pred = 0 |
|
for line in generated_poem.splitlines(): |
|
|
|
if not line.strip(): |
|
break |
|
if not (TextManipulation._remove_most_nonchar(line)).strip(): |
|
break |
|
|
|
if TextAnalysis._is_param_line(line): |
|
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_rhyme, |
|
is_syllable=False, syllables=args.val_syllables_rhyme, |
|
max_len=rhyme_model.model.config.max_position_embeddings - 2) |
|
rhyme_pred =StropheParams.RHYME[np.argmax(rhyme_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())] |
|
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_year, |
|
is_syllable=False, syllables=args.val_syllables_year, |
|
max_len=year_model.model.config.max_position_embeddings - 2) |
|
year_pred = round(year_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy()[0]) |
|
continue |
|
data = CorpusDatasetPytorch.collate_meter([{"input_ids" :["FIRST LINE SKIP!\n" + line]}],tokenizer=validator_tokenizer_meter, |
|
is_syllable=False, syllables=args.val_syllables_meter, |
|
max_len=meter_model.model.config.max_position_embeddings - 2) |
|
meters.append( |
|
StropheParams.METER[np.argmax(meter_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())] |
|
) |
|
print(f"REQUESTED: {user_reqs}, GENERATED USING: {generation}\n") |
|
print(generated_poem.strip()) |
|
print(f"PREDICTED: {rhyme_pred}, {year_pred}, {meters}\n\n") |