chaseharmon commited on
Commit
8195b87
·
1 Parent(s): c40f3f4
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
4
- from peft import PeftModel
5
 
6
  base_model_name = "chaseharmon/Rap-Mistral-Big"
7
 
@@ -14,7 +13,7 @@ def load_model():
14
  bnb_4bit_use_double_quant=False,
15
  bnb_4bit_compute_dtype="float16"
16
  )
17
-
18
  model = AutoModelForCausalLM.from_pretrained(
19
  base_model_name,
20
  device_map='auto',
@@ -33,15 +32,24 @@ def load_tokenizer():
33
 
34
  return tokenizer
35
 
36
-
 
 
37
 
38
  model = load_model()
 
 
39
  tokenizer = load_tokenizer
40
 
41
  st.title("Rap Verse Generation V1 Demo")
42
  st.header("Supported Artists")
43
  st.write("Lupe Fiasco, Common, Jay-Z, Yasiin Bey, Ab-Soul, Rakim")
44
 
45
- prompt = st.chat_input("Write a verse in the style of Lupe Fiasco")
 
 
 
 
 
 
46
 
47
- st.write(prompt)
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
 
4
 
5
  base_model_name = "chaseharmon/Rap-Mistral-Big"
6
 
 
13
  bnb_4bit_use_double_quant=False,
14
  bnb_4bit_compute_dtype="float16"
15
  )
16
+
17
  model = AutoModelForCausalLM.from_pretrained(
18
  base_model_name,
19
  device_map='auto',
 
32
 
33
  return tokenizer
34
 
35
+ def build_prompt(question):
36
+ prompt=f"[INST] {question} [/INST] "
37
+ return prompt
38
 
39
  model = load_model()
40
+ model.eval()
41
+
42
  tokenizer = load_tokenizer
43
 
44
  st.title("Rap Verse Generation V1 Demo")
45
  st.header("Supported Artists")
46
  st.write("Lupe Fiasco, Common, Jay-Z, Yasiin Bey, Ab-Soul, Rakim")
47
 
48
+ question = st.chat_input("Write a verse in the style of Lupe Fiasco")
49
+ if question:
50
+ prompt = build_prompt(question)
51
+ inputs = tokenizer(prompt, return_tensors="pt")
52
+ model_inputs = inputs.to('cuda')
53
+ generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True, pad_token_id=tokenizer.eos_token_id)
54
+ decoded_output = tokenizer.batch_decode(generated_ids)
55