Spaces:
Running
Running
import streamlit as st | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from sentence_transformers import SentenceTransformer | |
import openai | |
import faiss | |
import numpy as np | |
import os | |
import joblib | |
from openai import OpenAI | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
os.environ["HF_HOME"] = "/tmp" | |
os.environ["XDG_CACHE_HOME"] = "/tmp" | |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp" | |
os.environ["STREAMLIT_HOME"] = "/tmp" | |
client = OpenAI() # Uses env var OPENAI_API_KEY | |
def load_pdf_chunks(pdf_path): | |
reader = PdfReader(pdf_path) | |
raw_text = "" | |
for page in reader.pages: | |
raw_text += page.extract_text() + "\n" | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
return splitter.split_text(raw_text) | |
def load_model_and_index(chunks): | |
model = SentenceTransformer('models/all-MiniLM-L6-v2') | |
embeddings = model.encode(chunks) | |
faiss_index = faiss.IndexFlatL2(embeddings.shape[1]) | |
faiss_index.add(np.array(embeddings)) | |
# joblib.dump((model, chunks, faiss_index), "rag_model.joblib") | |
return model, chunks, faiss_index | |
def search(query, model, chunks, index, k=3): | |
query_vec = model.encode([query]) | |
scores, indices = index.search(np.array(query_vec), k) | |
return [chunks[i] for i in indices[0]] | |
def chat_no_rag(question, max_tokens=250): | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "user", "content": question} | |
], | |
temperature=0.5, | |
max_tokens=max_tokens | |
) | |
return response.choices[0].message.content | |
def chat_with_rag(question, retrieved_chunks, max_tokens=300): | |
context = "\n".join(retrieved_chunks) | |
prompt = f"Usa el siguiente contexto para responder la pregunta:\n\n{context}\n\nPregunta: {question}" | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.3, | |
max_tokens=max_tokens | |
) | |
return response.choices[0].message.content | |
def chat_with_rag_enhanced(question, retrieved_chunks, max_tokens=300): | |
context = "\n".join(retrieved_chunks) | |
prompt = ( | |
"Eres un experto en historia marcial. " | |
"Usa el siguiente contexto hist贸rico para responder con precisi贸n y detalle.\n\n" | |
f"Contexto:\n{context}\n\n" | |
f"Pregunta: {question}\nRespuesta:" | |
) | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.2, | |
max_tokens=max_tokens | |
) | |
return response.choices[0].message.content | |
# Streamlit UI | |
st.title("馃摐 RAG JuJutsu Historico - ChatGPT + HF + Streamlit") | |
if "model" not in st.session_state: | |
with st.spinner("Cargando y procesando el PDF..."): | |
chunks = load_pdf_chunks("JuJutsu-Contexto-Significado-Conexiones-Historia.pdf") | |
model, chunks, index = load_model_and_index(chunks) | |
st.session_state.model = model | |
st.session_state.chunks = chunks | |
st.session_state.index = index | |
query = st.text_input("Escribe tu pregunta sobre JuJutsu hist贸rico:") | |
max_tokens = st.slider("M谩ximo de tokens de respuesta", 50, 1000, 300, step=50) | |
if query: | |
model = st.session_state.model | |
chunks = st.session_state.chunks | |
index = st.session_state.index | |
st.subheader("馃敼 Respuesta sin RAG:") | |
st.write(chat_no_rag(query, max_tokens=max_tokens)) | |
st.subheader("馃敼 Respuesta con RAG:") | |
retrieved = search(query, model, chunks, index) | |
st.write(chat_with_rag(query, retrieved, max_tokens=max_tokens)) | |
st.subheader("馃敼 Respuesta con RAG + Mejora de Prompt:") | |
st.write(chat_with_rag_enhanced(query, retrieved, max_tokens=max_tokens)) |