from .poet_model_utils import PoetModelInterface from .poet_utils import TextAnalysis, StropheParams from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.utils import ModelOutput import random import torch class PoetModelFunctionalInterface(PoetModelInterface): """Poet Model Functional Interface. Abstract class with implementation of Args: PoetModelInterface (_type_): Is child of PoetModelInterface for carrying core methods """ def __init__(self, *args, **kwargs) -> None: """ Constructor. As child Class needs to construct Parent """ super().__init__(*args, **kwargs) def analyze_prompt(self, prompt) -> dict: """Analysis of users prompt Args: prompt (_type_): dict or string, carrying users intent Returns: dict: Analysis with users intended input """ if isinstance(prompt, dict): return prompt features_dict = {} lines = prompt.splitlines() lines = list(map(str.strip, lines)) i = 0 while i < len(lines): if not lines[i]: lines.pop(i) i-=1 i+=1 cont_line = 0 for line in lines: if TextAnalysis._is_param_line(line): for key, value in TextAnalysis._first_line_analysis(line).items(): features_dict[key] = value else: val = cont_line if "RHYME" in features_dict.keys() and cont_line < len(features_dict['RHYME']): if features_dict["RHYME"][cont_line] == "A": val = 0 elif features_dict["RHYME"][cont_line] == "B": val = 1 elif features_dict["RHYME"][cont_line] == "C": val = 2 elif features_dict["RHYME"][cont_line] == "D": val = 3 for key, value in TextAnalysis._continuos_line_analysis(line).items(): features_dict[f"{key}_{val}"] = value cont_line += 1 return features_dict def generate_forced(self, prompt, tokenizer: AutoTokenizer, sample: bool = True, format: str = 'METER_VERSE', device= torch.device('cpu'), *args, **kwargs) -> str: """Generate Strophe using the FORCED generation Args: prompt (_type_): dict or string of users intended parameters of strophe start tokenizer (AutoTokenizer): tokenizer to be used during generation. Should be model specific. sample (bool, optional): If to sample. Defaults to False. format (str, optional): Format of generation to be used. Should be same as trained on. possible formats: BASIC, VERSE_PAR, METER_VERSE, OLD (DEPRECATED! For old models compatibility only). Defaults to 'METER_VERSE'. device (_type_, optional): Device to generate on. CPU as default. Defaults to torch.device('cpu'). Returns: str: Generated Strophe """ features_dict_init = self.analyze_prompt(prompt) # If user parameters as dict, list is initialized to carry future verses. if isinstance(prompt, dict): prompt_list = [] else: prompt_list = prompt.splitlines() # GENERATE FOR POSSIBLE MISSING POET PARAM token_gen_rhyme = tokenizer.encode("#", return_tensors='pt') if sample: rhyme_line = self.model.generate(token_gen_rhyme.to(device), max_new_tokens= 100, do_sample=True, top_k=50, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) else: rhyme_line = self.model.generate(token_gen_rhyme.to(device), max_new_tokens= 100, num_beams=8, no_repeat_ngram_size=2, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) rhyme_dec = tokenizer.decode(rhyme_line.cpu()[0], skip_special_tokens=True).splitlines()[0] features_dict= TextAnalysis._first_line_analysis(rhyme_dec) for key, value in features_dict_init.items(): features_dict[key] = value # CONSTRUCT BEST INPUT LINE # BACKUP RHYME if "RHYME" not in features_dict.keys(): features_dict["RHYME"] = random.choice(StropheParams.RHYME[:-1]) #OLD if format == 'OLD': poet_param_str = "" if "RHYME" in features_dict.keys(): poet_param_str += features_dict["RHYME"] if "YEAR" in features_dict.keys(): poet_param_str += f" # {features_dict['YEAR']}" if 'STROPHE_METER' in features_dict.keys(): poet_param_str += f" # {features_dict['STROPHE_METER']}" elif format != 'METER_VERSE': poet_param_str = "# " if "RHYME" in features_dict.keys(): poet_param_str += features_dict["RHYME"] if "YEAR" in features_dict.keys(): poet_param_str += f" # {features_dict['YEAR']}" if 'STROPHE_METER' in features_dict.keys(): poet_param_str += f" # {features_dict['STROPHE_METER']}" # NEW else: poet_param_str = "# " if "RHYME" in features_dict.keys(): poet_param_str += features_dict["RHYME"] if "YEAR" in features_dict.keys(): poet_param_str += f" # {features_dict['YEAR']}" # REPLACE OR INSERT BASED ON PRESENCE if len(features_dict_init.keys()) == 0: # Wierd Input prompt_list = [poet_param_str] elif len(prompt_list) == 0: # Inputed as Dict prompt_list.append(poet_param_str) elif "RHYME" not in features_dict_init.keys(): if "YEAR" in features_dict_init.keys() or 'STROPHE_METER' in features_dict_init.keys(): # Replace the Uncomplete first line prompt_list[0] = poet_param_str else: prompt_list.insert(0, poet_param_str) else: prompt_list[0] = poet_param_str verse_len = len(features_dict["RHYME"]) # Finish possible not completed lines base_prompt_len = len(prompt_list) for i in range(2,base_prompt_len + 1): token_gen_finish = tokenizer.encode("\n".join(prompt_list[:i]), return_tensors='pt') if sample: finish_line = self.model.generate(token_gen_finish.to(device), max_new_tokens= 100, do_sample=True, top_k=50, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) else: finish_line = self.model.generate(token_gen_finish.to(device), max_new_tokens= 100, num_beams=8, no_repeat_ngram_size=2, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) decoded = tokenizer.decode(finish_line.cpu()[0], skip_special_tokens=True).splitlines() to_dec = min(i, len(decoded)) prompt_list[:to_dec] = decoded[:to_dec] rhyme_char = 0 if features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "B": rhyme_char = 1 elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "C": rhyme_char = 2 elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "D": rhyme_char = 3 elif features_dict["RHYME"][(to_dec - 2) % len(features_dict["RHYME"])] == "X": rhyme_char = -1 if to_dec - 1 < len(prompt_list): dec_line = prompt_list[to_dec-1] #OLD if format == 'VERSE_PAR' or format == 'OLD': if f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 1 and rhyme_char>=0 and dec_line.count("#") <=1: features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[0] features_dict[f'END_{rhyme_char}'] = dec_line.split()[1] elif f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 2 and rhyme_char>=0: features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[0] features_dict[f'END_{rhyme_char}'] = dec_line.split()[2] # NEW elif format == 'METER_VERSE': if f"END_{rhyme_char}" not in features_dict.keys() and len(dec_line.split()) > 4 and rhyme_char>=0: features_dict[f'METER_{rhyme_char}'] = dec_line.split()[0] features_dict[f'LENGTH_{rhyme_char}'] = dec_line.split()[2] features_dict[f'END_{rhyme_char}'] = dec_line.split()[4] # Generating 4 verse rhymes has_rep= False has_rep_again = False while len(prompt_list) <= verse_len: j = 0 if features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "B": j = 1 elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "C": j = 2 elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "D": j = 3 elif features_dict["RHYME"][(len(prompt_list) - 1) % len(features_dict["RHYME"])] == "X": j=-1 #OLD if format == 'BASIC': line_start = "" elif format == 'OLD': line_start = (f"{features_dict[f'LENGTH_{j}']} " if f"LENGTH_{j}" in features_dict.keys() else "" ) + \ (f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "") elif format == 'VERSE_PAR': line_start = (f"{features_dict[f'LENGTH_{j}']} #" if f"LENGTH_{j}" in features_dict.keys() else "" ) + \ (f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "") else: line_start = (f"{features_dict[f'METER_{j}'] } #" if f"METER_{j}" in features_dict.keys() else "") + \ (f" {features_dict[f'LENGTH_{j}']} #" if f"LENGTH_{j}" in features_dict.keys() else "" ) + \ (f" {features_dict[f'END_{j}'] } #" if f"END_{j}" in features_dict.keys() else "") tokenized_poet_start = tokenizer.encode("\n".join(prompt_list) + "\n" + line_start, return_tensors='pt') if sample: out_line = self.model.generate(tokenized_poet_start.to(device), max_new_tokens= 100, do_sample=True, top_k=50, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) else: out_line = self.model.generate(tokenized_poet_start.to(device), max_new_tokens= 100, num_beams=2, no_repeat_ngram_size=2, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) decoded_lines = tokenizer.decode(out_line.cpu()[0], skip_special_tokens=True).splitlines() # Repetition catcher # Possible if len(decoded_lines) <= len(prompt_list) and not(has_rep_again and has_rep): if has_rep: prompt_list.pop() has_rep= False has_rep_again = True else: has_rep = True continue if has_rep_again and has_rep: decoded_line: str = decoded_lines[-1] else: decoded_line: str = decoded_lines[len(prompt_list)] #OLD if format == 'VERSE_PAR' or format == 'OLD': if f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 1 and j>=0 and decoded_line.count("#") <=1: features_dict[f'LENGTH_{j}'] = decoded_line.split()[0] features_dict[f'END_{j}'] = decoded_line.split()[1] elif f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 2 and j>=0: features_dict[f'LENGTH_{j}'] = decoded_line.split()[0] features_dict[f'END_{j}'] = decoded_line.split()[2] # NEW elif format == 'METER_VERSE': if f"END_{j}" not in features_dict.keys() and len(decoded_line.split()) > 4 and j>=0: features_dict[f'METER_{j}'] = decoded_line.split()[0] features_dict[f'LENGTH_{j}'] = decoded_line.split()[2] features_dict[f'END_{j}'] = decoded_line.split()[4] prompt_list.append(decoded_line) return "\n".join(prompt_list) class PoetModelBase(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True) model_config = self.model.config self.model_size = 1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs): outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) return ModelOutput(loss= outputs.loss, model_output=outputs) # {"model_output" : outputs,"loss" : outputs.loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path, safe_serialization=False) class PoetModelAllTasks(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True) model_config = self.model.config self.model_size = 1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel Count self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type self.verse_endings = torch.nn.Linear(self.model_size, len(StropheParams.ENDS)) # Verse End Syllable self.metre_regressor = torch.nn.Linear(self.model_size,len(StropheParams.METER)) # Meter Type self.year_regressor = torch.nn.Linear(self.model_size,len(StropheParams.YEAR)) # Year Bucket def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, verse_end=None, year=None, metre=None, *args, **kwargs): outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) last_hidden = outputs['hidden_states'][-1] vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size))) rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) verse_end_reg = self.verse_endings((last_hidden[:,0,:].view(-1, self.model_size))) metre_regression = self.metre_regressor((last_hidden[:,0,:].view(-1, self.model_size))) year_regression = self.year_regressor((last_hidden[:,0,:].view(-1, self.model_size))) full_loss = outputs.loss vowel_loss = None if nums is not None: loss_fct = torch.nn.MSELoss() vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1)) full_loss = full_loss + 0.1*vowel_loss rhyme_loss = None if rhyme is not None: softmaxed = torch.softmax(rhyme_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() rhyme_loss = loss_fct(softmaxed, rhyme) full_loss = full_loss + 0.1*rhyme_loss verse_loss = None if verse_end is not None: softmaxed = torch.softmax(verse_end_reg, dim=1) loss_fct = torch.nn.CrossEntropyLoss() verse_loss = loss_fct(softmaxed, verse_end) full_loss = full_loss + 0.1*verse_loss metre_loss = None if metre is not None: softmaxed = torch.softmax(metre_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() metre_loss = loss_fct(softmaxed, metre) full_loss = full_loss + 0.1*metre_loss year_loss = None if year is not None: softmaxed = torch.softmax(year_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() year_loss = loss_fct(softmaxed, year) full_loss = full_loss + 0.1*year_loss return {"model_output" : outputs, "vowel_regression_output": vowel_regression, "vowel_regression_loss": vowel_loss, "rhyme_regression_output": rhyme_regression, "rhyme_regression_loss": rhyme_loss, "verse_end_regression_output" : verse_end_reg, "verse_end_regression_loss" : verse_loss, "metre_regression_output" : metre_regression, "metre_regression_loss" : metre_loss, "year_regression_output" : year_regression, "year_regression_loss" : year_loss, "loss": full_loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path, safe_serialization=False) from .poet_model_utils import ContextModule class PoetModelContextInput(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, context_input_size:int = 2048, block_count:int=3, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel,output_hidden_states=True) model_config = self.model.config self.model_size = -1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size # Number of Emmbedings taken from config self.context_size = context_input_size self.model.base_model.h.insert(3, ContextModule(block_count, context_input_size, self.model_size, self.model_size)) # Because of Inserted Layer, Head Masks don't match => Add 1 more self.model.base_model.config.n_layer += 1 self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type def forward(self, input_ids=None, labels=None, attention_mask=None, rhyme=None, context_ids=None, context_attention_mask=None,*args, **kwargs): # Inject Context to bypass GPT2Blocks (Can't Forward it) self.model.base_model.h[3].context_ids = context_ids self.model.base_model.h[3].context_attention_mask = context_attention_mask outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) last_hidden = outputs['hidden_states'][-1] rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) full_loss = outputs.loss rhyme_loss = None if rhyme is not None: softmaxed = torch.softmax(rhyme_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() rhyme_loss = loss_fct(softmaxed, rhyme) full_loss = full_loss + rhyme_loss # Delete the Injection to prevent Dataloss self.model.base_model.h[3].context_ids = None self.model.base_model.h[3].context_attention_mask = None return {"model_output" : outputs, "rhyme_regression_output": rhyme_regression, "rhyme_regression_loss": rhyme_loss, "loss": full_loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path) from .poet_model_utils import PoetTypeModule class PoetModelContextYear(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, context_input_size:int = 2048, block_count:int=3, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True) model_config = self.model.config self.model_size = -1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size # Number of Emmbedings taken from config self.context_size = context_input_size self.model.base_model.h.insert(3, ContextModule(block_count, context_input_size, self.model_size, self.model_size)) self.model.base_model.h.insert(3, PoetTypeModule(block_count, context_input_size, self.model_size, self.model_size)) # Because of Inserted Layer, Head Masks don't match => Add 1 more self.model.base_model.config.n_layer += 2 self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type self.year_regressor = torch.nn.Linear(self.model_size, len(StropheParams.YEAR)) # Year Bucket def forward(self, input_ids=None, labels=None, attention_mask=None, rhyme=None, context_ids=None, context_attention_mask=None, year=None,*args, **kwargs): # Inject Context to bypass GPT2Blocks (Can't Forward it) self.model.base_model.h[3].context_ids = context_ids self.model.base_model.h[3].context_attention_mask = context_attention_mask self.model.base_model.h[3].type_labels = year self.model.base_model.h[4].context_ids = context_ids self.model.base_model.h[4].context_attention_mask = context_attention_mask outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) last_hidden = outputs['hidden_states'][-1] rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) full_loss = outputs.loss rhyme_loss = None if rhyme is not None: softmaxed = torch.softmax(rhyme_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() rhyme_loss = loss_fct(softmaxed, rhyme) full_loss = full_loss + rhyme_loss year_regression = self.year_regressor((last_hidden[:,0,:].view(-1, self.model_size))) year_loss = None if year is not None: softmaxed = torch.softmax(year_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() year_loss = loss_fct(softmaxed, year) full_loss = full_loss + year_loss + self.model.base_model.h[3].indiv_loss # Delete the Injection to prevent Dataloss self.model.base_model.h[3].context_ids = None self.model.base_model.h[3].context_attention_mask = None self.model.base_model.h[3].type_labels = None # Delete Loss self.model.base_model.h[3].indiv_loss = None self.model.base_model.h[4].context_ids = None self.model.base_model.h[4].context_attention_mask = None return {"model_output" : outputs, "rhyme_regression_output": rhyme_regression, "rhyme_regression_loss": rhyme_loss, "year_regression_output" : year_regression, "year_loss" : year_loss, "loss": full_loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path) class DistilModel(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True) model_config = self.model.config self.model_size = 1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size self.kept_states = [1, 3, 5, 7, 9, 11] for pop_index in sorted(list(set(range(len(self.model.base_model.h))) - set(self.kept_states)), reverse=True): self.model.base_model.h.pop(pop_index) # Because of Inserted Layer, Head Masks don't match => Add 1 more self.model.base_model.config.n_layer = len(self.kept_states) self.loss_fnc = torch.nn.MSELoss() def forward(self, input_ids=None, labels=None, attention_mask=None, to_replicate_states= None, *args, **kwargs): outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) loss = outputs.loss # The 6 layers + embeddings (add + 1 to shift the original_index) for distil_index, original_index in enumerate([-1] + self.kept_states): loss += self.loss_fnc(outputs['hidden_states'][distil_index], to_replicate_states[original_index + 1]) return {"model_output" : outputs, "loss": loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path, safe_serialization=False) def generate_forced(self, *args, **kwargs): raise NotImplementedError("Currently without") class PoetModelHalfBase(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True, torch_dtype=torch.float16) model_config = self.model.config self.model_size = -1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size def forward(self, input_ids=None, labels=None, attention_mask=None, *args, **kwargs): outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) return {"model_output" : outputs, "loss" : outputs.loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path) class PoetModelSecondaryTasks(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True) model_config = self.model.config self.model_size = -1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size # Number of Emmbedings taken from config self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel count self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, *args, **kwargs): outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) last_hidden = outputs['hidden_states'][-1] vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size))) rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) full_loss = outputs.loss vowel_loss = None if nums is not None: loss_fct = torch.nn.MSELoss() vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1)) full_loss = full_loss + vowel_loss rhyme_loss = None if rhyme is not None: softmaxed = torch.softmax(rhyme_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() rhyme_loss = loss_fct(softmaxed, rhyme) full_loss = full_loss + rhyme_loss return {"model_output" : outputs, "vowel_regression_output": vowel_regression, "vowel_regression_loss": vowel_loss, "rhyme_regression_output": rhyme_regression, "rhyme_regression_loss": rhyme_loss, "loss": full_loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path) class PoetModelVerseEnd(PoetModelFunctionalInterface): def __init__(self, pretrainedModel, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(pretrainedModel, output_hidden_states=True) model_config = self.model.config self.model_size = -1 # Check for Hidden layer size by Attribute Name if hasattr(model_config, "n_embd"): self.model_size = model_config.n_embd elif hasattr(model_config, "hidden_size"): self.model_size = model_config.hidden_size # Number of Emmbedings taken from config self.vowels_regressor = torch.nn.Linear(self.model_size,1) # Vowel count self.rhyme_regressor = torch.nn.Linear(self.model_size, len(StropheParams.RHYME)) # Rhyme Type self.verse_endings = torch.nn.Linear(self.model_size, len(StropheParams.ENDS)) # Verse End Syllable def forward(self, input_ids=None, labels=None, attention_mask=None, nums=None, rhyme=None, verse_end = None, *args, **kwargs): outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) last_hidden = outputs['hidden_states'][-1] vowel_regression = self.vowels_regressor((last_hidden[:,0,:].view(-1, self.model_size))) rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) verse_end_reg = self.verse_endings((last_hidden[:,0,:].view(-1, self.model_size))) full_loss = outputs.loss vowel_loss = None if nums is not None: loss_fct = torch.nn.MSELoss() vowel_loss = loss_fct(vowel_regression.view(-1, 1), nums.view(-1, 1)) full_loss = full_loss + vowel_loss rhyme_loss = None if rhyme is not None: softmaxed = torch.softmax(rhyme_regression, dim=1) loss_fct = torch.nn.CrossEntropyLoss() rhyme_loss = loss_fct(softmaxed, rhyme) full_loss = full_loss + rhyme_loss verse_loss = None if verse_end is not None: softmaxed = torch.softmax(verse_end_reg, dim=1) loss_fct = torch.nn.CrossEntropyLoss() verse_loss = loss_fct(softmaxed, verse_end) full_loss = full_loss + verse_loss return {"model_output" : outputs, "vowel_regression_output": vowel_regression, "vowel_regression_loss": vowel_loss, "rhyme_regression_output": rhyme_regression, "rhyme_regression_loss": rhyme_loss, "verse_end_regression_output" : verse_end_reg, "verse_end_regression_loss" : verse_loss, "loss": full_loss} def save_LM(self, LM_path): self.model.save_pretrained(LM_path)