File size: 10,872 Bytes
13cea7c 2fead0d 13cea7c 2fead0d 13cea7c 2fead0d 13cea7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import argparse
import os
import torch
import numpy as np
import sys
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from utils.poet_utils import StropheParams, Tokens, TextManipulation, TextAnalysis
from utils.base_poet_models import PoetModelBase
from utils.validators import ValidatorInterface
from corpus_capsulated_datasets import CorpusDatasetPytorch
parser = argparse.ArgumentParser()
parser.add_argument("--model_path_full", default='jinymusim/gpt-czech-poet', type=str, help="Path to Model")
parser.add_argument("--rhyme_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils', 'validators', 'rhyme', 'distilroberta-base_BPE_validator_1704126399565')), type=str, help="Path to Model")
parser.add_argument("--metre_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'meter', 'ufal-robeczech-base_BPE_validator_1704126400265')), type=str, help="Path to Model")
parser.add_argument("--year_model_path_full", default=os.path.abspath(os.path.join(os.path.dirname(__file__), 'utils' ,"validators", 'year', 'ufal-robeczech-base_BPE_validator_1702393305267')), type=str, help="Path to Model")
parser.add_argument("--validator_tokenizer_model_rhyme", default='distilroberta-base', type=str, help="Validator tokenizer")
parser.add_argument("--validator_tokenizer_model_meter", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
parser.add_argument("--validator_tokenizer_model_year", default='ufal/robeczech-base', type=str, help="Validator tokenizer")
parser.add_argument("--val_syllables_rhyme", default=False, type=bool, help="Does validator use syllables")
parser.add_argument("--val_syllables_meter", default=False, type=bool, help="Does validator use syllables")
parser.add_argument("--val_syllables_year", default=False, type=bool, help="Does validator use syllables")
if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
_ ,model_rel_name = os.path.split(args.model_path_full)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PoetModelBase(args.model_path_full).to(device)
model.eval()
rhyme_model, meter_model, year_model = None, None, None
rhyme_model_name, meter_model_name, year_model_name = "", "", ""
if args.rhyme_model_path_full:
rhyme_model: ValidatorInterface = (torch.load(args.rhyme_model_path_full, map_location=torch.device('cpu'))).to(device)
rhyme_model.eval()
_, rhyme_model_name = os.path.split(args.rhyme_model_path_full)
if args.metre_model_path_full:
meter_model: ValidatorInterface = (torch.load(args.metre_model_path_full, map_location=torch.device('cpu'))).to(device)
meter_model.eval()
_, meter_model_name = os.path.split(args.metre_model_path_full)
if args.year_model_path_full:
year_model: ValidatorInterface = (torch.load(args.year_model_path_full, map_location=torch.device('cpu'))).to(device)
year_model.eval()
_, year_model_name = os.path.split(args.year_model_path_full)
# Load Rhyme tokenizer
validator_tokenizer_rhyme: PreTrainedTokenizerBase = None
if args.validator_tokenizer_model_rhyme:
try:
validator_tokenizer_rhyme = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_rhyme)
except:
validator_tokenizer_rhyme: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_rhyme)
validator_tokenizer_rhyme.eos_token = Tokens.EOS
validator_tokenizer_rhyme.eos_token_id = Tokens.EOS_ID
validator_tokenizer_rhyme.pad_token = Tokens.PAD
validator_tokenizer_rhyme.pad_token_id = Tokens.PAD_ID
validator_tokenizer_rhyme.unk_token = Tokens.UNK
validator_tokenizer_rhyme.unk_token_id = Tokens.UNK_ID
validator_tokenizer_rhyme.cls_token = Tokens.CLS
validator_tokenizer_rhyme.cls_token_id = Tokens.CLS_ID
validator_tokenizer_rhyme.sep_token = Tokens.SEP
validator_tokenizer_rhyme.sep_token_id = Tokens.SEP_ID
# Load Meter tokenizer
validator_tokenizer_meter: PreTrainedTokenizerBase = None
if args.validator_tokenizer_model_meter:
try:
validator_tokenizer_meter = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_meter, revision='v1.0')
except:
validator_tokenizer_meter: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_meter)
validator_tokenizer_meter.eos_token = Tokens.EOS
validator_tokenizer_meter.eos_token_id = Tokens.EOS_ID
validator_tokenizer_meter.pad_token = Tokens.PAD
validator_tokenizer_meter.pad_token_id = Tokens.PAD_ID
validator_tokenizer_meter.unk_token = Tokens.UNK
validator_tokenizer_meter.unk_token_id = Tokens.UNK_ID
validator_tokenizer_meter.cls_token = Tokens.CLS
validator_tokenizer_meter.cls_token_id = Tokens.CLS_ID
validator_tokenizer_meter.sep_token = Tokens.SEP
validator_tokenizer_meter.sep_token_id = Tokens.SEP_ID
# Load Year tokenizer
validator_tokenizer_year: PreTrainedTokenizerBase = None
if args.validator_tokenizer_model_year:
try:
validator_tokenizer_year = AutoTokenizer.from_pretrained(args.validator_tokenizer_model_year, revision='v1.0')
except:
validator_tokenizer_year: PreTrainedTokenizerBase = PreTrainedTokenizerFast(tokenizer_file=args.validator_tokenizer_model_year)
validator_tokenizer_year.eos_token = Tokens.EOS
validator_tokenizer_year.eos_token_id = Tokens.EOS_ID
validator_tokenizer_year.pad_token = Tokens.PAD
validator_tokenizer_year.pad_token_id = Tokens.PAD_ID
validator_tokenizer_year.unk_token = Tokens.UNK
validator_tokenizer_year.unk_token_id = Tokens.UNK_ID
validator_tokenizer_year.cls_token = Tokens.CLS
validator_tokenizer_year.cls_token_id = Tokens.CLS_ID
validator_tokenizer_year.sep_token = Tokens.SEP
validator_tokenizer_year.sep_token_id = Tokens.SEP_ID
# Load LM tokenizers
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(args.model_path_full)
generation = "BASIC"
def decoder_helper(type, user_input):
if type == "BASIC":
tokenized = tokenizer.encode(user_input, return_tensors='pt', truncation=True)
out = model.model.generate(tokenized.to(device),
max_length=512,
do_sample=True,
top_k=50,
eos_token_id = tokenizer.eos_token_id,
early_stopping=True,
pad_token_id= tokenizer.pad_token_id)
return tokenizer.decode(out.cpu()[0], skip_special_tokens=True)
if type=="FORCED":
return model.generate_forced(user_input, tokenizer, sample=True, device=device)
help = f"""Current setting is {generation} generating.
Change it by writing FORCED/BASIC to input.
Type HELP for HELP.
Type EXIT to exit.
Supported input consist of:
# RHYME_SCHEMA # YEAR
METER # NUM_SYLLABLES # ENDING #
FORCED generation supports incomplete verse input of
J # 13 #
D # 11 # na #
Each verse parameter must be followed by # to be properly analyzed.
IMPORTANT! After inputing prompt (Excluding HELP/EXIT/FORCED/BASIC), user MUST input empty line to finish the prompt!
"""
print("Welcome to simple czech strophe generation.", help)
while True:
user_input = ""
while True:
curr_line = input(">").strip()
if curr_line == 'EXIT':
sys.exit()
elif curr_line == "HELP":
print(help)
continue
elif curr_line == "BASIC":
print("Changed to BASIC")
generation = 'BASIC'
continue
elif curr_line == "FORCED":
print("Changed to FORCED")
generation = "FORCED"
continue
if not curr_line:
break
user_input += curr_line + "\n"
user_input = user_input.strip()
user_reqs = model.analyze_prompt(user_input)
if "RHYME" not in user_reqs.keys() and generation == "BASIC" and user_input:
print("BASIC generation can't work with imputed format.", help)
print("User input is substituted for #")
user_input = '#'
if not user_input:
print("No input, reverting to #")
user_input = '#'
generated_poem:str = decoder_helper(generation, user_input)
# Predictions
meters = []
rhyme_pred = ''
year_pred = 0
for line in generated_poem.splitlines():
# Skip Empty lines
if not line.strip():
break
if not (TextManipulation._remove_most_nonchar(line)).strip():
break
# Validate for Strophe Parameters
if TextAnalysis._is_param_line(line):
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_rhyme,
is_syllable=False, syllables=args.val_syllables_rhyme,
max_len=rhyme_model.model.config.max_position_embeddings - 2)
rhyme_pred =StropheParams.RHYME[np.argmax(rhyme_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
data = CorpusDatasetPytorch.collate_validator([{"input_ids" :[generated_poem]}],tokenizer=validator_tokenizer_year,
is_syllable=False, syllables=args.val_syllables_year,
max_len=year_model.model.config.max_position_embeddings - 2)
year_pred = round(year_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy()[0])
continue
data = CorpusDatasetPytorch.collate_meter([{"input_ids" :["FIRST LINE SKIP!\n" + line]}],tokenizer=validator_tokenizer_meter,
is_syllable=False, syllables=args.val_syllables_meter,
max_len=meter_model.model.config.max_position_embeddings - 2)
meters.append(
StropheParams.METER[np.argmax(meter_model.predict_state(input_ids=data['input_ids'].to(device)).detach().flatten().cpu().numpy())]
)
print(f"REQUESTED: {user_reqs}, GENERATED USING: {generation}\n")
print(generated_poem.strip())
print(f"PREDICTED: {rhyme_pred}, {year_pred}, {meters}\n\n") |