|
import copy |
|
import os |
|
import types |
|
import uuid |
|
from typing import Any, Dict, List, Union, Optional |
|
import time |
|
import queue |
|
import pathlib |
|
from datetime import datetime |
|
|
|
from src.utils import hash_file, get_sha |
|
|
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.schema import LLMResult |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.docstore.document import Document |
|
|
|
|
|
class StreamingGradioCallbackHandler(BaseCallbackHandler): |
|
""" |
|
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend |
|
""" |
|
def __init__(self, timeout: Optional[float] = None, block=True): |
|
super().__init__() |
|
self.text_queue = queue.SimpleQueue() |
|
self.stop_signal = None |
|
self.do_stop = False |
|
self.timeout = timeout |
|
self.block = block |
|
|
|
def on_llm_start( |
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
|
) -> None: |
|
"""Run when LLM starts running. Clean the queue.""" |
|
while not self.text_queue.empty(): |
|
try: |
|
self.text_queue.get(block=False) |
|
except queue.Empty: |
|
continue |
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
|
"""Run on new LLM token. Only available when streaming is enabled.""" |
|
self.text_queue.put(token) |
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: |
|
"""Run when LLM ends running.""" |
|
self.text_queue.put(self.stop_signal) |
|
|
|
def on_llm_error( |
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any |
|
) -> None: |
|
"""Run when LLM errors.""" |
|
self.text_queue.put(self.stop_signal) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
while True: |
|
try: |
|
value = self.stop_signal |
|
if self.do_stop: |
|
print("hit stop", flush=True) |
|
|
|
raise StopIteration() |
|
|
|
value = self.text_queue.get(block=self.block, timeout=self.timeout) |
|
break |
|
except queue.Empty: |
|
time.sleep(0.01) |
|
if value == self.stop_signal: |
|
raise StopIteration() |
|
else: |
|
return value |
|
|
|
|
|
def _chunk_sources(sources, chunk=True, chunk_size=512, language=None, db_type=None): |
|
assert db_type is not None |
|
|
|
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources): |
|
|
|
sources = [sources] |
|
if not chunk: |
|
[x.metadata.update(dict(chunk_id=0)) for chunk_id, x in enumerate(sources)] |
|
if db_type in ['chroma', 'chroma_old']: |
|
|
|
source_chunks = [Document(page_content=x.page_content, |
|
metadata=copy.deepcopy(x.metadata) or {}) |
|
for x in sources] |
|
else: |
|
source_chunks = sources |
|
else: |
|
if language and False: |
|
|
|
|
|
|
|
keep_separator = True |
|
separators = RecursiveCharacterTextSplitter.get_separators_for_language(language) |
|
else: |
|
separators = ["\n\n", "\n", " ", ""] |
|
keep_separator = False |
|
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator, |
|
separators=separators) |
|
source_chunks = splitter.split_documents(sources) |
|
|
|
|
|
[x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)] |
|
|
|
if db_type in ['chroma', 'chroma_old']: |
|
|
|
|
|
|
|
|
|
[x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sources)] |
|
|
|
|
|
return list(sources) + source_chunks |
|
else: |
|
return source_chunks |
|
|
|
|
|
def add_parser(docs1, parser): |
|
[x.metadata.update(dict(parser=x.metadata.get('parser', parser))) for x in docs1] |
|
|
|
|
|
def _add_meta(docs1, file, headsize=50, filei=0, parser='NotSet'): |
|
if os.path.isfile(file): |
|
file_extension = pathlib.Path(file).suffix |
|
hashid = hash_file(file) |
|
else: |
|
file_extension = str(file) |
|
hashid = get_sha(file) |
|
doc_hash = str(uuid.uuid4())[:10] |
|
if not isinstance(docs1, (list, tuple, types.GeneratorType)): |
|
docs1 = [docs1] |
|
[x.metadata.update(dict(input_type=file_extension, |
|
parser=x.metadata.get('parser', parser), |
|
date=str(datetime.now()), |
|
time=time.time(), |
|
order_id=order_id, |
|
hashid=hashid, |
|
doc_hash=doc_hash, |
|
file_id=filei, |
|
head=x.page_content[:headsize].strip())) for order_id, x in enumerate(docs1)] |
|
|
|
|
|
def fix_json_meta(docs1): |
|
if not isinstance(docs1, (list, tuple, types.GeneratorType)): |
|
docs1 = [docs1] |
|
|
|
[x.metadata.update(dict(sender_name=x.metadata.get('sender_name') or '')) for x in docs1] |
|
[x.metadata.update(dict(timestamp_ms=x.metadata.get('timestamp_ms') or '')) for x in docs1] |
|
|