|
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline |
|
import torch |
|
|
|
import gradio as gr |
|
|
|
|
|
def get_response_text(data): |
|
text = data[0]["generated_text"] |
|
|
|
assistant_text_index = text.rfind('### RESPONSE:') |
|
if assistant_text_index != -1: |
|
text = text[assistant_text_index+len('### RESPONSE:'):].strip() |
|
|
|
return text |
|
|
|
def get_llm_response(prompt, pipe): |
|
raw_output = pipe(prompt) |
|
text = get_response_text(raw_output) |
|
return text |
|
|
|
|
|
model_id = "georgesung/open_llama_7b_qlora_uncensored" |
|
tokenizer = LlamaTokenizer.from_pretrained(model_id) |
|
model = LlamaForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True) |
|
|
|
|
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_length=2048, |
|
temperature=0.7, |
|
top_p=0.95, |
|
repetition_penalty=1.15 |
|
) |
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox() |
|
clear = gr.Button("Clear") |
|
|
|
def hist_to_prompt(history): |
|
prompt = "" |
|
for human_text, bot_text in history: |
|
prompt += f"### HUMAN:\n{human_text}\n\n### RESPONSE:\n" |
|
if bot_text: |
|
prompt += f"{bot_text}\n\n" |
|
return prompt |
|
|
|
def get_bot_response(text): |
|
bot_text_index = text.rfind('### RESPONSE:') |
|
if bot_text_index != -1: |
|
text = text[bot_text_index + len('### RESPONSE:'):].strip() |
|
return text |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history): |
|
|
|
|
|
|
|
hist_text = hist_to_prompt(history) |
|
print(hist_text) |
|
bot_message = get_llm_response(hist_text, pipe) + tokenizer.eos_token |
|
history[-1][1] = bot_message |
|
|
|
return history |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|