import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import torch | |
base_model_name = "chaseharmon/Rap-Mistral-Big" | |
def load_model(): | |
nf4_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=False, | |
bnb_4bit_compute_dtype="float16" | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
device_map='auto', | |
) | |
model.config.use_cache = False | |
model.config.pretraining_tp = 1 | |
return model | |
def load_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
return tokenizer | |
def build_prompt(question): | |
prompt=f"[INST] {question} [/INST] " | |
return prompt | |
model = load_model() | |
model.eval() | |
tokenizer = load_tokenizer | |
st.title("Rap Verse Generation V1 Demo") | |
st.header("Supported Artists") | |
st.write("Lupe Fiasco, Common, Jay-Z, Yasiin Bey, Ab-Soul, Rakim") | |
question = st.chat_input("Write a verse in the style of Lupe Fiasco") | |
if question: | |
prompt = build_prompt(question) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
model_inputs = inputs.to('cuda') | |
generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True, pad_token_id=tokenizer.eos_token_id) | |
decoded_output = tokenizer.batch_decode(generated_ids) | |