LawVinaLlama / infer.py
haisonle001's picture
Upload infer.py
f6a1352 verified
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None
load_in_4bit = True
import requests
question= "Đi xe đè vạch màu vàng xử lý như thế nào?"
url = <retrieval_endpoint>
retrieve = {"query": [question]}
response = requests.post(url, json=retrieve)
context= response.json()['predict'][0][0][0]['top_relevant_chunks']
import random
def process_context(docs:list) -> list:
res=[]
for context in docs:
res.append(context.get("text",""))
return res[:10]
context= '\n'.join(process_context(context))
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = 'NaverHustQA/LawVinaLlama',
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
prompt_context = """Bạn là một tư vấn viên hữu ích về luật.
### Instruction and Input:
Dựa vào ngữ cảnh/tài liệu sau:
{}
Hãy trả lời câu hỏi: {}
### Câu trả lời:
{}
"""
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
prompt_context.format(
context, # instruction
question, # input
"",
)
], return_tensors = "pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens = 1024, use_cache = True)
print(tokenizer.batch_decode(outputs)[0].split("### Câu trả lời:")[1])