Spaces:
Runtime error
Runtime error
File size: 8,732 Bytes
6455751 c5e8d64 394bbaa 6455751 c5e8d64 6455751 c5e8d64 6455751 c5e8d64 394bbaa c5e8d64 6455751 c5e8d64 394bbaa 5d59d15 394bbaa c5e8d64 6455751 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa 1719da7 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 1719da7 5d59d15 6455751 394bbaa 6455751 394bbaa 6455751 394bbaa 6455751 394bbaa 5d59d15 394bbaa 1719da7 6455751 c5e8d64 394bbaa 6455751 394bbaa 6455751 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import random
# Load the dataset
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
def generate_song(state, language_model, generate_song):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Logic that takes put the selected language model and generates a song
if generate_song:
if not language_model:
song_text = "Please select a language model before generating a song."
return state, song_text, "", ""
# Generate the song and the options based on the language_model
if language_model == "Custom Gpt2":
model_name = "SpartanCinder/GPT2-finetuned-lyric-generation"
elif language_model == "Gpt2-Medium":
model_name = "gpt2-medium"
elif language_model == "facebook/bart-base":
model_name = "facebook/bart-base"
elif language_model == "Gpt-Neo":
model_name = "EleutherAI/gpt-neo-1.3B"
else: # Customized Models
model_name = "customized-models"
#tokenzer and text generation logic
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
#Call for a random artist from the dataset
correct_choice = pick_artist(dataset)
input_text = f"A Song in the style of {correct_choice}:"
# Tuninng settings
max_length = 128
input_ids = tokenizer.encode(input_text, return_tensors="pt")
input_ids = input_ids.to(device)
if language_model != "customized-models" or "Custom Gpt2":
### Using Beam search to generate text###
# encoded data
encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
# But this output is repeating, so I need ot adjust this so that it is not repeating.
elif language_model == "Custom Gpt2":
# encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
else:
### Nucleas Sampling to generate text###
# Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method
# top_p is the probability threshold for nucleus sampling
# So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution
# This will help to generate more diverse text that is less repetitive
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, )
# Decode output
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
# Remove the first line of the output if it contains newline characters
if '\n' in output:
output = '\n'.join(output.split('\n')[1:])
song_text = output
# Generate the multiple-choice options
options = generate_artist_options(dataset, correct_choice)
state['options'] = options
# Generate the multiple-choice check
multiple_choice_check = generate_multiple_choice_check(options, correct_choice)
state['multiple_choice_check'] = multiple_choice_check
state['correct_choice'] = correct_choice
return state, song_text, ', '.join(options)
#Check the selected artist and return whether it's correct
def on_submit_answer(state, user_choice):
# Map the user's choice (A, B, C, or D) to an index
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
index = choice_to_index[user_choice]
# Retrieve the user's choice and the correct choice from the state
user_artist = state['options'][index]
correct_artist = state['correct_choice']
# Compare the user's choice with the correct choice
if user_artist == correct_artist:
return {"CORRECT": f"You guessed the right artist: {correct_artist}"}
else:
return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"}
def pick_artist(dataset):
# Check if 'Artist' is in the dataset columns
artist_choice = list(set(dataset['train']['Artist']))
artist_choice = random.choice(artist_choice)
return artist_choice
def generate_artist_options(dataset, correct_artist):
# Generate 3 incorrect options
all_artists = list(set(dataset['train']['Artist']))
if correct_artist in all_artists:
all_artists.remove(correct_artist)
options = random.sample(all_artists, 3) + [correct_artist]
random.shuffle(options)
return options
def generate_multiple_choice_check(options, correct_choice):
return {option: option == correct_choice for option in options}
with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
gr.Markdown(" # Song Generator Guessing Game")
# gr.Markdown("")
# gr.HTML("<img src='/NLP_Song_Generator_Guessing_Game/RobotSinger.png'")
gr.Markdown("""
## Instructions
1. Select a language model from the dropdown.
2. Click the 'Generate Song' button to generate a song.
3. Guess the artist of the generated song by selecting an option from the radio buttons.
4. Click the 'Submit Answer' button to submit your guess.
""")
state = gr.State({'options': []})
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
generate_song_button = gr.Button("Generate Song")
generated_song = gr.Textbox(label="Generated Song")
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
submit_answer_button = gr.Button("Submit Answer")
correct_answer = gr.Textbox(label="Results")
gr.Markdown("""
## Developer Notes:
- The 'Custom Gpt2' model is a custom fine-tuned GPT-2 model for generating song lyrics.
-- It was trained using a custom dataset of song lyrics from various artists that I created using the Genius API.
-- It uses beam search to generate text, and is much faster compared to the other moodels
-- However, it still has trouble producing prompts
-- I found that artists like "Adele" and "Taylor Swift" are more likely to have coherent lyrics
- The 'Gpt2-Medium' model is the GPT-2 medium model from the Hugging Face model hub.
-- It uses beam search to generate text, and is slower compared to the custom GPT-2 model
-- Without tuning, it is more likely to produce a general response to the prompt
-- Oddly enough, had a tendency to produce lyrics that were more coherent than the full GPT-2 model
- The 'facebook/bart-base' model is the BART base model from the Hugging Face model hub.
-- The model only workd 20% of the time
- The 'Gpt-Neo' model is the GPT-Neo 1.3B model from the EleutherAI model hub.
-- It performs well, but is slower compared to the GPT-2 models
- The 'Customized Models' option is a placeholder for any other custom models that you may have.
#### Known Issues:
- The 'facebook/bart-base' model has a tendency to produce empty responses.
-- This is likely due to the model's architecture and the way it processes the input data.
- Ocasionaly, the Custom Gpt2 model will produce a result that is just numbers. This has only happened once or twice and both times where when is was generating a song for the Weekend.
""")
generate_song_button.click(
generate_song,
[state, language_model, generate_song_button],
[state, generated_song, artist_choice_display,]
)
submit_answer_button.click(
on_submit_answer,
[state, user_choice,],
[correct_answer]
)
game_interface.launch() |