Spaces:
Paused
Paused
import streamlit as st | |
from PyPDF2 import PdfReader | |
from docx import Document | |
import csv | |
import json | |
import os | |
import torch | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from huggingface_hub import login, InferenceClient | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
huggingface_token = os.getenv('HUGGINGFACE_TOKEN') | |
# Realizar el inicio de sesi贸n de Hugging Face solo si el token est谩 disponible | |
if huggingface_token: | |
login(token=huggingface_token) | |
# Configuraci贸n del cliente de inferencia | |
def load_inference_client(): | |
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.3") | |
return client | |
client = load_inference_client() | |
# Configuraci贸n del modelo de clasificaci贸n | |
def load_classification_model(): | |
tokenizer = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") | |
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish") | |
return model, tokenizer | |
classification_model, classification_tokenizer = load_classification_model() | |
id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"} | |
# Cargar documentos JSON para cada categor铆a | |
def load_json_documents(): | |
documents = {} | |
categories = ["multas", "politicas_de_privacidad", "contratos", "denuncias", "otros"] | |
for category in categories: | |
with open(f"./{category}.json", "r", encoding="utf-8") as f: | |
data = json.load(f)["questions_and_answers"] | |
documents[category] = [entry["question"] + " " + entry["answer"] for entry in data] | |
return documents | |
json_documents = load_json_documents() | |
# Configuraci贸n de Embeddings y Vector Stores | |
def create_vector_store(): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2", model_kwargs={"device": "cpu"}) | |
vector_stores = {} | |
for category, docs in json_documents.items(): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) | |
split_docs = text_splitter.split_text(docs) | |
vector_stores[category] = FAISS.from_texts(split_docs, embeddings) | |
return vector_stores | |
vector_stores = create_vector_store() | |
def classify_text(text): | |
inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length") | |
classification_model.eval() | |
with torch.no_grad(): | |
outputs = classification_model(**inputs) | |
logits = outputs.logits | |
predicted_class_id = logits.argmax(dim=-1).item() | |
predicted_label = id2label[predicted_class_id] | |
return predicted_label | |
def translate(text, target_language): | |
template = f''' | |
Por favor, traduzca el siguiente documento al {target_language}: | |
<document> | |
{text} | |
</document> | |
Aseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento. | |
''' | |
messages = [{"role": "user", "content": template}] | |
response = client.chat(messages) | |
translated_text = response.generated_text | |
return translated_text | |
def summarize(text, length): | |
template = f''' | |
Por favor, haga un resumen {length} del siguiente documento: | |
<document> | |
{text} | |
</document> | |
Aseg煤rese de que el resumen sea conciso y conserve el significado original del documento. | |
''' | |
messages = [{"role": "user", "content": template}] | |
response = client.chat(messages) | |
summarized_text = response.generated_text | |
return summarized_text | |
def handle_uploaded_file(uploaded_file): | |
try: | |
if uploaded_file.name.endswith(".txt"): | |
text = uploaded_file.read().decode("utf-8") | |
elif uploaded_file.name.endswith(".pdf"): | |
reader = PdfReader(uploaded_file) | |
text = "" | |
for page in range(len(reader.pages)): | |
text += reader.pages[page].extract_text() | |
elif uploaded_file.name.endswith(".docx"): | |
doc = Document(uploaded_file) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
elif uploaded_file.name.endswith(".csv"): | |
text = "" | |
content = uploaded_file.read().decode("utf-8").splitlines() | |
reader = csv.reader(content) | |
text = " ".join([" ".join(row) for row in reader]) | |
elif uploaded_file.name.endswith(".json"): | |
data = json.load(uploaded_file) | |
text = json.dumps(data, indent=4) | |
else: | |
text = "Tipo de archivo no soportado." | |
return text | |
except Exception as e: | |
return str(e) | |
def main(): | |
st.title("LexAIcon") | |
st.write("Puedes conversar con este chatbot basado en Mistral-7B-Instruct y subir archivos para que el chatbot los procese.") | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
with st.sidebar: | |
st.text_input("HuggingFace Token", value=huggingface_token, type="password", key="huggingface_token") | |
st.caption("[Consigue un HuggingFace Token](https://huggingface.co/settings/tokens)") | |
for msg in st.session_state.messages: | |
st.write(f"**{msg['role'].capitalize()}:** {msg['content']}") | |
user_input = st.text_input("Introduce tu consulta:", "") | |
if user_input: | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"]) | |
target_language = None | |
summary_length = None | |
if operation == "Traducir": | |
target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"]) | |
if operation == "Resumir": | |
summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"]) | |
if uploaded_files := st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"], accept_multiple_files=True): | |
for uploaded_file in uploaded_files: | |
file_content = handle_uploaded_file(uploaded_file) | |
classification = classify_text(file_content) | |
vector_store = vector_stores[classification] | |
search_docs = vector_store.similarity_search(user_input) | |
context = " ".join([doc.page_content for doc in search_docs]) | |
prompt_with_context = f"Contexto: {context}\n\nPregunta: {user_input}" | |
messages = [{"role": "user", "content": prompt_with_context}] | |
response = client.chat(messages) | |
bot_response = response.generated_text | |
elif operation == "Resumir": | |
if summary_length == "corto": | |
length = "de aproximadamente 50 palabras" | |
elif summary_length == "medio": | |
length = "de aproximadamente 100 palabras" | |
elif summary_length == "largo": | |
length = "de aproximadamente 500 palabras" | |
bot_response = summarize(user_input, length) | |
elif operation == "Traducir": | |
bot_response = translate(user_input, target_language) | |
else: | |
messages = [{"role": "user", "content": user_input}] | |
response = client.chat(messages) | |
bot_response = response.generated_text | |
st.session_state.messages.append({"role": "assistant", "content": bot_response}) | |
st.write(f"**Assistant:** {bot_response}") | |
if __name__ == "__main__": | |
main() |