Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain_community.document_loaders import WebBaseLoader, PyPDFLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain_community import embeddings | |
from langchain_community.chat_models import ChatOllama | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain.output_parsers import PydanticOutputParser | |
from langchain.text_splitter import CharacterTextSplitter | |
import os | |
def process_input(urls, question): | |
# get a token: https://huggingface.co/docs/api-inference/quicktour#get-your-api-token | |
from getpass import getpass | |
HUGGINGFACEHUB_API_TOKEN = getpass() | |
os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'HUGGINGFACEHUB_API_TOKEN' | |
repo_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
model_local = HuggingFaceEndpoint(repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN | |
) | |
# Convert string of URLs to list | |
urls_list = urls.split("\n") | |
docs = [WebBaseLoader(url).load() for url in urls_list] | |
docs_list = [item for sublist in docs for item in sublist] | |
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=7500, chunk_overlap=100) | |
doc_splits = text_splitter.split_documents(docs_list) | |
vectorstore = Chroma.from_documents( | |
documents=doc_splits, | |
collection_name="rag-chroma", | |
embedding=embeddings.ollama.OllamaEmbeddings(model='nomic-embed-text'), | |
) | |
retriever = vectorstore.as_retriever() | |
after_rag_template = """Answer the question based only on the following context: | |
{context} | |
Question: {question} | |
""" | |
after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template) | |
after_rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| after_rag_prompt | |
| model_local | |
| StrOutputParser() | |
) | |
return after_rag_chain.invoke(question) | |
# Define Gradio interface | |
iface = gr.Interface(fn=process_input, | |
inputs=[gr.Textbox(label="Enter URLs separated by new lines"), gr.Textbox(label="Question")], | |
# server_name | |
outputs="text", | |
title="Document Query with Ollama", | |
description="Enter URLs and a question to query the documents.") | |
iface.launch() | |