JuJitsuPOC / app.py
DD8943's picture
update files
766a317 verified
raw
history blame
3.82 kB
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import openai
import faiss
import numpy as np
import os
import joblib
from openai import OpenAI
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
os.environ["HF_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp"
os.environ["STREAMLIT_HOME"] = "/tmp"
client = OpenAI() # Uses env var OPENAI_API_KEY
@st.cache_data
def load_pdf_chunks(pdf_path):
reader = PdfReader(pdf_path)
raw_text = ""
for page in reader.pages:
raw_text += page.extract_text() + "\n"
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
return splitter.split_text(raw_text)
@st.cache_resource
def load_model_and_index(chunks):
model = SentenceTransformer('models/all-MiniLM-L6-v2')
embeddings = model.encode(chunks)
faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
faiss_index.add(np.array(embeddings))
# joblib.dump((model, chunks, faiss_index), "rag_model.joblib")
return model, chunks, faiss_index
def search(query, model, chunks, index, k=3):
query_vec = model.encode([query])
scores, indices = index.search(np.array(query_vec), k)
return [chunks[i] for i in indices[0]]
def chat_no_rag(question, max_tokens=250):
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "user", "content": question}
],
temperature=0.5,
max_tokens=max_tokens
)
return response.choices[0].message.content
def chat_with_rag(question, retrieved_chunks, max_tokens=300):
context = "\n".join(retrieved_chunks)
prompt = f"Usa el siguiente contexto para responder la pregunta:\n\n{context}\n\nPregunta: {question}"
response = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=max_tokens
)
return response.choices[0].message.content
def chat_with_rag_enhanced(question, retrieved_chunks, max_tokens=300):
context = "\n".join(retrieved_chunks)
prompt = (
"Eres un experto en historia marcial. "
"Usa el siguiente contexto hist贸rico para responder con precisi贸n y detalle.\n\n"
f"Contexto:\n{context}\n\n"
f"Pregunta: {question}\nRespuesta:"
)
response = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
max_tokens=max_tokens
)
return response.choices[0].message.content
# Streamlit UI
st.title("馃摐 RAG JuJutsu Historico - ChatGPT + HF + Streamlit")
if "model" not in st.session_state:
with st.spinner("Cargando y procesando el PDF..."):
chunks = load_pdf_chunks("JuJutsu-Contexto-Significado-Conexiones-Historia.pdf")
model, chunks, index = load_model_and_index(chunks)
st.session_state.model = model
st.session_state.chunks = chunks
st.session_state.index = index
query = st.text_input("Escribe tu pregunta sobre JuJutsu hist贸rico:")
max_tokens = st.slider("M谩ximo de tokens de respuesta", 50, 1000, 300, step=50)
if query:
model = st.session_state.model
chunks = st.session_state.chunks
index = st.session_state.index
st.subheader("馃敼 Respuesta sin RAG:")
st.write(chat_no_rag(query, max_tokens=max_tokens))
st.subheader("馃敼 Respuesta con RAG:")
retrieved = search(query, model, chunks, index)
st.write(chat_with_rag(query, retrieved, max_tokens=max_tokens))
st.subheader("馃敼 Respuesta con RAG + Mejora de Prompt:")
st.write(chat_with_rag_enhanced(query, retrieved, max_tokens=max_tokens))