import re from unittest import result import string import streamlit as st import torch from torch.nn import functional as F from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList, pipeline, top_k_top_p_filtering, PhrasalConstraint, DisjunctiveConstraint) import ast class ModifyLogitsProcessor(LogitsProcessor): ### Anything with the letter "e" in it def __init__(self, tokenizer, chars_to_modify, filter_mode=True): super().__init__() self.tokenizer = tokenizer self.filter_mode = filter_mode self.chars_to_modify = chars_to_modify # Compute the tokens to modify at initialization self.tokens_to_modify = {} for char, factor in chars_to_modify.items(): mod_tokens = [token_id for token_id, token in enumerate(self.tokenizer.get_vocab()) if char in token] self.tokens_to_modify[char] = mod_tokens def __call__(self, input_ids, scores): for char, tokens in self.tokens_to_modify.items(): if self.filter_mode: scores[:, tokens] = -float('inf') else: # Fetch the corresponding factor from chars_to_modify dictionary factor = self.chars_to_modify[char] scores[:, tokens] += factor return scores st.set_page_config(page_title="Gadsby") st.title("Gadsby - Constrained Text Generation with Transformers") st.image("https://upload.wikimedia.org/wikipedia/commons/1/1d/Gadsby_%28book_cover%29.jpg") st.caption("The inspiration for this space: https://en.wikipedia.org/wiki/Gadsby_(novel)") form = st.sidebar.form("choose_settings") form.header("Model Settings") model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "facebook/opt-1.3b") form.caption("This will download a new model, so it may take awhile or even break if the model is too large") percision = form.selectbox("What percision are we loading the model with?", ["8bit", "16bit", "32bit"], ) form.caption("The lower the percision, the less ram the model takes and the faster it runs, but the quality is reduced") form.header("Token Level Constraint Settings") form.subheader("Lipogram Constraint") form.caption("Lipograms are compositions where a certain letter or certain letters of the alphabet are omitted or discouraged") filter_mode = form.checkbox("Filter Mode?", value=False) form.caption("Enabling filter mode sets all selected tokens probabilities to negative infinity") naughty_strings_list = form.text_input('Enter letters or words to filter or modify the probabilities of (comma separated):', value = "that,e") factor_input = form.text_input('Enter corresponding factors to add to the logits (comma separated, ignored if in filter mode):', value = "5,-99") form.header("Sequence Level Constraint Settings") form.header("Phrasal Constraint") force_word = form.text_input("Enter a word or sentence that is guaranteed to appear in the output", value = "lipogram") form.header("Disjunctive Constraint") force_flexible_input = form.text_input('Enter a list of words or sentences that the model must include at least one item from (in Python list format)', '["constraint", "banana"]') if force_flexible_input: try: force_flexible = ast.literal_eval(force_flexible_input) except Exception as e: st.write('Failed to parse the list. Please check your input.') st.write('Error:', e) force_flexible = [] else: pass if naughty_strings_list: chars = naughty_strings_list.split(',') factors = list(map(float, factor_input.split(','))) chars_to_modify = dict(zip(chars, factors)) else: chars = "" factors = [] chars_to_modify = {} generate_args = st.text_input('model.generate() arguments (in python dictionary format) ', '{"max_new_tokens": 50, "min_new_tokens": 50, "temperature": 2.0, "num_return_sequences": 1, "do_sample": False, "num_beams": 2, "repetition_penalty": 3.0}') st.caption("For more details on what these settings mean and a complete list of all settings, see here: https://huggingface.co/blog/how-to-generate and https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig and https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationMixin.generate") sequence = st.text_area("Enter a custom prompt", value = "Tell me about ") form.form_submit_button("Generate some Constrained Text!") def parse_generate_args(args_str): args_list = args_str.split(',') args_dict = {arg.split(':')[0]: int(arg.split(':')[1]) for arg in args_list if len(arg.split(':')) == 2} return args_dict @st.cache_resource def load_the_tokenizer(): tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = False) return tokenizer @st.cache_resource def load_the_model(percision): if percision == "32bit": model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False) elif percision =="16bit": model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=False, torch_dtype=torch.float16) else: model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', load_in_8bit=True) return model if len(chars) != len(factors): st.write("Please ensure that the number of characters matches the number of factors.") else: model = load_the_model(percision) tokenizer = load_the_tokenizer() constraints = [] if force_word: constraints.append(PhrasalConstraint( tokenizer(force_word, add_special_tokens=False).input_ids )) if force_flexible_input: constraints.append(DisjunctiveConstraint( tokenizer(force_flexible, add_special_tokens=False).input_ids )) if filter_mode: logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=True)]) else: logits_processor = LogitsProcessorList([ModifyLogitsProcessor(tokenizer, chars_to_modify, filter_mode=False)]) input_ids = tokenizer.encode(sequence, return_tensors="pt").to('cuda') generate_kwargs = ast.literal_eval(generate_args) if constraints: output_ids = model.generate(input_ids, constraints=constraints, logits_processor=logits_processor, **generate_kwargs) else: output_ids = model.generate(input_ids, logits_processor=logits_processor, **generate_kwargs) st.write("GENERATED SEQUENCE(s): ") for output in output_ids: st.write(tokenizer.decode(output, skip_special_tokens = True, clean_up_tokenization_spaces = True))