File size: 2,245 Bytes
2e808c4 29c3aa6 7a28c51 29c3aa6 8195b87 29c3aa6 62ed4f6 29c3aa6 7a28c51 29c3aa6 7a28c51 29c3aa6 8195b87 7a28c51 29c3aa6 8195b87 7a28c51 50e2273 2e808c4 3bb778c 5838536 fc4a409 5838536 b4029bd 5838536 fc4a409 8195b87 fc4a409 5838536 8195b87 7a28c51 5838536 f9019ad b4029bd e606f6c e0d3620 b4c237a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
base_model_name = "chaseharmon/Rap-Mistral-Big"
@st.cache_resource
def load_model():
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype="float16"
)
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map='auto',
quantization_config=nf4_config,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
return model
@st.cache_resource
def load_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return tokenizer
def build_prompt(question):
prompt=f"[INST] {question} [/INST] "
return prompt
model = load_model()
model.eval()
tokenizer = load_tokenizer()
st.title("Rap Verse Generation V1 Demo")
st.header("Supported Artists")
st.write("Lupe Fiasco, Common, Jay-Z, Yasiin Bey, Ab-Soul, Rakim")
prompt_placeholder = st.empty()
display_placeholder = st.empty()
prompt_placeholder.write("Ask the AI to write a verse")
display_placeholder.write("")
question = st.chat_input("Write a verse in the style of Lupe Fiasco")
if question:
display_placeholder.write("Loading...")
prompt_placeholder.write(question)
prompt = build_prompt(question)
inputs = tokenizer(prompt, return_tensors="pt")
model_inputs = inputs.to('cuda')
generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.batch_decode(generated_ids)[0]
end_of_inst = response.find("[/INST]") + len("[/INST]")
if end_of_inst > -1:
actual_response = response[end_of_inst:].strip()
else:
actual_response = response
actual_response = actual_response.replace("\n", " \n")
actual_response = actual_response.replace("nigga", "brotha")
actual_response = actual_response.replace("Nigga", "Brotha")
actual_response = actual_response.replace("faggot", "f--got")
display_placeholder.write(actual_response) |