Anthony G
try
44ceb10
raw
history blame
2.7 kB
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftConfig, PeftModel
import warnings
warnings.filterwarnings("ignore")
PEFT_MODEL = "givyboy/phi-2-finetuned-mental-health-conversational"
SYSTEM_PROMPT = """Answer the following question truthfully.
If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'."""
USER_PROMPT = lambda x: f"""<HUMAN>: {x}\n<ASSISTANT>: """
ADD_RESPONSE = lambda x, y: f"""<HUMAN>: {x}\n<ASSISTANT>: {y}"""
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_use_double_quant=True,
# bnb_4bit_compute_dtype=torch.float16,
# )
config = PeftConfig.from_pretrained(PEFT_MODEL)
peft_base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
# quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
offload_folder="offload/",
offload_state_dict=True,
)
peft_model = PeftModel.from_pretrained(
peft_base_model,
PEFT_MODEL,
offload_folder="offload/",
offload_state_dict=True,
)
peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
peft_tokenizer.pad_token = peft_tokenizer.eos_token
pipeline = transformers.pipeline(
"text-generation",
model=peft_model,
tokenizer=peft_tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
def format_message(message: str, history: list[str], memory_limit: int = 3) -> str:
if len(history) > memory_limit:
history = history[-memory_limit:]
if len(history) == 0:
return f"{SYSTEM_PROMPT}\n{USER_PROMPT(message)}"
formatted_message = f"{SYSTEM_PROMPT}\n{ADD_RESPONSE(history[0][0], history[0][1])}"
for msg, ans in history[1:]:
formatted_message += f"\n{ADD_RESPONSE(msg, ans)}"
formatted_message += f"\n{USER_PROMPT(message)}"
return formatted_message
def get_model_response(message: str, history: list[str]) -> str:
formatted_message = format_message(message, history)
sequences = pipeline(
formatted_message,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=peft_tokenizer.eos_token_id,
max_length=600,
)[0]
print(sequences["generated_text"])
output = sequences["generated_text"].split("<ASSISTANT>:")[-1].strip()
# print(f"Response: {output}")
return output
gr.ChatInterface(fn=get_model_response).launch()