File size: 3,637 Bytes
1300f65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import openai
import faiss
import numpy as np
import os
import joblib
from openai import OpenAI

client = OpenAI()  # Uses env var OPENAI_API_KEY

@st.cache_data
def load_pdf_chunks(pdf_path):
    reader = PdfReader(pdf_path)
    raw_text = ""
    for page in reader.pages:
        raw_text += page.extract_text() + "\n"
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    return splitter.split_text(raw_text)

@st.cache_resource
def load_model_and_index(chunks):
    model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    embeddings = model.encode(chunks)
    faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
    faiss_index.add(np.array(embeddings))
    joblib.dump((model, chunks, faiss_index), "rag_model.joblib")
    return model, chunks, faiss_index

def search(query, model, chunks, index, k=3):
    query_vec = model.encode([query])
    scores, indices = index.search(np.array(query_vec), k)
    return [chunks[i] for i in indices[0]]

def chat_no_rag(question, max_tokens=250):
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "user", "content": question}
        ],
        temperature=0.5,
        max_tokens=max_tokens
    )
    return response.choices[0].message.content

def chat_with_rag(question, retrieved_chunks, max_tokens=300):
    context = "\n".join(retrieved_chunks)
    prompt = f"Usa el siguiente contexto para responder la pregunta:\n\n{context}\n\nPregunta: {question}"

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3,
        max_tokens=max_tokens
    )
    return response.choices[0].message.content

def chat_with_rag_enhanced(question, retrieved_chunks, max_tokens=300):
    context = "\n".join(retrieved_chunks)
    prompt = (
        "Eres un experto en historia marcial. "
        "Usa el siguiente contexto hist贸rico para responder con precisi贸n y detalle.\n\n"
        f"Contexto:\n{context}\n\n"
        f"Pregunta: {question}\nRespuesta:"
    )

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.2,
        max_tokens=max_tokens
    )
    return response.choices[0].message.content

# Streamlit UI
st.title("馃摐 RAG JuJutsu Historico - ChatGPT + HF + Streamlit")

if "model" not in st.session_state:
    with st.spinner("Cargando y procesando el PDF..."):
        chunks = load_pdf_chunks("JuJutsu-Contexto-Significado-Conexiones-Historia.pdf")
        model, chunks, index = load_model_and_index(chunks)
        st.session_state.model = model
        st.session_state.chunks = chunks
        st.session_state.index = index

query = st.text_input("Escribe tu pregunta sobre JuJutsu hist贸rico:")
max_tokens = st.slider("M谩ximo de tokens de respuesta", 50, 1000, 300, step=50)

if query:
    model = st.session_state.model
    chunks = st.session_state.chunks
    index = st.session_state.index

    st.subheader("馃敼 Respuesta sin RAG:")
    st.write(chat_no_rag(query, max_tokens=max_tokens))

    st.subheader("馃敼 Respuesta con RAG:")
    retrieved = search(query, model, chunks, index)
    st.write(chat_with_rag(query, retrieved, max_tokens=max_tokens))

    st.subheader("馃敼 Respuesta con RAG + Mejora de Prompt:")
    st.write(chat_with_rag_enhanced(query, retrieved, max_tokens=max_tokens))