chatbot / app.py
trungtienluong's picture
Update app.py
caad00a verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import pandas as pd
from datasets import Dataset, load_dataset
from sklearn.model_selection import train_test_split
from accelerate import Accelerator
# Initialize the accelerator
accelerator = Accelerator()
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
# Load the base model with accelerate
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
model = accelerator.prepare(model)
# Load the pre-trained model with PEFT
peft_config = PeftConfig.from_pretrained("trungtienluong/experiments500czephymodelngay11t6l1")
model = PeftModel.from_pretrained(model, "trungtienluong/experiments500czephymodelngay11t6l1")
model = accelerator.prepare(model)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Load the dataset
dataset = load_dataset("trungtienluong/500cau")
data = pd.DataFrame(dataset['train'])
train_samples, temp_samples = train_test_split(data, test_size=0.2, random_state=42)
val_samples, test_samples = train_test_split(temp_samples, test_size=0.5, random_state=42)
train_dataset = Dataset.from_pandas(train_samples)
val_dataset = Dataset.from_pandas(val_samples)
test_dataset = Dataset.from_pandas(test_samples)
def create_prompt(question):
prompt_messages = [
{"role": "system", "content": "Bạn là một chuyên gia trong lĩnh vực nhi khoa. Hãy trả lời chính xác theo explanation của từng câu. Không thêm thông tin bên ngoài."},
{"role": "user", "content": "Nhiễm trùng sơ sinh là gì?"},
{"role": "assistant", "content": "Nhiễm trùng sơ sinh là tình trạng mà một em bé mới sinh bị nhiễm khuẩn hoặc vi rút. Đây là một vấn đề nghiêm trọng có thể ảnh hưởng đến sức khỏe và thậm chí là tính mạng của trẻ sơ sinh."},
{"role": "user", "content": question}
]
prompt = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)
return prompt
def post_process_answer(answer):
lines = answer.split('\n')
unwanted_tags = ["<system>", "<user>", "<assistant>"]
filtered_lines = [line for line in lines if not any(tag in line for tag in unwanted_tags)]
return "\n".join(filtered_lines).strip()
def generate_answer(question):
try:
prompt = create_prompt(question)
encoding = tokenizer(prompt, return_tensors="pt")
encoding = accelerator.prepare(encoding)
with torch.inference_mode():
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
max_new_tokens=150
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
processed_answer = post_process_answer(answer)
print(f"Generated answer: {processed_answer}")
return processed_answer
except Exception as e:
print(f"Error generating answer: {e}")
return "Error"
def get_random_test_question():
random_question = test_dataset.shuffle(seed=42)['Question'][0]
return random_question
def interface_generate_answer(question, use_test_question):
if use_test_question:
question = get_random_test_question()
answer = generate_answer(question)
return question, answer
iface = gr.Interface(
fn=interface_generate_answer,
inputs=[
gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn ở đây...", label="Câu hỏi"),
gr.Checkbox(label="Sử dụng câu hỏi từ tập kiểm tra")
],
outputs=[
gr.Textbox(label="Câu hỏi đã nhập hoặc từ tập kiểm tra"),
gr.Textbox(label="Câu trả lời")
],
title="Chatbot Nhi khoa",
description="Hỏi bất kỳ câu hỏi nào về nhi khoa. Bạn có thể chọn sử dụng câu hỏi từ tập kiểm tra.",
theme="default"
)
iface.launch(share=True)