surveyia / main.py
datacipen's picture
Update main.py
db46aef verified
raw
history blame
7.53 kB
import os
import json
import bcrypt
from typing import List
from pathlib import Path
from langchain_huggingface import HuggingFaceEmbeddings
#from langchain_community.llms import HuggingFaceEndpoint
from langchain_huggingface import HuggingFaceEndpoint
#from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain_community.document_loaders import (
PyMuPDFLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
from langchain.callbacks.base import BaseCallbackHandler
import chainlit as cl
from chainlit.input_widget import TextInput, Select, Switch, Slider
from literalai import LiteralClient
@cl.password_auth_callback
def auth_callback(username: str, password: str):
auth = json.loads(os.environ['CHAINLIT_AUTH_LOGIN'])
ident = next(d['ident'] for d in auth if d['ident'] == username)
pwd = next(d['pwd'] for d in auth if d['ident'] == username)
resultLogAdmin = bcrypt.checkpw(username.encode('utf-8'), bcrypt.hashpw(ident.encode('utf-8'), bcrypt.gensalt()))
resultPwdAdmin = bcrypt.checkpw(password.encode('utf-8'), bcrypt.hashpw(pwd.encode('utf-8'), bcrypt.gensalt()))
resultRole = next(d['role'] for d in auth if d['ident'] == username)
if resultLogAdmin and resultPwdAdmin and resultRole == "admindatapcc":
return cl.User(
identifier=ident + " : 🧑‍💼 Admin Datapcc", metadata={"role": "admin", "provider": "credentials"}
)
elif resultLogAdmin and resultPwdAdmin and resultRole == "userdatapcc":
return cl.User(
identifier=ident + " : 🧑‍🎓 User Datapcc", metadata={"role": "user", "provider": "credentials"}
)
literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
chunk_size = 1024
chunk_overlap = 50
embeddings_model = HuggingFaceEmbeddings()
PDF_STORAGE_PATH = "./public/pdfs"
def process_pdfs(pdf_storage_path: str):
pdf_directory = Path(pdf_storage_path)
docs = [] # type: List[Document]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
for pdf_path in pdf_directory.glob("*.pdf"):
loader = PyMuPDFLoader(str(pdf_path))
documents = loader.load()
docs += text_splitter.split_documents(documents)
doc_search = Chroma.from_documents(docs, embeddings_model)
namespace = "chromadb/my_documents"
record_manager = SQLRecordManager(
namespace, db_url="sqlite:///record_manager_cache.sql"
)
record_manager.create_schema()
index_result = index(
docs,
record_manager,
doc_search,
cleanup="incremental",
source_id_key="source",
)
print(f"Indexing stats: {index_result}")
return doc_search
doc_search = process_pdfs(PDF_STORAGE_PATH)
#model = ChatOpenAI(model_name="gpt-4", streaming=True)
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN']
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
model = HuggingFaceEndpoint(
repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True
)
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"Doc Chain Assistant": "Assistant Reviewstream"}
return rename_dict.get(orig_author, orig_author)
@cl.set_chat_profiles
async def chat_profile():
return [
cl.ChatProfile(name="Reviewstream",markdown_description="Requêter sur les publications de recherche",icon="/public/logo-ofipe.jpg",),
cl.ChatProfile(name="Imagestream",markdown_description="Requêter sur un ensemble d'images",icon="./public/logo-ofipe.jpg",),
]
@cl.on_chat_start
async def on_chat_start():
await cl.Message(f"> REVIEWSTREAM").send()
await cl.Message(f"Nous avons le plaisir de vous accueillir dans l'application de recherche et d'analyse des publications.").send()
listPrompts_name = f"Liste des revues de recherche"
contentPrompts = """<p><img src='/public/hal-logo-header.png' width='32' align='absmiddle' /> <strong> Hal Archives Ouvertes</strong> : Une archive ouverte est un réservoir numérique contenant des documents issus de la recherche scientifique, généralement déposés par leurs auteurs, et permettant au grand public d'y accéder gratuitement et sans contraintes.
</p>
<p><img src='/public/logo-persee.png' width='32' align='absmiddle' /> <strong>Persée</strong> : offre un accès libre et gratuit à des collections complètes de publications scientifiques (revues, livres, actes de colloques, publications en série, sources primaires, etc.) associé à une gamme d'outils de recherche et d'exploitation.</p>
"""
prompt_elements = []
prompt_elements.append(
cl.Text(content=contentPrompts, name=listPrompts_name, display="side")
)
await cl.Message(content="📚 " + listPrompts_name, elements=prompt_elements).send()
settings = await cl.ChatSettings(
[
Select(
id="Model",
label="Publications de recherche",
values=["---", "HAL", "Persée"],
initial_index=0,
),
]
).send()
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
def format_docs(docs):
return "\n\n".join([d.page_content for d in docs])
retriever = doc_search.as_retriever()
runnable = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable") # type: Runnable
msg = cl.Message(content="")
class PostMessageHandler(BaseCallbackHandler):
"""
Callback handler for handling the retriever and LLM processes.
Used to post the sources of the retrieved documents as a Chainlit element.
"""
def __init__(self, msg: cl.Message):
BaseCallbackHandler.__init__(self)
self.msg = msg
self.sources = set() # To store unique pairs
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
for d in documents:
source_page_pair = (d.metadata['source'], d.metadata['page'])
self.sources.add(source_page_pair) # Add unique pairs to the set
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
if len(self.sources):
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
self.msg.elements.append(
cl.Text(name="Sources", content=sources_text, display="inline")
)
async with cl.Step(type="run", name="QA Assistant"):
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(callbacks=[
cl.LangchainCallbackHandler(),
PostMessageHandler(msg)
]),
):
await msg.stream_token(chunk)
await msg.send()