Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
@Time : 2023/6/8 14:03 | |
@Author : alexanderwu | |
@File : https://github.com/geekan/MetaGPT/blob/main/metagpt/document_store/document.py | |
""" | |
from pathlib import Path | |
import pandas as pd | |
from langchain.document_loaders import ( | |
TextLoader, | |
UnstructuredPDFLoader, | |
UnstructuredWordDocumentLoader, | |
) | |
from langchain.text_splitter import CharacterTextSplitter | |
from tqdm import tqdm | |
def validate_cols(content_col: str, df: pd.DataFrame): | |
if content_col not in df.columns: | |
raise ValueError | |
def read_data(data_path: Path): | |
suffix = data_path.suffix | |
if '.xlsx' == suffix: | |
data = pd.read_excel(data_path) | |
elif '.csv' == suffix: | |
data = pd.read_csv(data_path) | |
elif '.json' == suffix: | |
data = pd.read_json(data_path) | |
elif suffix in ('.docx', '.doc'): | |
data = UnstructuredWordDocumentLoader(str(data_path), mode='elements').load() | |
elif '.txt' == suffix: | |
data = TextLoader(str(data_path)).load() | |
text_splitter = CharacterTextSplitter(separator='\n', chunk_size=256, chunk_overlap=0) | |
texts = text_splitter.split_documents(data) | |
data = texts | |
elif '.pdf' == suffix: | |
data = UnstructuredPDFLoader(str(data_path), mode="elements").load() | |
else: | |
raise NotImplementedError | |
return data | |
class Document: | |
def __init__(self, data_path, content_col='content', meta_col='metadata'): | |
self.data = read_data(data_path) | |
if isinstance(self.data, pd.DataFrame): | |
validate_cols(content_col, self.data) | |
self.content_col = content_col | |
self.meta_col = meta_col | |
def _get_docs_and_metadatas_by_df(self) -> (list, list): | |
df = self.data | |
docs = [] | |
metadatas = [] | |
for i in tqdm(range(len(df))): | |
docs.append(df[self.content_col].iloc[i]) | |
if self.meta_col: | |
metadatas.append({self.meta_col: df[self.meta_col].iloc[i]}) | |
else: | |
metadatas.append({}) | |
return docs, metadatas | |
def _get_docs_and_metadatas_by_langchain(self) -> (list, list): | |
data = self.data | |
docs = [i.page_content for i in data] | |
metadatas = [i.metadata for i in data] | |
return docs, metadatas | |
def get_docs_and_metadatas(self) -> (list, list): | |
if isinstance(self.data, pd.DataFrame): | |
return self._get_docs_and_metadatas_by_df() | |
elif isinstance(self.data, list): | |
return self._get_docs_and_metadatas_by_langchain() | |
else: | |
raise NotImplementedError | |