|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
import torch |
|
|
|
base_model_name = "chaseharmon/Rap-Mistral-Big" |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
nf4_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=False, |
|
bnb_4bit_compute_dtype="float16" |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
device_map='auto', |
|
quantization_config=nf4_config, |
|
) |
|
model.config.use_cache = False |
|
model.config.pretraining_tp = 1 |
|
|
|
return model |
|
|
|
@st.cache_resource |
|
def load_tokenizer(): |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
return tokenizer |
|
|
|
def build_prompt(question): |
|
prompt=f"[INST] {question} [/INST] " |
|
return prompt |
|
|
|
model = load_model() |
|
model.eval() |
|
|
|
tokenizer = load_tokenizer() |
|
|
|
st.title("Rap Verse Generation V1 Demo") |
|
st.header("Supported Artists") |
|
st.write("Lupe Fiasco, Common, Jay-Z, Yasiin Bey, Ab-Soul, Rakim") |
|
|
|
prompt_placeholder = st.empty() |
|
display_placeholder = st.empty() |
|
|
|
prompt_placeholder.write("Ask the AI to write a verse") |
|
display_placeholder.write("") |
|
|
|
question = st.chat_input("Write a verse in the style of Lupe Fiasco") |
|
if question: |
|
display_placeholder.write("Loading...") |
|
prompt_placeholder.write(question) |
|
prompt = build_prompt(question) |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
model_inputs = inputs.to('cuda') |
|
generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True, pad_token_id=tokenizer.eos_token_id) |
|
response = tokenizer.batch_decode(generated_ids)[0] |
|
end_of_inst = response.find("[/INST]") + len("[/INST]") |
|
if end_of_inst > -1: |
|
actual_response = response[end_of_inst:].strip() |
|
else: |
|
actual_response = response |
|
actual_response = actual_response.replace("\n", " \n") |
|
actual_response = actual_response.replace("nigga", "brotha") |
|
actual_response = actual_response.replace("Nigga", "Brotha") |
|
actual_response = actual_response.replace("faggot", "f--got") |
|
display_placeholder.write(actual_response) |