ChatData / backend /chat_bot /private_knowledge_base.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
e931b70 verified
import hashlib
from datetime import datetime
from typing import List, Optional
import pandas as pd
from clickhouse_connect import get_client
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
from streamlit.runtime.uploaded_file_manager import UploadedFile
from backend.chat_bot.tools import parse_files, extract_embedding
from backend.construct.build_retriever_tool import create_retriever_tool
from logger import logger
class ChatBotKnowledgeTable:
def __init__(self, host, port, username, password,
embedding: Embeddings, parser_api_key: str, db="chat",
kb_table="private_kb", tool_table="private_tool") -> None:
super().__init__()
personal_files_schema_ = f"""
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
entity_id String,
file_name String,
text String,
user_id String,
created_by DateTime,
vector Array(Float32),
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
) ENGINE = ReplacingMergeTree ORDER BY entity_id
"""
# `tool_name` represent private knowledge database name.
private_knowledge_base_schema_ = f"""
CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
tool_id String,
tool_name String,
file_names Array(String),
user_id String,
created_by DateTime,
tool_description String
) ENGINE = ReplacingMergeTree ORDER BY tool_id
"""
self.personal_files_table = kb_table
self.private_knowledge_base_table = tool_table
config = MyScaleSettings(
host=host,
port=port,
username=username,
password=password,
database=db,
table=kb_table,
)
self.client = get_client(
host=config.host,
port=config.port,
username=config.username,
password=config.password,
)
self.client.command("SET allow_experimental_object_type=1")
self.client.command(personal_files_schema_)
self.client.command(private_knowledge_base_schema_)
self.parser_api_key = parser_api_key
self.vector_store = MyScaleWithoutJSON(
embedding=embedding,
config=config,
must_have_cols=["file_name", "text", "created_by"],
)
# List all files with given `user_id`
def list_files(self, user_id: str):
query = f"""
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
FROM {self.vector_store.config.database}.{self.personal_files_table}
WHERE user_id = '{user_id}' GROUP BY file_name
"""
return [r for r in self.vector_store.client.query(query).named_results()]
# Parse and embedding files
def add_by_file(self, user_id, files: List[UploadedFile]):
data = parse_files(self.parser_api_key, user_id, files)
data = extract_embedding(self.vector_store.embeddings, data)
self.vector_store.client.insert_df(
table=self.personal_files_table,
df=pd.DataFrame(data),
database=self.vector_store.config.database,
)
# Remove all files and private_knowledge_bases with given `user_id`
def clear(self, user_id: str):
self.vector_store.client.command(
f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} "
f"WHERE user_id='{user_id}'"
)
query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
WHERE user_id = '{user_id}'"""
self.vector_store.client.command(query)
def create_private_knowledge_base(
self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None
):
self.vector_store.client.insert_df(
self.private_knowledge_base_table,
pd.DataFrame(
[
{
"tool_id": hashlib.sha256(
(user_id + tool_name).encode("utf-8")
).hexdigest(),
"tool_name": tool_name, # tool_name represent user's private knowledge base.
"file_names": files,
"user_id": user_id,
"created_by": datetime.now(),
"tool_description": tool_description,
}
]
),
database=self.vector_store.config.database,
)
# Show all private knowledge bases with given `user_id`
def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None):
extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else ""
query = f"""
SELECT tool_name, tool_description, length(file_names)
FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
WHERE user_id = '{user_id}' {extended_where}
"""
return [r for r in self.vector_store.client.query(query).named_results()]
def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]):
unique_list = list(set(private_knowledge_bases))
unique_list = ",".join([f"'{t}'" for t in unique_list])
query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]"""
self.vector_store.client.command(query)
def as_retrieval_tools(self, user_id, tool_name=None):
logger.info(f"")
private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name)
retrievers = {}
for private_kb in private_knowledge_bases:
file_names_sql = f"""
SELECT arrayJoin(file_names) FROM (
SELECT file_names
FROM chat.private_tool
WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}'
)
"""
logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}")
res = self.client.query(file_names_sql)
file_names = []
for line in res.result_rows:
file_names.append(line[0])
file_names = ', '.join(f"'{item}'" for item in file_names)
logger.info(f"user_id is {user_id}, file_names is {file_names}")
retrievers[private_kb["tool_name"]] = create_retriever_tool(
self.vector_store.as_retriever(
search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"},
),
tool_name=private_kb["tool_name"],
description=private_kb["tool_description"],
)
return retrievers