import gradio as gr
from datasets import load_dataset
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
import faiss
import fitz  # PyMuPDF

# 환경 변수에서 Hugging Face 토큰 가져오기
token = os.environ.get("HF_TOKEN")
if not token:
    raise ValueError("Hugging Face token is missing. Please set it in your environment variables.")

# 임베딩 모델 로드
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

# PDF에서 텍스트 추출
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    return text

# 법률 문서 PDF 경로 지정 및 텍스트 추출
pdf_path = "laws.pdf"  # 여기에 실제 PDF 경로를 입력하세요.
law_text = extract_text_from_pdf(pdf_path)

# 법률 문서 텍스트를 문장 단위로 나누고 임베딩
law_sentences = law_text.split('\n')
law_embeddings = ST.encode(law_sentences)

# FAISS 인덱스 생성 및 임베딩 추가
index = faiss.IndexFlatL2(law_embeddings.shape[1])
index.add(law_embeddings)

# Hugging Face에서 법률 상담 데이터셋 로드
dataset = load_dataset("jihye-moon/LawQA-Ko")
data = dataset["train"]

# 질문 컬럼을 임베딩하여 새로운 컬럼에 추가
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
data.add_faiss_index(column="question_embedding")

# LLaMA 모델 설정
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config,
    token=token
)

SYS_PROMPT = """You are an assistant for answering legal questions.
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't makup an answer."""

# 법률 문서 검색 함수
def search_law(query, k=5):
    query_embedding = ST.encode([query])
    D, I = index.search(query_embedding, k)
    return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]

# 법률 상담 데이터 검색 함수
def search_qa(query, k=3):
    scores, retrieved_examples = data.get_nearest_examples(
        "question_embedding", ST.encode(query), k=k
    )
    return [retrieved_examples["answer"][i] for i in range(k)]

# 최종 프롬프트 생성
def format_prompt(prompt, law_docs, qa_docs):
    PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
    for doc in law_docs:
        PROMPT += f"{doc[0]}\n"
    PROMPT += "\nLegal QA:\n"
    for doc in qa_docs:
        PROMPT += f"{doc}\n"
    return PROMPT

# 챗봇 응답 함수
def talk(prompt, history):
    law_results = search_law(prompt, k=3)
    qa_results = search_qa(prompt, k=3)
    
    retrieved_law_docs = [result[0] for result in law_results]
    formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
    formatted_prompt = formatted_prompt[:2000]  # GPU 메모리 부족을 피하기 위해 프롬프트 제한
    messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]

    # 모델에게 생성 지시
    input_ids = tokenizer(messages, return_tensors="pt").input_ids.to(model.device)

    generate_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        temperature=0.75,
        eos_token_id=tokenizer.eos_token_id,
    )
    
    try:
        outputs = model.generate(**generate_kwargs)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        response = f"Error: {str(e)}"

    return response


# Gradio 인터페이스 설정
TITLE = "Legal RAG Chatbot"

DESCRIPTION = """
A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
This chatbot can search legal documents and previous legal QA pairs to provide answers.
"""

demo = gr.ChatInterface(
    fn=talk,
    chatbot=gr.Chatbot(
        show_label=True,
        show_share_button=True,
        show_copy_button=True,
        likeable=True,
        layout="bubble",
        bubble_full_width=False,
    ),
    theme="Soft",
    examples=[["What are the regulations on data privacy?"]],
    title=TITLE,
    description=DESCRIPTION,
)

# Gradio 데모 실행
demo.launch(debug=True)