File size: 3,911 Bytes
c5446a2
 
 
 
 
 
 
 
 
 
e5da8b5
 
c5446a2
 
e5da8b5
c5446a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# app.py
import torch
from transformers import (
    DPRContextEncoder, DPRContextEncoderTokenizerFast,
    DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast,
    BartForConditionalGeneration, BartTokenizer
)
from datasets import Dataset
import faiss
import numpy as np
import gradio as gr

# Importar funciones de extracci贸n
from extract_text import extract_text_from_pdf, extract_text_from_docx, extract_text_from_image

# Inicializar modelos y variables globales
ctx_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

q_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
q_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained('facebook/dpr-question_encoder-single-nq-base')

generator = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
gen_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

# Inicializar dataset y 铆ndice
dataset = Dataset.from_dict({'text': []})
embeddings = np.empty((0, ctx_encoder.config.hidden_size), dtype='float32')
index = faiss.IndexFlatIP(ctx_encoder.config.hidden_size)

# Funci贸n para actualizar el 铆ndice con nuevo texto
def actualizar_indice(nuevo_texto):
    global dataset, embeddings, index

    # A帽adir nuevo documento al dataset
    dataset = dataset.add_item({'text': nuevo_texto})

    # Codificar el nuevo documento
    inputs = ctx_tokenizer(nuevo_texto, truncation=True, padding='longest', return_tensors='pt')
    embedding = ctx_encoder(**inputs).pooler_output.detach().numpy()

    # Actualizar embeddings y 铆ndice
    embeddings = np.vstack([embeddings, embedding])
    index.add(embedding)

# Funci贸n para recuperar documentos relevantes
def retrieve_docs(question, k=5):
    inputs = q_tokenizer(question, return_tensors='pt')
    question_embedding = q_encoder(**inputs).pooler_output.detach().numpy()

    distances, indices = index.search(question_embedding, k)
    retrieved_texts = [dataset[i]['text'] for i in indices[0]]
    return retrieved_texts

# Funci贸n para generar respuesta
def generate_answer(question):
    retrieved_docs = retrieve_docs(question)
    context = ' '.join(retrieved_docs)

    input_text = f"Pregunta: {question} Contexto: {context}"
    inputs = gen_tokenizer([input_text], max_length=1024, return_tensors='pt', truncation=True)
    summary_ids = generator.generate(inputs['input_ids'], num_beams=4, max_length=100, early_stopping=True)
    answer = gen_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return answer

# Funci贸n principal de la aplicaci贸n
def responder(archivo, pregunta):
    texto_extraido = ''
    if archivo is not None:
        file_path = archivo.name
        if file_path.endswith('.pdf'):
            texto_extraido = extract_text_from_pdf(file_path)
        elif file_path.endswith('.docx'):
            texto_extraido = extract_text_from_docx(file_path)
        elif file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
            texto_extraido = extract_text_from_image(file_path)
        else:
            return "Formato de archivo no soportado."

        # Actualizar el 铆ndice con el nuevo texto
        actualizar_indice(texto_extraido)

        # Generar respuesta
        respuesta = generate_answer(pregunta)
        return respuesta
    else:
        return "Por favor, sube un archivo."

# Configurar la interfaz de Gradio
interfaz = gr.Interface(
    fn=responder,
    inputs=[
        gr.inputs.File(label="Sube un archivo (PDF, DOCX, Imagen)"),
        gr.inputs.Textbox(lines=2, placeholder="Escribe tu pregunta aqu铆...")
    ],
    outputs="text",
    title="Aplicaci贸n RAG con Extracci贸n de Texto",
    description="Sube un archivo y haz una pregunta sobre su contenido."
)

if __name__ == "__main__":
    interfaz.launch()