|
from __future__ import annotations |
|
|
|
import atexit |
|
import concurrent |
|
import copy |
|
import difflib |
|
import re |
|
import threading |
|
import traceback |
|
import os |
|
import time |
|
import urllib.parse |
|
import uuid |
|
import warnings |
|
from concurrent.futures import Future |
|
from datetime import timedelta |
|
from enum import Enum |
|
from functools import lru_cache |
|
from pathlib import Path |
|
from typing import Callable, Generator, Any, Union, List, Dict, Literal, Tuple |
|
import ast |
|
import inspect |
|
import numpy as np |
|
|
|
try: |
|
from gradio_utils.yield_utils import ReturnType |
|
except (ImportError, ModuleNotFoundError): |
|
try: |
|
from yield_utils import ReturnType |
|
except (ImportError, ModuleNotFoundError): |
|
try: |
|
from src.yield_utils import ReturnType |
|
except (ImportError, ModuleNotFoundError): |
|
from .src.yield_utils import ReturnType |
|
|
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" |
|
|
|
from huggingface_hub import SpaceStage |
|
from huggingface_hub.utils import ( |
|
build_hf_headers, |
|
) |
|
|
|
from gradio_client import utils |
|
|
|
from importlib.metadata import distribution, PackageNotFoundError |
|
|
|
lock = threading.Lock() |
|
|
|
try: |
|
assert distribution("gradio_client") is not None |
|
have_gradio_client = True |
|
from packaging import version |
|
|
|
client_version = distribution("gradio_client").version |
|
is_gradio_client_version7plus = version.parse(client_version) >= version.parse( |
|
"0.7.0" |
|
) |
|
except (PackageNotFoundError, AssertionError): |
|
have_gradio_client = False |
|
is_gradio_client_version7plus = False |
|
|
|
from gradio_client.client import Job, DEFAULT_TEMP_DIR, Endpoint |
|
from gradio_client import Client |
|
|
|
|
|
def check_job(job, timeout=0.0, raise_exception=True, verbose=False): |
|
try: |
|
e = job.exception(timeout=timeout) |
|
except concurrent.futures.TimeoutError: |
|
|
|
if verbose: |
|
print("not enough time to determine job status: %s" % timeout) |
|
e = None |
|
if e: |
|
|
|
if raise_exception: |
|
raise RuntimeError(traceback.format_exception(e)) |
|
else: |
|
return e |
|
|
|
|
|
|
|
class LangChainAction(Enum): |
|
"""LangChain action""" |
|
|
|
QUERY = "Query" |
|
SUMMARIZE_MAP = "Summarize" |
|
EXTRACT = "Extract" |
|
|
|
|
|
pre_prompt_query0 = "Pay attention and remember the information below, which will help to answer the question or imperative after the context ends." |
|
prompt_query0 = "According to only the information in the document sources provided within the context above: " |
|
|
|
pre_prompt_summary0 = """""" |
|
prompt_summary0 = "Using only the information in the document sources above, write a condensed and concise well-structured Markdown summary of key results." |
|
|
|
pre_prompt_extraction0 = ( |
|
"""In order to extract information, pay attention to the following text.""" |
|
) |
|
prompt_extraction0 = ( |
|
"Using only the information in the document sources above, extract " |
|
) |
|
|
|
hyde_llm_prompt0 = "Answer this question with vibrant details in order for some NLP embedding model to use that answer as better query than original question: " |
|
|
|
client_version = distribution("gradio_client").version |
|
old_gradio = version.parse(client_version) <= version.parse("0.6.1") |
|
|
|
|
|
class CommonClient: |
|
def question(self, instruction, *args, **kwargs) -> str: |
|
""" |
|
Prompt LLM (direct to LLM with instruct prompting required for instruct models) and get response |
|
""" |
|
kwargs["instruction"] = kwargs.get("instruction", instruction) |
|
kwargs["langchain_action"] = LangChainAction.QUERY.value |
|
kwargs["langchain_mode"] = "LLM" |
|
ret = "" |
|
for ret1 in self.query_or_summarize_or_extract(*args, **kwargs): |
|
ret = ret1.reply |
|
return ret |
|
|
|
def question_stream( |
|
self, instruction, *args, **kwargs |
|
) -> Generator[ReturnType, None, None]: |
|
""" |
|
Prompt LLM (direct to LLM with instruct prompting required for instruct models) and get response |
|
""" |
|
kwargs["instruction"] = kwargs.get("instruction", instruction) |
|
kwargs["langchain_action"] = LangChainAction.QUERY.value |
|
kwargs["langchain_mode"] = "LLM" |
|
ret = yield from self.query_or_summarize_or_extract(*args, **kwargs) |
|
return ret |
|
|
|
def query(self, query, *args, **kwargs) -> str: |
|
""" |
|
Search for documents matching a query, then ask that query to LLM with those documents |
|
""" |
|
kwargs["instruction"] = kwargs.get("instruction", query) |
|
kwargs["langchain_action"] = LangChainAction.QUERY.value |
|
ret = "" |
|
for ret1 in self.query_or_summarize_or_extract(*args, **kwargs): |
|
ret = ret1.reply |
|
return ret |
|
|
|
def query_stream(self, query, *args, **kwargs) -> Generator[ReturnType, None, None]: |
|
""" |
|
Search for documents matching a query, then ask that query to LLM with those documents |
|
""" |
|
kwargs["instruction"] = kwargs.get("instruction", query) |
|
kwargs["langchain_action"] = LangChainAction.QUERY.value |
|
ret = yield from self.query_or_summarize_or_extract(*args, **kwargs) |
|
return ret |
|
|
|
def summarize(self, *args, query=None, focus=None, **kwargs) -> str: |
|
""" |
|
Search for documents matching a focus, then ask a query to LLM with those documents |
|
If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used |
|
""" |
|
kwargs["prompt_summary"] = kwargs.get( |
|
"prompt_summary", query or prompt_summary0 |
|
) |
|
kwargs["instruction"] = kwargs.get("instruction", focus) |
|
kwargs["langchain_action"] = LangChainAction.SUMMARIZE_MAP.value |
|
ret = "" |
|
for ret1 in self.query_or_summarize_or_extract(*args, **kwargs): |
|
ret = ret1.reply |
|
return ret |
|
|
|
def summarize_stream(self, *args, query=None, focus=None, **kwargs) -> str: |
|
""" |
|
Search for documents matching a focus, then ask a query to LLM with those documents |
|
If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used |
|
""" |
|
kwargs["prompt_summary"] = kwargs.get( |
|
"prompt_summary", query or prompt_summary0 |
|
) |
|
kwargs["instruction"] = kwargs.get("instruction", focus) |
|
kwargs["langchain_action"] = LangChainAction.SUMMARIZE_MAP.value |
|
ret = yield from self.query_or_summarize_or_extract(*args, **kwargs) |
|
return ret |
|
|
|
def extract(self, *args, query=None, focus=None, **kwargs) -> list[str]: |
|
""" |
|
Search for documents matching a focus, then ask a query to LLM with those documents |
|
If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used |
|
""" |
|
kwargs["prompt_extraction"] = kwargs.get( |
|
"prompt_extraction", query or prompt_extraction0 |
|
) |
|
kwargs["instruction"] = kwargs.get("instruction", focus) |
|
kwargs["langchain_action"] = LangChainAction.EXTRACT.value |
|
ret = "" |
|
for ret1 in self.query_or_summarize_or_extract(*args, **kwargs): |
|
ret = ret1.reply |
|
return ret |
|
|
|
def extract_stream(self, *args, query=None, focus=None, **kwargs) -> list[str]: |
|
""" |
|
Search for documents matching a focus, then ask a query to LLM with those documents |
|
If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used |
|
""" |
|
kwargs["prompt_extraction"] = kwargs.get( |
|
"prompt_extraction", query or prompt_extraction0 |
|
) |
|
kwargs["instruction"] = kwargs.get("instruction", focus) |
|
kwargs["langchain_action"] = LangChainAction.EXTRACT.value |
|
ret = yield from self.query_or_summarize_or_extract(*args, **kwargs) |
|
return ret |
|
|
|
def get_client_kwargs(self, **kwargs): |
|
client_kwargs = {} |
|
try: |
|
from src.evaluate_params import eval_func_param_names |
|
except (ImportError, ModuleNotFoundError): |
|
try: |
|
from evaluate_params import eval_func_param_names |
|
except (ImportError, ModuleNotFoundError): |
|
from .src.evaluate_params import eval_func_param_names |
|
|
|
for k in eval_func_param_names: |
|
if k in kwargs: |
|
client_kwargs[k] = kwargs[k] |
|
|
|
if os.getenv("HARD_ASSERTS"): |
|
fun_kwargs = { |
|
k: v.default |
|
for k, v in dict( |
|
inspect.signature(self.query_or_summarize_or_extract).parameters |
|
).items() |
|
} |
|
diff = set(eval_func_param_names).difference(fun_kwargs) |
|
assert len(diff) == 0, ( |
|
"Add query_or_summarize_or_extract entries: %s" % diff |
|
) |
|
|
|
extra_query_params = [ |
|
"file", |
|
"bad_error_string", |
|
"print_info", |
|
"asserts", |
|
"url", |
|
"prompt_extraction", |
|
"model", |
|
"text", |
|
"print_error", |
|
"pre_prompt_extraction", |
|
"embed", |
|
"print_warning", |
|
"sanitize_llm", |
|
] |
|
diff = set(fun_kwargs).difference( |
|
eval_func_param_names + extra_query_params |
|
) |
|
assert len(diff) == 0, "Add eval_func_params entries: %s" % diff |
|
|
|
return client_kwargs |
|
|
|
def get_query_kwargs(self, **kwargs): |
|
fun_dict = dict( |
|
inspect.signature(self.query_or_summarize_or_extract).parameters |
|
).items() |
|
fun_kwargs = {k: kwargs.get(k, v.default) for k, v in fun_dict} |
|
|
|
return fun_kwargs |
|
|
|
@staticmethod |
|
def check_error(res_dict): |
|
actual_llm = "" |
|
try: |
|
actual_llm = res_dict["save_dict"]["display_name"] |
|
except: |
|
pass |
|
if "error" in res_dict and res_dict["error"]: |
|
raise RuntimeError(f"Error from LLM {actual_llm}: {res_dict['error']}") |
|
if "error_ex" in res_dict and res_dict["error_ex"]: |
|
raise RuntimeError( |
|
f"Error Traceback from LLM {actual_llm}: {res_dict['error_ex']}" |
|
) |
|
if "response" not in res_dict: |
|
raise ValueError(f"No response from LLM {actual_llm}") |
|
|
|
def query_or_summarize_or_extract( |
|
self, |
|
print_error=print, |
|
print_info=print, |
|
print_warning=print, |
|
bad_error_string=None, |
|
sanitize_llm=None, |
|
h2ogpt_key: str = None, |
|
instruction: str = "", |
|
text: list[str] | str | None = None, |
|
file: list[str] | str | None = None, |
|
url: list[str] | str | None = None, |
|
embed: bool = True, |
|
chunk: bool = True, |
|
chunk_size: int = 512, |
|
langchain_mode: str = None, |
|
langchain_action: str | None = None, |
|
langchain_agents: List[str] = [], |
|
top_k_docs: int = 10, |
|
document_choice: Union[str, List[str]] = "All", |
|
document_subset: str = "Relevant", |
|
document_source_substrings: Union[str, List[str]] = [], |
|
document_source_substrings_op: str = "and", |
|
document_content_substrings: Union[str, List[str]] = [], |
|
document_content_substrings_op: str = "and", |
|
system_prompt: str | None = "", |
|
pre_prompt_query: str | None = pre_prompt_query0, |
|
prompt_query: str | None = prompt_query0, |
|
pre_prompt_summary: str | None = pre_prompt_summary0, |
|
prompt_summary: str | None = prompt_summary0, |
|
pre_prompt_extraction: str | None = pre_prompt_extraction0, |
|
prompt_extraction: str | None = prompt_extraction0, |
|
hyde_llm_prompt: str | None = hyde_llm_prompt0, |
|
all_docs_start_prompt: str | None = None, |
|
all_docs_finish_prompt: str | None = None, |
|
user_prompt_for_fake_system_prompt: str = None, |
|
json_object_prompt: str = None, |
|
json_object_prompt_simpler: str = None, |
|
json_code_prompt: str = None, |
|
json_code_prompt_if_no_schema: str = None, |
|
json_schema_instruction: str = None, |
|
json_preserve_system_prompt: bool = False, |
|
json_object_post_prompt_reminder: str = None, |
|
json_code_post_prompt_reminder: str = None, |
|
json_code2_post_prompt_reminder: str = None, |
|
model: str | int | None = None, |
|
model_lock: dict | None = None, |
|
stream_output: bool = False, |
|
enable_caching: bool = False, |
|
do_sample: bool = False, |
|
seed: int | None = 0, |
|
temperature: float = 0.0, |
|
top_p: float = 1.0, |
|
top_k: int = 40, |
|
|
|
repetition_penalty: float = 1.0, |
|
penalty_alpha: float = 0.0, |
|
max_time: int = 360, |
|
max_new_tokens: int = 1024, |
|
add_search_to_context: bool = False, |
|
chat_conversation: list[tuple[str, str]] | None = None, |
|
text_context_list: list[str] | None = None, |
|
docs_ordering_type: str | None = None, |
|
min_max_new_tokens: int = 512, |
|
max_input_tokens: int = -1, |
|
max_total_input_tokens: int = -1, |
|
docs_token_handling: str = "split_or_merge", |
|
docs_joiner: str = "\n\n", |
|
hyde_level: int = 0, |
|
hyde_template: str = None, |
|
hyde_show_only_final: bool = True, |
|
doc_json_mode: bool = False, |
|
metadata_in_context: list = [], |
|
image_file: Union[str, list] = None, |
|
image_control: str = None, |
|
images_num_max: int = None, |
|
image_resolution: tuple = None, |
|
image_format: str = None, |
|
rotate_align_resize_image: bool = None, |
|
video_frame_period: int = None, |
|
image_batch_image_prompt: str = None, |
|
image_batch_final_prompt: str = None, |
|
image_batch_stream: bool = None, |
|
visible_vision_models: Union[str, int, list] = None, |
|
video_file: Union[str, list] = None, |
|
response_format: str = "text", |
|
guided_json: Union[str, dict] = "", |
|
guided_regex: str = "", |
|
guided_choice: List[str] | None = None, |
|
guided_grammar: str = "", |
|
guided_whitespace_pattern: str = None, |
|
prompt_type: Union[int, str] = None, |
|
prompt_dict: Dict = None, |
|
chat_template: str = None, |
|
jq_schema=".[]", |
|
llava_prompt: str = "auto", |
|
image_audio_loaders: list = None, |
|
url_loaders: list = None, |
|
pdf_loaders: list = None, |
|
extract_frames: int = 10, |
|
add_chat_history_to_context: bool = True, |
|
chatbot_role: str = "None", |
|
speaker: str = "None", |
|
tts_language: str = "autodetect", |
|
tts_speed: float = 1.0, |
|
visible_image_models: List[str] = [], |
|
image_size: str = "1024x1024", |
|
image_quality: str = 'standard', |
|
image_guidance_scale: float = 3.0, |
|
image_num_inference_steps: int = 30, |
|
visible_models: Union[str, int, list] = None, |
|
client_metadata: str = '', |
|
|
|
num_return_sequences: int = None, |
|
chat: bool = True, |
|
min_new_tokens: int = None, |
|
early_stopping: Union[bool, str] = None, |
|
iinput: str = "", |
|
iinput_nochat: str = "", |
|
instruction_nochat: str = "", |
|
context: str = "", |
|
num_beams: int = 1, |
|
asserts: bool = False, |
|
do_lock: bool = False, |
|
) -> Generator[ReturnType, None, None]: |
|
""" |
|
Query or Summarize or Extract using h2oGPT |
|
Args: |
|
instruction: Query for LLM chat. Used for similarity search |
|
|
|
For query, prompt template is: |
|
"{pre_prompt_query} |
|
\"\"\" |
|
{content} |
|
\"\"\" |
|
{prompt_query}{instruction}" |
|
If added to summarization, prompt template is |
|
"{pre_prompt_summary} |
|
\"\"\" |
|
{content} |
|
\"\"\" |
|
Focusing on {instruction}, {prompt_summary}" |
|
text: textual content or list of such contents |
|
file: a local file to upload or files to upload |
|
url: a url to give or urls to use |
|
embed: whether to embed content uploaded |
|
|
|
:param langchain_mode: "LLM" to talk to LLM with no docs, "MyData" for personal docs, "UserData" for shared docs, etc. |
|
:param langchain_action: Action to take, "Query" or "Summarize" or "Extract" |
|
:param langchain_agents: Which agents to use, if any |
|
:param top_k_docs: number of document parts. |
|
When doing query, number of chunks |
|
When doing summarization, not related to vectorDB chunks that are not used |
|
E.g. if PDF, then number of pages |
|
:param chunk: whether to chunk sources for document Q/A |
|
:param chunk_size: Size in characters of chunks |
|
:param document_choice: Which documents ("All" means all) -- need to use upload_api API call to get server's name if want to select |
|
:param document_subset: Type of query, see src/gen.py |
|
:param document_source_substrings: See gen.py |
|
:param document_source_substrings_op: See gen.py |
|
:param document_content_substrings: See gen.py |
|
:param document_content_substrings_op: See gen.py |
|
|
|
:param system_prompt: pass system prompt to models that support it. |
|
If 'auto' or None, then use automatic version |
|
If '', then use no system prompt (default) |
|
:param pre_prompt_query: Prompt that comes before document part |
|
:param prompt_query: Prompt that comes after document part |
|
:param pre_prompt_summary: Prompt that comes before document part |
|
None makes h2oGPT internally use its defaults |
|
E.g. "In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text" |
|
:param prompt_summary: Prompt that comes after document part |
|
None makes h2oGPT internally use its defaults |
|
E.g. "Using only the text above, write a condensed and concise summary of key results (preferably as bullet points):\n" |
|
i.e. for some internal document part fstring, the template looks like: |
|
template = "%s |
|
\"\"\" |
|
%s |
|
\"\"\" |
|
%s" % (pre_prompt_summary, fstring, prompt_summary) |
|
:param hyde_llm_prompt: hyde prompt for first step when using LLM |
|
:param all_docs_start_prompt: start of document block |
|
:param all_docs_finish_prompt: finish of document block |
|
|
|
:param user_prompt_for_fake_system_prompt: user part of pre-conversation if LLM doesn't handle system prompt |
|
:param json_object_prompt: prompt for getting LLM to do JSON object |
|
:param json_object_prompt_simpler: simpler of "" for MistralAI |
|
:param json_code_prompt: prompt for getting LLm to do JSON in code block |
|
:param json_code_prompt_if_no_schema: prompt for getting LLM to do JSON in code block if no schema |
|
:param json_schema_instruction: prompt for LLM to use schema |
|
:param json_preserve_system_prompt: Whether to preserve system prompt for json mode |
|
:param json_object_post_prompt_reminder: json object reminder about JSON |
|
:param json_code_post_prompt_reminder: json code w/ schema reminder about JSON |
|
:param json_code2_post_prompt_reminder: json code wo/ schema reminder about JSON |
|
|
|
:param h2ogpt_key: Access Key to h2oGPT server (if not already set in client at init time) |
|
:param model: base_model name or integer index of model_lock on h2oGPT server |
|
None results in use of first (0th index) model in server |
|
to get list of models do client.list_models() |
|
:param model_lock: dict of states or single state, with dict of things like inference server, to use when using dynamic LLM (not from existing model lock on h2oGPT) |
|
:param pre_prompt_extraction: Same as pre_prompt_summary but for when doing extraction |
|
:param prompt_extraction: Same as prompt_summary but for when doing extraction |
|
:param do_sample: see src/gen.py |
|
:param seed: see src/gen.py |
|
:param temperature: see src/gen.py |
|
:param top_p: see src/gen.py |
|
:param top_k: see src/gen.py |
|
:param repetition_penalty: see src/gen.py |
|
:param penalty_alpha: see src/gen.py |
|
:param max_new_tokens: see src/gen.py |
|
:param min_max_new_tokens: see src/gen.py |
|
:param max_input_tokens: see src/gen.py |
|
:param max_total_input_tokens: see src/gen.py |
|
:param stream_output: Whether to stream output |
|
:param enable_caching: Whether to enable caching |
|
:param max_time: how long to take |
|
|
|
:param add_search_to_context: Whether to do web search and add results to context |
|
:param chat_conversation: List of tuples for (human, bot) conversation that will be pre-appended to an (instruction, None) case for a query |
|
:param text_context_list: List of strings to add to context for non-database version of document Q/A for faster handling via API etc. |
|
Forces LangChain code path and uses as many entries in list as possible given max_seq_len, with first assumed to be most relevant and to go near prompt. |
|
:param docs_ordering_type: By default uses 'reverse_ucurve_sort' for optimal retrieval |
|
:param max_input_tokens: Max input tokens to place into model context for each LLM call |
|
-1 means auto, fully fill context for query, and fill by original document chunk for summarization |
|
>=0 means use that to limit context filling to that many tokens |
|
:param max_total_input_tokens: like max_input_tokens but instead of per LLM call, applies across all LLM calls for single summarization/extraction action |
|
:param max_new_tokens: Maximum new tokens |
|
:param min_max_new_tokens: minimum value for max_new_tokens when auto-adjusting for content of prompt, docs, etc. |
|
|
|
:param docs_token_handling: 'chunk' means fill context with top_k_docs (limited by max_input_tokens or model_max_len) chunks for query |
|
or top_k_docs original document chunks summarization |
|
None or 'split_or_merge' means same as 'chunk' for query, while for summarization merges documents to fill up to max_input_tokens or model_max_len tokens |
|
:param docs_joiner: string to join lists of text when doing split_or_merge. None means '\n\n' |
|
:param hyde_level: 0-3 for HYDE. |
|
0 uses just query to find similarity with docs |
|
1 uses query + pure LLM response to find similarity with docs |
|
2: uses query + LLM response using docs to find similarity with docs |
|
3+: etc. |
|
:param hyde_template: see src/gen.py |
|
:param hyde_show_only_final: see src/gen.py |
|
:param doc_json_mode: see src/gen.py |
|
:param metadata_in_context: see src/gen.py |
|
|
|
:param image_file: Initial image for UI (or actual image for CLI) Vision Q/A. Or list of images for some models |
|
:param image_control: Initial image for UI Image Control |
|
:param images_num_max: Max. number of images per LLM call |
|
:param image_resolution: Resolution of any images |
|
:param image_format: Image format |
|
:param rotate_align_resize_image: Whether to apply rotation, alignment, resize before giving to LLM |
|
:param video_frame_period: Period of frames to use from video |
|
:param image_batch_image_prompt: Prompt used to query image only if doing batching of images |
|
:param image_batch_final_prompt: Prompt used to query result of batching of images |
|
:param image_batch_stream: Whether to stream batching of images. |
|
:param visible_vision_models: Model to use for vision, e.g. if base LLM has no vision |
|
If 'auto', then use CLI value, else use model display name given here |
|
:param video_file: DO NOT USE FOR API, put images, videos, urls, and youtube urls in image_file as list |
|
|
|
:param response_format: text or json_object or json_code |
|
# https://github.com/vllm-project/vllm/blob/a3c226e7eb19b976a937e745f3867eb05f809278/vllm/entrypoints/openai/protocol.py#L117-L135 |
|
:param guided_json: str or dict of JSON schema |
|
:param guided_regex: |
|
:param guided_choice: list of strings to have LLM choose from |
|
:param guided_grammar: |
|
:param guided_whitespace_pattern: |
|
|
|
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model |
|
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True) |
|
:param chat_template: jinja HF transformers chat_template to use. '' or None means no change to template |
|
|
|
:param jq_schema: control json loader |
|
By default '.[]' ingests everything in brute-force way, but better to match your schema |
|
See: https://python.langchain.com/docs/modules/data_connection/document_loaders/json#using-jsonloader |
|
|
|
:param extract_frames: How many unique frames to extract from video (if 0, then just do audio if audio type file as well) |
|
|
|
:param llava_prompt: Prompt passed to LLaVa for querying the image |
|
|
|
:param image_audio_loaders: which loaders to use for image and audio parsing (None means default) |
|
:param url_loaders: which loaders to use for url parsing (None means default) |
|
:param pdf_loaders: which loaders to use for pdf parsing (None means default) |
|
|
|
:param add_chat_history_to_context: Include chat context when performing action |
|
Not supported when using CLI mode |
|
|
|
:param chatbot_role: Default role for coqui models. If 'None', then don't by default speak when launching h2oGPT for coqui model choice. |
|
:param speaker: Default speaker for microsoft models If 'None', then don't by default speak when launching h2oGPT for microsoft model choice. |
|
:param tts_language: Default language for coqui models |
|
:param tts_speed: Default speed of TTS, < 1.0 (needs rubberband) for slower than normal, > 1.0 for faster. Tries to keep fixed pitch. |
|
|
|
:param visible_image_models: Which image gen models to include |
|
:param image_size |
|
:param image_quality |
|
:param image_guidance_scale |
|
:param image_num_inference_steps |
|
:param visible_models: Which models in model_lock list to show by default |
|
Takes integers of position in model_lock (model_states) list or strings of base_model names |
|
Ignored if model_lock not used |
|
For nochat API, this is single item within a list for model by name or by index in model_lock |
|
If None, then just use first model in model_lock list |
|
If model_lock not set, use model selected by CLI --base_model etc. |
|
Note that unlike h2ogpt_key, this visible_models only applies to this running h2oGPT server, |
|
and the value is not used to access the inference server. |
|
If need a visible_models for an inference server, then use --model_lock and group together. |
|
:param client_metadata: |
|
:param asserts: whether to do asserts to ensure handling is correct |
|
|
|
Returns: summary/answer: str or extraction List[str] |
|
|
|
""" |
|
if self.config is None: |
|
self.setup() |
|
if self.persist: |
|
client = self |
|
else: |
|
client = self.clone() |
|
try: |
|
h2ogpt_key = h2ogpt_key or self.h2ogpt_key |
|
client.h2ogpt_key = h2ogpt_key |
|
|
|
if model is not None and visible_models is None: |
|
visible_models = model |
|
client.check_model(model) |
|
|
|
|
|
|
|
langchain_mode = langchain_mode or "MyData" |
|
loaders = tuple([None, None, None, None, None, None]) |
|
doc_options = tuple([langchain_mode, chunk, chunk_size, embed]) |
|
asserts |= bool(os.getenv("HARD_ASSERTS", False)) |
|
if ( |
|
text |
|
and isinstance(text, list) |
|
and not file |
|
and not url |
|
and not text_context_list |
|
): |
|
|
|
text_context_list = text |
|
text = None |
|
|
|
res = [] |
|
if text: |
|
t0 = time.time() |
|
res = client.predict( |
|
text, *doc_options, *loaders, h2ogpt_key, api_name="/add_text" |
|
) |
|
t1 = time.time() |
|
print_info("upload text: %s" % str(timedelta(seconds=t1 - t0))) |
|
if asserts: |
|
assert res[0] is None |
|
assert res[1] == langchain_mode |
|
assert "user_paste" in res[2] |
|
assert res[3] == "" |
|
if file: |
|
|
|
|
|
_, file = client.predict(file, api_name="/upload_api") |
|
|
|
res = client.predict( |
|
file, *doc_options, *loaders, h2ogpt_key, api_name="/add_file_api" |
|
) |
|
if asserts: |
|
assert res[0] is None |
|
assert res[1] == langchain_mode |
|
assert os.path.basename(file) in res[2] |
|
assert res[3] == "" |
|
if url: |
|
res = client.predict( |
|
url, *doc_options, *loaders, h2ogpt_key, api_name="/add_url" |
|
) |
|
if asserts: |
|
assert res[0] is None |
|
assert res[1] == langchain_mode |
|
assert url in res[2] |
|
assert res[3] == "" |
|
assert res[4] |
|
if res and not res[4] and "Exception" in res[2]: |
|
print_error("Exception: %s" % res[2]) |
|
|
|
|
|
api_name = "/submit_nochat_api" |
|
|
|
pre_prompt_summary = ( |
|
pre_prompt_summary |
|
if langchain_action == LangChainAction.SUMMARIZE_MAP.value |
|
else pre_prompt_extraction |
|
) |
|
prompt_summary = ( |
|
prompt_summary |
|
if langchain_action == LangChainAction.SUMMARIZE_MAP.value |
|
else prompt_extraction |
|
) |
|
|
|
chat_conversation = ( |
|
chat_conversation |
|
if chat_conversation or not self.persist |
|
else self.chat_conversation.copy() |
|
) |
|
|
|
locals_for_client = locals().copy() |
|
locals_for_client.pop("self", None) |
|
client_kwargs = self.get_client_kwargs(**locals_for_client) |
|
|
|
|
|
if do_lock: |
|
with lock: |
|
self.server_hash = client.server_hash |
|
else: |
|
self.server_hash = client.server_hash |
|
|
|
|
|
if self.persist: |
|
self.chat_conversation.append((instruction, None)) |
|
|
|
|
|
actual_llm = visible_models |
|
response = "" |
|
texts_out = [] |
|
trials = 3 |
|
|
|
|
|
trials_generation = 10 |
|
trial = 0 |
|
trial_generation = 0 |
|
t0 = time.time() |
|
input_tokens = 0 |
|
output_tokens = 0 |
|
tokens_per_second = 0 |
|
vision_visible_model = None |
|
vision_batch_input_tokens = 0 |
|
vision_batch_output_tokens = 0 |
|
vision_batch_tokens_per_second = 0 |
|
t_taken_s = None |
|
while True: |
|
time_to_first_token = None |
|
t0 = time.time() |
|
try: |
|
if not stream_output: |
|
res = client.predict( |
|
str(dict(client_kwargs)), |
|
api_name=api_name, |
|
) |
|
if time_to_first_token is None: |
|
time_to_first_token = time.time() - t0 |
|
t_taken_s = time.time() - t0 |
|
|
|
if do_lock: |
|
with lock: |
|
self.server_hash = client.server_hash |
|
else: |
|
self.server_hash = client.server_hash |
|
res_dict = ast.literal_eval(res) |
|
self.check_error(res_dict) |
|
response = res_dict["response"] |
|
if langchain_action != LangChainAction.EXTRACT.value: |
|
response = response.strip() |
|
else: |
|
response = [r.strip() for r in ast.literal_eval(response)] |
|
sources = res_dict["sources"] |
|
scores_out = [x["score"] for x in sources] |
|
texts_out = [x["content"] for x in sources] |
|
prompt_raw = res_dict.get("prompt_raw", "") |
|
try: |
|
actual_llm = res_dict["save_dict"][ |
|
"display_name" |
|
] |
|
except Exception as e: |
|
print_warning( |
|
f"Unable to access save_dict to get actual_llm: {str(e)}" |
|
) |
|
try: |
|
extra_dict = res_dict["save_dict"]["extra_dict"] |
|
input_tokens = extra_dict["num_prompt_tokens"] |
|
output_tokens = extra_dict["ntokens"] |
|
tokens_per_second = np.round( |
|
extra_dict["tokens_persecond"], decimals=3 |
|
) |
|
vision_visible_model = extra_dict.get( |
|
"batch_vision_visible_model" |
|
) |
|
vision_batch_input_tokens = extra_dict.get( |
|
"vision_batch_input_tokens", 0 |
|
) |
|
except: |
|
if os.getenv("HARD_ASSERTS"): |
|
raise |
|
if asserts: |
|
if text and not file and not url: |
|
assert any( |
|
text[:cutoff] == texts_out |
|
for cutoff in range(len(text)) |
|
) |
|
assert len(texts_out) == len(scores_out) |
|
|
|
yield ReturnType( |
|
reply=response, |
|
text_context_list=texts_out, |
|
prompt_raw=prompt_raw, |
|
actual_llm=actual_llm, |
|
input_tokens=input_tokens, |
|
output_tokens=output_tokens, |
|
tokens_per_second=tokens_per_second, |
|
time_to_first_token=time_to_first_token or (time.time() - t0), |
|
vision_visible_model=vision_visible_model, |
|
vision_batch_input_tokens=vision_batch_input_tokens, |
|
vision_batch_output_tokens=vision_batch_output_tokens, |
|
vision_batch_tokens_per_second=vision_batch_tokens_per_second, |
|
) |
|
if self.persist: |
|
self.chat_conversation[-1] = (instruction, response) |
|
else: |
|
job = client.submit(str(dict(client_kwargs)), api_name=api_name) |
|
text0 = "" |
|
while not job.done(): |
|
e = check_job(job, timeout=0, raise_exception=False) |
|
if e is not None: |
|
break |
|
outputs_list = job.outputs().copy() |
|
if outputs_list: |
|
res = outputs_list[-1] |
|
res_dict = ast.literal_eval(res) |
|
self.check_error(res_dict) |
|
response = res_dict["response"] |
|
prompt_raw = res_dict.get( |
|
"prompt_raw", "" |
|
) |
|
text_chunk = response[ |
|
len(text0): |
|
] |
|
if not text_chunk: |
|
time.sleep(0.001) |
|
continue |
|
text0 = response |
|
assert text_chunk, "must yield non-empty string" |
|
if time_to_first_token is None: |
|
time_to_first_token = time.time() - t0 |
|
yield ReturnType( |
|
reply=text_chunk, |
|
actual_llm=actual_llm, |
|
) |
|
time.sleep(0.005) |
|
|
|
|
|
res_all = job.outputs().copy() |
|
success = job.communicator.job.latest_status.success |
|
timeout = 0.1 if success else 10 |
|
if len(res_all) > 0: |
|
try: |
|
check_job(job, timeout=timeout, raise_exception=True) |
|
except ( |
|
Exception |
|
) as e: |
|
if "Abrupt termination of communication" in str(e): |
|
t_taken = "%.4f" % (time.time() - t0) |
|
raise TimeoutError( |
|
f"LLM {actual_llm} timed out after {t_taken} seconds." |
|
) |
|
else: |
|
raise |
|
|
|
res = res_all[-1] |
|
res_dict = ast.literal_eval(res) |
|
self.check_error(res_dict) |
|
response = res_dict["response"] |
|
sources = res_dict["sources"] |
|
prompt_raw = res_dict["prompt_raw"] |
|
save_dict = res_dict.get("save_dict", dict(extra_dict={})) |
|
extra_dict = save_dict.get("extra_dict", {}) |
|
texts_out = [x["content"] for x in sources] |
|
t_taken_s = time.time() - t0 |
|
t_taken = "%.4f" % t_taken_s |
|
|
|
if langchain_action != LangChainAction.EXTRACT.value: |
|
text_chunk = response.strip() |
|
else: |
|
text_chunk = [ |
|
r.strip() for r in ast.literal_eval(response) |
|
] |
|
|
|
if not text_chunk: |
|
raise TimeoutError( |
|
f"No output from LLM {actual_llm} after {t_taken} seconds." |
|
) |
|
if "error" in save_dict and not prompt_raw: |
|
raise RuntimeError( |
|
f"Error from LLM {actual_llm}: {save_dict['error']}" |
|
) |
|
assert ( |
|
prompt_raw or extra_dict |
|
), "LLM response failed to return final metadata." |
|
|
|
try: |
|
extra_dict = res_dict["save_dict"]["extra_dict"] |
|
input_tokens = extra_dict["num_prompt_tokens"] |
|
output_tokens = extra_dict["ntokens"] |
|
vision_visible_model = extra_dict.get( |
|
"batch_vision_visible_model" |
|
) |
|
vision_batch_input_tokens = extra_dict.get( |
|
"batch_num_prompt_tokens", 0 |
|
) |
|
vision_batch_output_tokens = extra_dict.get( |
|
"batch_ntokens", 0 |
|
) |
|
tokens_per_second = np.round( |
|
extra_dict["tokens_persecond"], decimals=3 |
|
) |
|
vision_batch_tokens_per_second = extra_dict.get( |
|
"batch_tokens_persecond", 0 |
|
) |
|
if vision_batch_tokens_per_second: |
|
vision_batch_tokens_per_second = np.round( |
|
vision_batch_tokens_per_second, decimals=3 |
|
) |
|
except: |
|
if os.getenv("HARD_ASSERTS"): |
|
raise |
|
try: |
|
actual_llm = res_dict["save_dict"][ |
|
"display_name" |
|
] |
|
except Exception as e: |
|
print_warning( |
|
f"Unable to access save_dict to get actual_llm: {str(e)}" |
|
) |
|
|
|
if text_context_list: |
|
assert texts_out, "No texts_out 1" |
|
|
|
if time_to_first_token is None: |
|
time_to_first_token = time.time() - t0 |
|
yield ReturnType( |
|
reply=text_chunk, |
|
text_context_list=texts_out, |
|
prompt_raw=prompt_raw, |
|
actual_llm=actual_llm, |
|
input_tokens=input_tokens, |
|
output_tokens=output_tokens, |
|
tokens_per_second=tokens_per_second, |
|
time_to_first_token=time_to_first_token, |
|
trial=trial, |
|
vision_visible_model=vision_visible_model, |
|
vision_batch_input_tokens=vision_batch_input_tokens, |
|
vision_batch_output_tokens=vision_batch_output_tokens, |
|
vision_batch_tokens_per_second=vision_batch_tokens_per_second, |
|
) |
|
if self.persist: |
|
self.chat_conversation[-1] = ( |
|
instruction, |
|
text_chunk, |
|
) |
|
else: |
|
assert not success |
|
check_job(job, timeout=2.0 * timeout, raise_exception=True) |
|
if trial > 0 or trial_generation > 0: |
|
print("trial recovered: %s %s" % (trial, trial_generation)) |
|
break |
|
except Exception as e: |
|
if "No generations" in str( |
|
e |
|
) or """'NoneType' object has no attribute 'generations'""" in str( |
|
e |
|
): |
|
trial_generation += 1 |
|
else: |
|
trial += 1 |
|
print_error( |
|
"h2oGPT predict failed: %s %s" |
|
% (str(e), "".join(traceback.format_tb(e.__traceback__))), |
|
) |
|
if "invalid model" in str(e).lower(): |
|
raise |
|
if bad_error_string and bad_error_string in str(e): |
|
|
|
raise |
|
if trial == trials or trial_generation == trials_generation: |
|
print_error( |
|
"trying again failed: %s %s" % (trial, trial_generation) |
|
) |
|
raise |
|
else: |
|
|
|
if "Overloaded" in str(traceback.format_tb(e.__traceback__)): |
|
sleep_time = 30 + 2 ** (trial + 1) |
|
else: |
|
sleep_time = 1 * trial |
|
print_warning( |
|
"trying again: %s in %s seconds" % (trial, sleep_time) |
|
) |
|
time.sleep(sleep_time) |
|
finally: |
|
|
|
if do_lock: |
|
with lock: |
|
self.server_hash = client.server_hash |
|
else: |
|
self.server_hash = client.server_hash |
|
|
|
t1 = time.time() |
|
print_info( |
|
dict( |
|
api="submit_nochat_api", |
|
streaming=stream_output, |
|
texts_in=len(text or []) + len(text_context_list or []), |
|
texts_out=len(texts_out), |
|
images=len(image_file) |
|
if isinstance(image_file, list) |
|
else 1 |
|
if image_file |
|
else 0, |
|
response_time=str(timedelta(seconds=t1 - t0)), |
|
response_len=len(response), |
|
llm=visible_models, |
|
actual_llm=actual_llm, |
|
) |
|
) |
|
finally: |
|
|
|
if do_lock: |
|
with lock: |
|
self.server_hash = client.server_hash |
|
else: |
|
self.server_hash = client.server_hash |
|
|
|
def check_model(self, model): |
|
if model != 0 and self.check_model_name: |
|
valid_llms = self.list_models() |
|
if ( |
|
isinstance(model, int) |
|
and model >= len(valid_llms) |
|
or isinstance(model, str) |
|
and model not in valid_llms |
|
): |
|
did_you_mean = "" |
|
if isinstance(model, str): |
|
alt = difflib.get_close_matches(model, valid_llms, 1) |
|
if alt: |
|
did_you_mean = f"\nDid you mean {repr(alt[0])}?" |
|
raise RuntimeError( |
|
f"Invalid llm: {repr(model)}, must be either an integer between " |
|
f"0 and {len(valid_llms) - 1} or one of the following values: {valid_llms}.{did_you_mean}" |
|
) |
|
|
|
@staticmethod |
|
def _get_ttl_hash(seconds=60): |
|
"""Return the same value within `seconds` time period""" |
|
return round(time.time() / seconds) |
|
|
|
@lru_cache() |
|
def _get_models_full(self, ttl_hash=None, do_lock=False) -> List[Dict[str, Any]]: |
|
""" |
|
Full model info in list if dict (cached) |
|
""" |
|
del ttl_hash |
|
if self.config is None: |
|
self.setup() |
|
client = self.clone() |
|
try: |
|
return ast.literal_eval(client.predict(api_name="/model_names")) |
|
finally: |
|
if do_lock: |
|
with lock: |
|
self.server_hash = client.server_hash |
|
else: |
|
self.server_hash = client.server_hash |
|
|
|
def get_models_full(self, do_lock=False) -> List[Dict[str, Any]]: |
|
""" |
|
Full model info in list if dict |
|
""" |
|
return self._get_models_full(ttl_hash=self._get_ttl_hash(), do_lock=do_lock) |
|
|
|
def list_models(self) -> List[str]: |
|
""" |
|
Model names available from endpoint |
|
""" |
|
return [x["display_name"] for x in self.get_models_full()] |
|
|
|
def simple_stream( |
|
self, |
|
client_kwargs={}, |
|
api_name="/submit_nochat_api", |
|
prompt="", |
|
prompter=None, |
|
sanitize_bot_response=False, |
|
max_time=300, |
|
is_public=False, |
|
raise_exception=True, |
|
verbose=False, |
|
): |
|
job = self.submit(str(dict(client_kwargs)), api_name=api_name) |
|
sources = [] |
|
res_dict = dict( |
|
response="", |
|
sources=sources, |
|
save_dict={}, |
|
llm_answers={}, |
|
response_no_refs="", |
|
sources_str="", |
|
prompt_raw="", |
|
) |
|
yield res_dict |
|
text = "" |
|
text0 = "" |
|
strex = "" |
|
tgen0 = time.time() |
|
while not job.done(): |
|
e = check_job(job, timeout=0, raise_exception=False) |
|
if e is not None: |
|
break |
|
outputs_list = job.outputs().copy() |
|
if outputs_list: |
|
res = outputs_list[-1] |
|
res_dict = ast.literal_eval(res) |
|
text = res_dict["response"] if "response" in res_dict else "" |
|
prompt_and_text = prompt + text |
|
if prompter: |
|
response = prompter.get_response( |
|
prompt_and_text, |
|
prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response, |
|
) |
|
else: |
|
response = text |
|
text_chunk = response[len(text0):] |
|
if not text_chunk: |
|
|
|
time.sleep(0.001) |
|
continue |
|
|
|
text0 = response |
|
res_dict.update( |
|
dict( |
|
response=response, |
|
sources=sources, |
|
error=strex, |
|
response_no_refs=response, |
|
) |
|
) |
|
yield res_dict |
|
if time.time() - tgen0 > max_time: |
|
if verbose: |
|
print( |
|
"Took too long for Gradio: %s" % (time.time() - tgen0), |
|
flush=True, |
|
) |
|
break |
|
time.sleep(0.005) |
|
|
|
res_all = job.outputs().copy() |
|
success = job.communicator.job.latest_status.success |
|
timeout = 0.1 if success else 10 |
|
if len(res_all) > 0: |
|
|
|
e = check_job(job, timeout=timeout, raise_exception=True) |
|
if e is not None: |
|
strex = "".join(traceback.format_tb(e.__traceback__)) |
|
|
|
res = res_all[-1] |
|
res_dict = ast.literal_eval(res) |
|
text = res_dict["response"] |
|
sources = res_dict.get("sources") |
|
if sources is None: |
|
|
|
if is_public: |
|
raise ValueError("Abrupt termination of communication") |
|
else: |
|
raise ValueError("Abrupt termination of communication: %s" % strex) |
|
else: |
|
|
|
|
|
e = check_job(job, timeout=2.0 * timeout, raise_exception=True) |
|
|
|
if e is not None: |
|
stre = str(e) |
|
strex = "".join(traceback.format_tb(e.__traceback__)) |
|
else: |
|
stre = "" |
|
strex = "" |
|
|
|
print( |
|
"Bad final response:%s %s %s: %s %s" |
|
% (res_all, prompt, text, stre, strex), |
|
flush=True, |
|
) |
|
prompt_and_text = prompt + text |
|
if prompter: |
|
response = prompter.get_response( |
|
prompt_and_text, |
|
prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response, |
|
) |
|
else: |
|
response = text |
|
res_dict.update( |
|
dict( |
|
response=response, |
|
sources=sources, |
|
error=strex, |
|
response_no_refs=response, |
|
) |
|
) |
|
yield res_dict |
|
return res_dict |
|
|
|
def stream( |
|
self, |
|
client_kwargs={}, |
|
api_name="/submit_nochat_api", |
|
prompt="", |
|
prompter=None, |
|
sanitize_bot_response=False, |
|
max_time=None, |
|
is_public=False, |
|
raise_exception=True, |
|
verbose=False, |
|
): |
|
strex = "" |
|
e = None |
|
res_dict = {} |
|
try: |
|
res_dict = yield from self._stream( |
|
client_kwargs, |
|
api_name=api_name, |
|
prompt=prompt, |
|
prompter=prompter, |
|
sanitize_bot_response=sanitize_bot_response, |
|
max_time=max_time, |
|
verbose=verbose, |
|
) |
|
except Exception as e: |
|
strex = "".join(traceback.format_tb(e.__traceback__)) |
|
|
|
|
|
if raise_exception: |
|
raise |
|
|
|
if "timeout" in res_dict["save_dict"]["extra_dict"]: |
|
timeout_time = res_dict["save_dict"]["extra_dict"]["timeout"] |
|
raise TimeoutError( |
|
"Timeout from local after %s %s" |
|
% (timeout_time, ": " + strex if e else "") |
|
) |
|
|
|
|
|
if res_dict.get("sources") is None: |
|
|
|
if is_public: |
|
raise ValueError("Abrupt termination of communication") |
|
else: |
|
raise ValueError("Abrupt termination of communication: %s" % strex) |
|
return res_dict |
|
|
|
def _stream( |
|
self, |
|
client_kwargs, |
|
api_name="/submit_nochat_api", |
|
prompt="", |
|
prompter=None, |
|
sanitize_bot_response=False, |
|
max_time=None, |
|
verbose=False, |
|
): |
|
job = self.submit(str(dict(client_kwargs)), api_name=api_name) |
|
|
|
text = "" |
|
sources = [] |
|
save_dict = {} |
|
save_dict["extra_dict"] = {} |
|
res_dict = dict( |
|
response=text, |
|
sources=sources, |
|
save_dict=save_dict, |
|
llm_answers={}, |
|
response_no_refs=text, |
|
sources_str="", |
|
prompt_raw="", |
|
) |
|
yield res_dict |
|
|
|
text0 = "" |
|
tgen0 = time.time() |
|
n = 0 |
|
for res in job: |
|
res_dict, text0 = yield from self.yield_res( |
|
res, |
|
res_dict, |
|
prompt, |
|
prompter, |
|
sanitize_bot_response, |
|
max_time, |
|
text0, |
|
tgen0, |
|
verbose, |
|
) |
|
n += 1 |
|
if "timeout" in res_dict["save_dict"]["extra_dict"]: |
|
break |
|
|
|
outputs = job.outputs().copy() |
|
all_n = len(outputs) |
|
for nn in range(n, all_n): |
|
res = outputs[nn] |
|
res_dict, text0 = yield from self.yield_res( |
|
res, |
|
res_dict, |
|
prompt, |
|
prompter, |
|
sanitize_bot_response, |
|
max_time, |
|
text0, |
|
tgen0, |
|
verbose, |
|
) |
|
return res_dict |
|
|
|
@staticmethod |
|
def yield_res( |
|
res, |
|
res_dict, |
|
prompt, |
|
prompter, |
|
sanitize_bot_response, |
|
max_time, |
|
text0, |
|
tgen0, |
|
verbose, |
|
): |
|
do_yield = True |
|
res_dict_server = ast.literal_eval(res) |
|
|
|
text = res_dict_server["response"] |
|
if text is None: |
|
print("text None", flush=True) |
|
text = "" |
|
if prompter: |
|
response = prompter.get_response( |
|
prompt + text, |
|
prompt=prompt, |
|
sanitize_bot_response=sanitize_bot_response, |
|
) |
|
else: |
|
response = text |
|
text_chunk = response[len(text0):] |
|
if not text_chunk: |
|
|
|
time.sleep(0.001) |
|
do_yield = False |
|
|
|
text0 = response |
|
res_dict.update(res_dict_server) |
|
res_dict.update(dict(response=response, response_no_refs=response)) |
|
|
|
timeout_time_other = ( |
|
res_dict.get("save_dict", {}).get("extra_dict", {}).get("timeout") |
|
) |
|
if timeout_time_other: |
|
if verbose: |
|
print( |
|
"Took too long for other Gradio: %s" % (time.time() - tgen0), |
|
flush=True, |
|
) |
|
return res_dict, text0 |
|
|
|
timeout_time = time.time() - tgen0 |
|
if max_time is not None and timeout_time > max_time: |
|
if "save_dict" not in res_dict: |
|
res_dict["save_dict"] = {} |
|
if "extra_dict" not in res_dict["save_dict"]: |
|
res_dict["save_dict"]["extra_dict"] = {} |
|
res_dict["save_dict"]["extra_dict"]["timeout"] = timeout_time |
|
yield res_dict |
|
if verbose: |
|
print( |
|
"Took too long for Gradio: %s" % (time.time() - tgen0), flush=True |
|
) |
|
return res_dict, text0 |
|
if do_yield: |
|
yield res_dict |
|
time.sleep(0.005) |
|
return res_dict, text0 |
|
|
|
|
|
class H2OGradioClient(CommonClient, Client): |
|
""" |
|
Parent class of gradio client |
|
To handle automatically refreshing client if detect gradio server changed |
|
""" |
|
|
|
def reset_session(self) -> None: |
|
self.session_hash = str(uuid.uuid4()) |
|
if hasattr(self, "include_heartbeat") and self.include_heartbeat: |
|
self._refresh_heartbeat.set() |
|
|
|
def __init__( |
|
self, |
|
src: str, |
|
hf_token: str | None = None, |
|
max_workers: int = 40, |
|
serialize: bool | None = None, |
|
output_dir: str |
|
| Path = DEFAULT_TEMP_DIR, |
|
verbose: bool = False, |
|
auth: tuple[str, str] | None = None, |
|
*, |
|
headers: dict[str, str] | None = None, |
|
upload_files: bool = True, |
|
download_files: bool = True, |
|
_skip_components: bool = True, |
|
|
|
ssl_verify: bool = True, |
|
h2ogpt_key: str = None, |
|
persist: bool = False, |
|
check_hash: bool = True, |
|
check_model_name: bool = False, |
|
include_heartbeat: bool = False, |
|
): |
|
""" |
|
Parameters: |
|
Base Class parameters |
|
+ |
|
h2ogpt_key: h2oGPT key to gain access to the server |
|
persist: whether to persist the state, so repeated calls are aware of the prior user session |
|
This allows the scratch MyData to be reused, etc. |
|
This also maintains the chat_conversation history |
|
check_hash: whether to check git hash for consistency between server and client to ensure API always up to date |
|
check_model_name: whether to check the model name here (adds delays), or just let server fail (faster) |
|
""" |
|
if serialize is None: |
|
|
|
|
|
serialize = False |
|
self.args = tuple([src]) |
|
self.kwargs = dict( |
|
hf_token=hf_token, |
|
max_workers=max_workers, |
|
serialize=serialize, |
|
output_dir=output_dir, |
|
verbose=verbose, |
|
h2ogpt_key=h2ogpt_key, |
|
persist=persist, |
|
check_hash=check_hash, |
|
check_model_name=check_model_name, |
|
include_heartbeat=include_heartbeat, |
|
) |
|
if is_gradio_client_version7plus: |
|
|
|
|
|
|
|
|
|
|
|
self._skip_components = _skip_components |
|
self.ssl_verify = ssl_verify |
|
self.kwargs.update( |
|
dict( |
|
auth=auth, |
|
upload_files=upload_files, |
|
download_files=download_files, |
|
ssl_verify=ssl_verify, |
|
) |
|
) |
|
|
|
self.verbose = verbose |
|
self.hf_token = hf_token |
|
if serialize is not None: |
|
warnings.warn( |
|
"The `serialize` parameter is deprecated and will be removed. Please use the equivalent `upload_files` parameter instead." |
|
) |
|
upload_files = serialize |
|
self.serialize = serialize |
|
self.upload_files = upload_files |
|
self.download_files = download_files |
|
self.space_id = None |
|
self.cookies: dict[str, str] = {} |
|
if is_gradio_client_version7plus: |
|
self.output_dir = ( |
|
str(output_dir) if isinstance(output_dir, Path) else output_dir |
|
) |
|
else: |
|
self.output_dir = output_dir |
|
self.max_workers = max_workers |
|
self.src = src |
|
self.auth = auth |
|
self.headers = headers |
|
|
|
self.config = None |
|
self.h2ogpt_key = h2ogpt_key |
|
self.persist = persist |
|
self.check_hash = check_hash |
|
self.check_model_name = check_model_name |
|
self.include_heartbeat = include_heartbeat |
|
|
|
self.chat_conversation = [] |
|
self.server_hash = None |
|
|
|
def __repr__(self): |
|
if self.config and False: |
|
|
|
return self.view_api(print_info=False, return_format="str") |
|
return "Not setup for %s" % self.src |
|
|
|
def __str__(self): |
|
if self.config and False: |
|
|
|
return self.view_api(print_info=False, return_format="str") |
|
return "Not setup for %s" % self.src |
|
|
|
def setup(self): |
|
src = self.src |
|
|
|
headers0 = self.headers |
|
self.headers = build_hf_headers( |
|
token=self.hf_token, |
|
library_name="gradio_client", |
|
library_version=utils.__version__, |
|
) |
|
if headers0: |
|
self.headers.update(headers0) |
|
if ( |
|
"authorization" in self.headers |
|
and self.headers["authorization"] == "Bearer " |
|
): |
|
self.headers["authorization"] = "Bearer hf_xx" |
|
if src.startswith("http://") or src.startswith("https://"): |
|
_src = src if src.endswith("/") else src + "/" |
|
else: |
|
_src = self._space_name_to_src(src) |
|
if _src is None: |
|
raise ValueError( |
|
f"Could not find Space: {src}. If it is a private Space, please provide an hf_token." |
|
) |
|
self.space_id = src |
|
self.src = _src |
|
state = self._get_space_state() |
|
if state == SpaceStage.BUILDING: |
|
if self.verbose: |
|
print("Space is still building. Please wait...") |
|
while self._get_space_state() == SpaceStage.BUILDING: |
|
time.sleep(2) |
|
pass |
|
if state in utils.INVALID_RUNTIME: |
|
raise ValueError( |
|
f"The current space is in the invalid state: {state}. " |
|
"Please contact the owner to fix this." |
|
) |
|
if self.verbose: |
|
print(f"Loaded as API: {self.src} ✔") |
|
|
|
if is_gradio_client_version7plus: |
|
if self.auth is not None: |
|
self._login(self.auth) |
|
|
|
self.config = self._get_config() |
|
self.api_url = urllib.parse.urljoin(self.src, utils.API_URL) |
|
if is_gradio_client_version7plus: |
|
self.protocol: Literal[ |
|
"ws", "sse", "sse_v1", "sse_v2", "sse_v2.1" |
|
] = self.config.get("protocol", "ws") |
|
self.sse_url = urllib.parse.urljoin( |
|
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL |
|
) |
|
if hasattr(utils, "HEARTBEAT_URL") and self.include_heartbeat: |
|
self.heartbeat_url = urllib.parse.urljoin(self.src, utils.HEARTBEAT_URL) |
|
else: |
|
self.heartbeat_url = None |
|
self.sse_data_url = urllib.parse.urljoin( |
|
self.src, |
|
utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL, |
|
) |
|
self.ws_url = urllib.parse.urljoin( |
|
self.src.replace("http", "ws", 1), utils.WS_URL |
|
) |
|
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL) |
|
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL) |
|
if is_gradio_client_version7plus: |
|
self.app_version = version.parse(self.config.get("version", "2.0")) |
|
self._info = self._get_api_info() |
|
self.session_hash = str(uuid.uuid4()) |
|
|
|
self.get_endpoints(self) |
|
|
|
|
|
|
|
if ( |
|
is_gradio_client_version7plus |
|
and hasattr(utils, "HEARTBEAT_URL") |
|
and self.include_heartbeat |
|
): |
|
self._refresh_heartbeat = threading.Event() |
|
self._kill_heartbeat = threading.Event() |
|
|
|
self.heartbeat = threading.Thread( |
|
target=self._stream_heartbeat, daemon=True |
|
) |
|
self.heartbeat.start() |
|
|
|
self.server_hash = self.get_server_hash() |
|
|
|
return self |
|
|
|
@staticmethod |
|
def get_endpoints(client, verbose=False): |
|
t0 = time.time() |
|
|
|
client.executor = concurrent.futures.ThreadPoolExecutor( |
|
max_workers=client.max_workers |
|
) |
|
if is_gradio_client_version7plus: |
|
from gradio_client.client import EndpointV3Compatibility |
|
|
|
endpoint_class = ( |
|
Endpoint |
|
if client.protocol.startswith("sse") |
|
else EndpointV3Compatibility |
|
) |
|
else: |
|
endpoint_class = Endpoint |
|
|
|
if is_gradio_client_version7plus: |
|
client.endpoints = [ |
|
endpoint_class(client, fn_index, dependency, client.protocol) |
|
for fn_index, dependency in enumerate(client.config["dependencies"]) |
|
] |
|
else: |
|
client.endpoints = [ |
|
endpoint_class(client, fn_index, dependency) |
|
for fn_index, dependency in enumerate(client.config["dependencies"]) |
|
] |
|
if is_gradio_client_version7plus: |
|
client.stream_open = False |
|
client.streaming_future = None |
|
from gradio_client.utils import Message |
|
|
|
client.pending_messages_per_event = {} |
|
client.pending_event_ids = set() |
|
if verbose: |
|
print("duration endpoints: %s" % (time.time() - t0), flush=True) |
|
|
|
@staticmethod |
|
def is_full_git_hash(s): |
|
|
|
return bool(re.fullmatch(r"[0-9a-f]{40}", s)) |
|
|
|
def get_server_hash(self) -> str: |
|
return self._get_server_hash(ttl_hash=self._get_ttl_hash()) |
|
|
|
def _get_server_hash(self, ttl_hash=None) -> str: |
|
""" |
|
Get server hash using super without any refresh action triggered |
|
Returns: git hash of gradio server |
|
""" |
|
del ttl_hash |
|
t0 = time.time() |
|
if self.config is None: |
|
self.setup() |
|
t1 = time.time() |
|
ret = "GET_GITHASH_UNSET" |
|
try: |
|
if self.check_hash: |
|
ret = super().submit(api_name="/system_hash").result() |
|
assert self.is_full_git_hash(ret), f"ret is not a full git hash: {ret}" |
|
return ret |
|
finally: |
|
if self.verbose: |
|
print( |
|
"duration server_hash: %s full time: %s system_hash time: %s" |
|
% (ret, time.time() - t0, time.time() - t1), |
|
flush=True, |
|
) |
|
|
|
def refresh_client_if_should(self): |
|
if self.config is None: |
|
self.setup() |
|
|
|
|
|
server_hash = self.get_server_hash() |
|
if self.server_hash != server_hash: |
|
if self.verbose: |
|
print( |
|
"server hash changed: %s %s" % (self.server_hash, server_hash), |
|
flush=True, |
|
) |
|
if self.server_hash is not None and self.persist: |
|
if self.verbose: |
|
print( |
|
"Failed to persist due to server hash change, only kept chat_conversation not user session hash", |
|
flush=True, |
|
) |
|
|
|
self.refresh_client() |
|
self.server_hash = server_hash |
|
|
|
def refresh_client(self): |
|
""" |
|
Ensure every client call is independent |
|
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code) |
|
Returns: |
|
""" |
|
if self.config is None: |
|
self.setup() |
|
|
|
kwargs = self.kwargs.copy() |
|
kwargs.pop("h2ogpt_key", None) |
|
kwargs.pop("persist", None) |
|
kwargs.pop("check_hash", None) |
|
kwargs.pop("check_model_name", None) |
|
kwargs.pop("include_heartbeat", None) |
|
ntrials = 3 |
|
client = None |
|
for trial in range(0, ntrials): |
|
try: |
|
client = Client(*self.args, **kwargs) |
|
break |
|
except ValueError as e: |
|
if trial >= ntrials: |
|
raise |
|
else: |
|
if self.verbose: |
|
print("Trying refresh %d/%d %s" % (trial, ntrials - 1, str(e))) |
|
trial += 1 |
|
time.sleep(10) |
|
if client is None: |
|
raise RuntimeError("Failed to get new client") |
|
session_hash0 = self.session_hash if self.persist else None |
|
for k, v in client.__dict__.items(): |
|
setattr(self, k, v) |
|
if session_hash0: |
|
|
|
self.session_hash = session_hash0 |
|
if self.verbose: |
|
print("Hit refresh_client(): %s %s" % (self.session_hash, session_hash0)) |
|
|
|
self.server_hash = self.get_server_hash() |
|
|
|
def clone(self, do_lock=False): |
|
if do_lock: |
|
with lock: |
|
return self._clone() |
|
else: |
|
return self._clone() |
|
|
|
def _clone(self): |
|
if self.config is None: |
|
self.setup() |
|
client = self.__class__("") |
|
for k, v in self.__dict__.items(): |
|
setattr(client, k, v) |
|
client.reset_session() |
|
|
|
self.get_endpoints(client) |
|
|
|
|
|
client.server_hash = self.server_hash |
|
client.chat_conversation = self.chat_conversation |
|
return client |
|
|
|
def submit( |
|
self, |
|
*args, |
|
api_name: str | None = None, |
|
fn_index: int | None = None, |
|
result_callbacks: Callable | list[Callable] | None = None, |
|
exception_handling=True, |
|
) -> Job: |
|
if self.config is None: |
|
self.setup() |
|
|
|
try: |
|
self.refresh_client_if_should() |
|
job = super().submit(*args, api_name=api_name, fn_index=fn_index) |
|
except Exception as e: |
|
ex = traceback.format_exc() |
|
print( |
|
"Hit e=%s\n\n%s\n\n%s" |
|
% (str(ex), traceback.format_exc(), self.__dict__), |
|
flush=True, |
|
) |
|
|
|
self.refresh_client() |
|
job = super().submit(*args, api_name=api_name, fn_index=fn_index) |
|
|
|
if exception_handling: |
|
|
|
e = check_job(job, timeout=0.01, raise_exception=False) |
|
if e is not None: |
|
print( |
|
"GR job failed: %s %s" |
|
% (str(e), "".join(traceback.format_tb(e.__traceback__))), |
|
flush=True, |
|
) |
|
|
|
self.refresh_client() |
|
job = super().submit(*args, api_name=api_name, fn_index=fn_index) |
|
e2 = check_job(job, timeout=0.1, raise_exception=False) |
|
if e2 is not None: |
|
print( |
|
"GR job failed again: %s\n%s" |
|
% (str(e2), "".join(traceback.format_tb(e2.__traceback__))), |
|
flush=True, |
|
) |
|
|
|
return job |
|
|
|
|
|
class CloneableGradioClient(CommonClient, Client): |
|
def __init__(self, *args, **kwargs): |
|
self._original_config = None |
|
self._original_info = None |
|
self._original_endpoints = None |
|
self._original_executor = None |
|
self._original_heartbeat = None |
|
self._quiet = kwargs.pop('quiet', False) |
|
super().__init__(*args, **kwargs) |
|
self._initialize_session_specific() |
|
self._initialize_shared_info() |
|
atexit.register(self.cleanup) |
|
self.auth = kwargs.get('auth') |
|
|
|
def _initialize_session_specific(self): |
|
"""Initialize or reset session-specific attributes.""" |
|
self.session_hash = str(uuid.uuid4()) |
|
self._refresh_heartbeat = threading.Event() |
|
self._kill_heartbeat = threading.Event() |
|
self.stream_open = False |
|
self.streaming_future = None |
|
self.pending_messages_per_event = {} |
|
self.pending_event_ids = set() |
|
|
|
def _initialize_shared_info(self): |
|
"""Initialize information that can be shared across clones.""" |
|
if self._original_config is None: |
|
self._original_config = super().config |
|
if self._original_info is None: |
|
self._original_info = super()._info |
|
if self._original_endpoints is None: |
|
self._original_endpoints = super().endpoints |
|
if self._original_executor is None: |
|
self._original_executor = super().executor |
|
if self._original_heartbeat is None: |
|
self._original_heartbeat = super().heartbeat |
|
|
|
@property |
|
def config(self): |
|
return self._original_config |
|
|
|
@config.setter |
|
def config(self, value): |
|
self._original_config = value |
|
|
|
@property |
|
def _info(self): |
|
return self._original_info |
|
|
|
@_info.setter |
|
def _info(self, value): |
|
self._original_info = value |
|
|
|
@property |
|
def endpoints(self): |
|
return self._original_endpoints |
|
|
|
@endpoints.setter |
|
def endpoints(self, value): |
|
self._original_endpoints = value |
|
|
|
@property |
|
def executor(self): |
|
return self._original_executor |
|
|
|
@executor.setter |
|
def executor(self, value): |
|
self._original_executor = value |
|
|
|
@property |
|
def heartbeat(self): |
|
return self._original_heartbeat |
|
|
|
@heartbeat.setter |
|
def heartbeat(self, value): |
|
self._original_heartbeat = value |
|
|
|
def setup(self): |
|
|
|
pass |
|
|
|
@staticmethod |
|
def _get_ttl_hash(seconds=60): |
|
"""Return the same value within `seconds` time period""" |
|
return round(time.time() / seconds) |
|
|
|
def get_server_hash(self) -> str: |
|
return self._get_server_hash(ttl_hash=self._get_ttl_hash()) |
|
|
|
def _get_server_hash(self, ttl_hash=None): |
|
del ttl_hash |
|
return self.predict(api_name="/system_hash") |
|
|
|
def clone(self): |
|
"""Create a new CloneableGradioClient instance with the same configuration but a new session.""" |
|
new_client = copy.copy(self) |
|
new_client._initialize_session_specific() |
|
new_client._quiet = True |
|
atexit.register(new_client.cleanup) |
|
return new_client |
|
|
|
def __repr__(self): |
|
if self._quiet: |
|
return f"<CloneableGradioClient (quiet) connected to {self.src}>" |
|
return super().__repr__() |
|
|
|
def __str__(self): |
|
if self._quiet: |
|
return f"CloneableGradioClient (quiet) connected to {self.src}" |
|
return super().__str__() |
|
|
|
def cleanup(self): |
|
"""Clean up resources used by this client.""" |
|
if self._original_executor: |
|
self._original_executor.shutdown(wait=False) |
|
if self._kill_heartbeat: |
|
self._kill_heartbeat.set() |
|
if self._original_heartbeat: |
|
self._original_heartbeat.join(timeout=1) |
|
atexit.unregister(self.cleanup) |
|
|
|
|
|
if old_gradio: |
|
GradioClient = H2OGradioClient |
|
else: |
|
GradioClient = CloneableGradioClient |
|
|