LexAIcon / app.py
manuelcozar55's picture
Update app.py
ad0aba7 verified
raw
history blame
7.62 kB
import streamlit as st
from PyPDF2 import PdfReader
from docx import Document
import csv
import json
import os
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from huggingface_hub import login, InferenceClient
from transformers import AutoTokenizer, AutoModelForSequenceClassification
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
# Realizar el inicio de sesi贸n de Hugging Face solo si el token est谩 disponible
if huggingface_token:
login(token=huggingface_token)
# Configuraci贸n del cliente de inferencia
@st.cache_resource
def load_inference_client():
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3")
return client
client = load_inference_client()
# Configuraci贸n del modelo de clasificaci贸n
@st.cache_resource
def load_classification_model():
tokenizer = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
return model, tokenizer
classification_model, classification_tokenizer = load_classification_model()
id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"}
# Cargar documentos JSON para cada categor铆a
@st.cache_resource
def load_json_documents():
documents = {}
categories = ["multas", "politicas_de_privacidad", "contratos", "denuncias", "otros"]
for category in categories:
with open(f"./{category}.json", "r", encoding="utf-8") as f:
data = json.load(f)["questions_and_answers"]
documents[category] = [entry["question"] + " " + entry["answer"] for entry in data]
return documents
json_documents = load_json_documents()
# Configuraci贸n de Embeddings y Vector Stores
@st.cache_resource
def create_vector_store():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2", model_kwargs={"device": "cpu"})
vector_stores = {}
for category, docs in json_documents.items():
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
split_docs = text_splitter.split_text(docs)
vector_stores[category] = FAISS.from_texts(split_docs, embeddings)
return vector_stores
vector_stores = create_vector_store()
def classify_text(text):
inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length")
classification_model.eval()
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax(dim=-1).item()
predicted_label = id2label[predicted_class_id]
return predicted_label
def translate(text, target_language):
template = f'''
Por favor, traduzca el siguiente documento al {target_language}:
<document>
{text}
</document>
Aseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento.
'''
messages = [{"role": "user", "content": template}]
response = client.chat(messages)
translated_text = response.generated_text
return translated_text
def summarize(text, length):
template = f'''
Por favor, haga un resumen {length} del siguiente documento:
<document>
{text}
</document>
Aseg煤rese de que el resumen sea conciso y conserve el significado original del documento.
'''
messages = [{"role": "user", "content": template}]
response = client.chat(messages)
summarized_text = response.generated_text
return summarized_text
def handle_uploaded_file(uploaded_file):
try:
if uploaded_file.name.endswith(".txt"):
text = uploaded_file.read().decode("utf-8")
elif uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
text = ""
for page in range(len(reader.pages)):
text += reader.pages[page].extract_text()
elif uploaded_file.name.endswith(".docx"):
doc = Document(uploaded_file)
text = "\n".join([para.text for para in doc.paragraphs])
elif uploaded_file.name.endswith(".csv"):
text = ""
content = uploaded_file.read().decode("utf-8").splitlines()
reader = csv.reader(content)
text = " ".join([" ".join(row) for row in reader])
elif uploaded_file.name.endswith(".json"):
data = json.load(uploaded_file)
text = json.dumps(data, indent=4)
else:
text = "Tipo de archivo no soportado."
return text
except Exception as e:
return str(e)
def main():
st.title("LexAIcon")
st.write("Puedes conversar con este chatbot basado en Mistral-7B-Instruct y subir archivos para que el chatbot los procese.")
if "messages" not in st.session_state:
st.session_state["messages"] = []
with st.sidebar:
st.text_input("HuggingFace Token", value=huggingface_token, type="password", key="huggingface_token")
st.caption("[Consigue un HuggingFace Token](https://huggingface.co/settings/tokens)")
for msg in st.session_state.messages:
st.write(f"**{msg['role'].capitalize()}:** {msg['content']}")
user_input = st.text_input("Introduce tu consulta:", "")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"])
target_language = None
summary_length = None
if operation == "Traducir":
target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"])
if operation == "Resumir":
summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"])
if uploaded_files := st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"], accept_multiple_files=True):
for uploaded_file in uploaded_files:
file_content = handle_uploaded_file(uploaded_file)
classification = classify_text(file_content)
vector_store = vector_stores[classification]
search_docs = vector_store.similarity_search(user_input)
context = " ".join([doc.page_content for doc in search_docs])
prompt_with_context = f"Contexto: {context}\n\nPregunta: {user_input}"
messages = [{"role": "user", "content": prompt_with_context}]
response = client.chat(messages)
bot_response = response.generated_text
elif operation == "Resumir":
if summary_length == "corto":
length = "de aproximadamente 50 palabras"
elif summary_length == "medio":
length = "de aproximadamente 100 palabras"
elif summary_length == "largo":
length = "de aproximadamente 500 palabras"
bot_response = summarize(user_input, length)
elif operation == "Traducir":
bot_response = translate(user_input, target_language)
else:
messages = [{"role": "user", "content": user_input}]
response = client.chat(messages)
bot_response = response.generated_text
st.session_state.messages.append({"role": "assistant", "content": bot_response})
st.write(f"**Assistant:** {bot_response}")
if __name__ == "__main__":
main()