|
from unsloth import FastLanguageModel |
|
import torch |
|
max_seq_length = 2048 |
|
dtype = None |
|
load_in_4bit = True |
|
|
|
import requests |
|
question= "Đi xe đè vạch màu vàng xử lý như thế nào?" |
|
url = <retrieval_endpoint> |
|
retrieve = {"query": [question]} |
|
response = requests.post(url, json=retrieve) |
|
context= response.json()['predict'][0][0][0]['top_relevant_chunks'] |
|
|
|
import random |
|
def process_context(docs:list) -> list: |
|
res=[] |
|
for context in docs: |
|
res.append(context.get("text","")) |
|
return res[:10] |
|
context= '\n'.join(process_context(context)) |
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name = 'NaverHustQA/LawVinaLlama', |
|
max_seq_length = max_seq_length, |
|
dtype = dtype, |
|
load_in_4bit = load_in_4bit, |
|
) |
|
prompt_context = """Bạn là một tư vấn viên hữu ích về luật. |
|
### Instruction and Input: |
|
Dựa vào ngữ cảnh/tài liệu sau: |
|
{} |
|
Hãy trả lời câu hỏi: {} |
|
|
|
### Câu trả lời: |
|
{} |
|
""" |
|
|
|
FastLanguageModel.for_inference(model) |
|
inputs = tokenizer( |
|
[ |
|
prompt_context.format( |
|
context, |
|
question, |
|
"", |
|
) |
|
], return_tensors = "pt").to("cuda") |
|
|
|
outputs = model.generate(**inputs, max_new_tokens = 1024, use_cache = True) |
|
print(tokenizer.batch_decode(outputs)[0].split("### Câu trả lời:")[1]) |
|
|