Spaces:
Running
Running
""" | |
Gradio UI for Mistral 7B with RAG | |
""" | |
import os | |
from typing import List | |
import gradio as gr | |
from langchain_core.runnables.base import RunnableSequence | |
import numpy as np | |
from confluence_rag import generate_rag_chain, load_pdf, store_vector, load_multiple_pdf | |
def initialize_chain(file: gr.File) -> RunnableSequence: | |
""" | |
Initializes the chain with the given file. | |
If no file is provided, the llm is used without RAG. | |
Args: | |
file (gr.File): file to initialize the chain with | |
Returns: | |
RunnableSequence: the chain | |
""" | |
if file is None: | |
return generate_rag_chain() | |
if len(file) == 1: | |
pdf = load_pdf(file[0].name) | |
else: | |
pdf = load_multiple_pdf([f.name for f in file]) | |
retriever = store_vector(pdf) | |
return generate_rag_chain(retriever) | |
def invoke_chain(message: str, history: List[str], file: gr.File = None) -> str: | |
""" | |
Invokes the chain with the given message and updates the chain if a new file is provided. | |
Args: | |
message (str): message to invoke the chain with | |
history (List[str]): history of messages | |
file (gr.File, optional): file to update the chain with. Defaults to None. | |
Returns: | |
str: the response of the chain | |
""" | |
# Check if file is provided and exists | |
if file is not None and not np.all([os.path.exists(f.name) for f in file]) or len(file) == 0: | |
return "Error: File not found." | |
if file is not None and not np.all([f.name.endswith(".pdf") for f in file]): | |
return "Error: File is not a pdf." | |
chain = initialize_chain(file) | |
return chain.invoke(message) | |
def create_demo() -> gr.Interface: | |
""" | |
Creates and returns a Gradio Chat Interface. | |
Returns: | |
gr.Interface: the Gradio Chat Interface | |
""" | |
return gr.ChatInterface( | |
invoke_chain, | |
additional_inputs=[gr.File(label="File", file_count='multiple')], | |
title="Mistral 7B with RAG", | |
description="Ask questions to Mistral about your pdf document.", | |
theme="soft", | |
) | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch() |