georgesung's picture
Increase context window to 2048
26dde52
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline
import torch
import gradio as gr
# LLM helper functions
def get_response_text(data):
text = data[0]["generated_text"]
assistant_text_index = text.rfind('### RESPONSE:')
if assistant_text_index != -1:
text = text[assistant_text_index+len('### RESPONSE:'):].strip()
return text
def get_llm_response(prompt, pipe):
raw_output = pipe(prompt)
text = get_response_text(raw_output)
return text
# Load LLM
model_id = "georgesung/open_llama_7b_qlora_uncensored"
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
# Llama tokenizer missing pad token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048, # LLaMA default context window
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def hist_to_prompt(history):
prompt = ""
for human_text, bot_text in history:
prompt += f"### HUMAN:\n{human_text}\n\n### RESPONSE:\n"
if bot_text:
prompt += f"{bot_text}\n\n"
return prompt
def get_bot_response(text):
bot_text_index = text.rfind('### RESPONSE:')
if bot_text_index != -1:
text = text[bot_text_index + len('### RESPONSE:'):].strip()
return text
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
#bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
#history[-1][1] = bot_message + '</s>'
hist_text = hist_to_prompt(history)
print(hist_text)
bot_message = get_llm_response(hist_text, pipe) + tokenizer.eos_token
history[-1][1] = bot_message # add bot message to overall history
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch()