aiben / src /client_test.py
abugaber's picture
Upload folder using huggingface_hub
3943768 verified
"""
Client test.
Run server:
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
NOTE: For private models, add --use-auth_token=True
NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
Currently, this will force model to be on a single GPU.
Then run this client as:
python src/client_test.py
For HF spaces:
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
Result:
Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
For demo:
HOST="https://gpt.h2o.ai" python src/client_test.py
Result:
Loaded as API: https://gpt.h2o.ai ✔
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
{'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
"""
import ast
import time
import os
import markdown # pip install markdown
import pytest
from bs4 import BeautifulSoup # pip install beautifulsoup4
from utils import is_gradio_version4
try:
from enums import DocumentSubset, LangChainAction
except:
from enums import DocumentSubset, LangChainAction
from tests.utils import get_inf_server
debug = False
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
def get_client(serialize=not is_gradio_version4):
from gradio_client import Client
client = Client(get_inf_server(), serialize=serialize)
if debug:
print(client.view_api(all_endpoints=True))
return client
def get_args(prompt, prompt_type=None, chat=False,
stream_output=False,
enable_caching=False,
max_new_tokens=50,
top_k_docs=3,
langchain_mode='Disabled',
add_chat_history_to_context=True,
langchain_action=LangChainAction.QUERY.value,
langchain_agents=[],
prompt_dict=None,
chat_template=None,
version=None,
h2ogpt_key=None,
visible_models=None,
visible_image_models=None,
image_size=None,
image_quality=None,
image_guidance_scale=None,
image_num_inference_steps=None,
system_prompt='', # default of no system prompt triggered by empty string
add_search_to_context=False,
chat_conversation=None,
text_context_list=None,
document_choice=[],
document_source_substrings=[],
document_source_substrings_op='and',
document_content_substrings=[],
document_content_substrings_op='and',
max_time=40, # nominally want test to complete, not exercise timeout code (llama.cpp gets stuck behind file lock if prior generation is still going)
repetition_penalty=1.0,
do_sample=True,
seed=0,
metadata_in_context=[],
):
from collections import OrderedDict
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
iinput='', # only for chat=True
context='',
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
enable_caching=enable_caching,
prompt_type=prompt_type,
prompt_dict=prompt_dict,
chat_template=chat_template,
temperature=0.1,
top_p=1.0,
top_k=40,
penalty_alpha=0,
num_beams=1,
max_new_tokens=max_new_tokens,
min_new_tokens=0,
early_stopping=False,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
do_sample=do_sample,
seed=seed,
chat=chat,
instruction_nochat=prompt if not chat else '',
iinput_nochat='', # only for chat=False
langchain_mode=langchain_mode,
add_chat_history_to_context=add_chat_history_to_context,
langchain_action=langchain_action,
langchain_agents=langchain_agents,
top_k_docs=top_k_docs,
chunk=True,
chunk_size=512,
document_subset=DocumentSubset.Relevant.name,
document_choice=[] or document_choice,
document_source_substrings=[] or document_source_substrings,
document_source_substrings_op='and' or document_source_substrings_op,
document_content_substrings=[] or document_content_substrings,
document_content_substrings_op='and' or document_content_substrings_op,
pre_prompt_query=None,
prompt_query=None,
pre_prompt_summary=None,
prompt_summary=None,
hyde_llm_prompt=None,
all_docs_start_prompt=None,
all_docs_finish_prompt=None,
user_prompt_for_fake_system_prompt=None,
json_object_prompt=None,
json_object_prompt_simpler=None,
json_code_prompt=None,
json_code_prompt_if_no_schema=None,
json_schema_instruction=None,
json_preserve_system_prompt=None,
json_object_post_prompt_reminder=None,
json_code_post_prompt_reminder=None,
json_code2_post_prompt_reminder=None,
system_prompt=system_prompt,
image_audio_loaders=None,
pdf_loaders=None,
url_loaders=None,
jq_schema=None,
extract_frames=None,
llava_prompt=None,
visible_models=visible_models,
visible_image_models=visible_image_models,
image_size=image_size,
image_quality=image_quality,
image_guidance_scale=image_guidance_scale,
image_num_inference_steps=image_num_inference_steps,
h2ogpt_key=h2ogpt_key,
add_search_to_context=add_search_to_context,
chat_conversation=chat_conversation,
text_context_list=text_context_list,
docs_ordering_type=None,
min_max_new_tokens=None,
max_input_tokens=None,
max_total_input_tokens=None,
docs_token_handling=None,
docs_joiner=None,
hyde_level=0,
hyde_template=None,
hyde_show_only_final=False,
doc_json_mode=False,
metadata_in_context=metadata_in_context,
chatbot_role='None',
speaker='None',
tts_language='autodetect',
tts_speed=1.0,
image_file=None,
image_control=None,
images_num_max=None,
image_resolution=None,
image_format=None,
rotate_align_resize_image=None,
video_frame_period=None,
image_batch_image_prompt=None,
image_batch_final_prompt=None,
image_batch_stream=None,
visible_vision_models=None,
video_file=None,
response_format=None,
guided_json=None,
guided_regex=None,
guided_choice=None,
guided_grammar=None,
guided_whitespace_pattern=None,
model_lock=None,
client_metadata=None,
)
diff = 0
from evaluate_params import eval_func_param_names
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == diff
assert eval_func_param_names == list(kwargs.keys())
if chat:
# add chatbot output on end. Assumes serialize=False
kwargs.update(dict(chatbot=[]))
return kwargs, list(dict(kwargs).values())
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic(prompt_type='human_bot', version=None, visible_models=None, prompt='Who are you?',
h2ogpt_key=None):
return run_client_nochat(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50, version=version,
visible_models=visible_models, h2ogpt_key=h2ogpt_key)
"""
time HOST=https://gpt-internal.h2o.ai PYTHONPATH=. pytest -n 20 src/client_test.py::test_client_basic_benchmark
32 seconds to answer 20 questions at once with 70B llama2 on 4x A100 80GB using TGI 0.9.3
"""
@pytest.mark.skip(reason="For manual use against some server, no server launched")
@pytest.mark.parametrize("id", range(20))
def test_client_basic_benchmark(id, prompt_type='human_bot', version=None):
return run_client_nochat(prompt="""
/nfs4/llm/h2ogpt/h2ogpt/bin/python /home/arno/pycharm-2022.2.2/plugins/python/helpers/pycharm/_jb_pytest_runner.py --target src/client_test.py::test_client_basic
Testing started at 8:41 AM ...
Launching pytest with arguments src/client_test.py::test_client_basic --no-header --no-summary -q in /nfs4/llm/h2ogpt
============================= test session starts ==============================
collecting ...
src/client_test.py:None (src/client_test.py)
ImportError while importing test module '/nfs4/llm/h2ogpt/src/client_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
h2ogpt/lib/python3.10/site-packages/_pytest/python.py:618: in _importtestmodule
mod = import_path(self.path, mode=importmode, root=self.config.rootpath)
h2ogpt/lib/python3.10/site-packages/_pytest/pathlib.py:533: in import_path
importlib.import_module(module_name)
/usr/lib/python3.10/importlib/__init__.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1050: in _gcd_import
???
<frozen importlib._bootstrap>:1027: in _find_and_load
???
<frozen importlib._bootstrap>:1006: in _find_and_load_unlocked
???
<frozen importlib._bootstrap>:688: in _load_unlocked
???
h2ogpt/lib/python3.10/site-packages/_pytest/assertion/rewrite.py:168: in exec_module
exec(co, module.__dict__)
src/client_test.py:51: in <module>
from enums import DocumentSubset, LangChainAction
E ModuleNotFoundError: No module named 'enums'
collected 0 items / 1 error
=============================== 1 error in 0.14s ===============================
ERROR: not found: /nfs4/llm/h2ogpt/src/client_test.py::test_client_basic
(no name '/nfs4/llm/h2ogpt/src/client_test.py::test_client_basic' in any of [<Module client_test.py>])
Process finished with exit code 4
What happened?
""", prompt_type=prompt_type, max_new_tokens=100, version=version)
def run_client_nochat(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None, visible_models=None):
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version,
visible_models=visible_models, h2ogpt_key=h2ogpt_key)
api_name = '/submit_nochat'
client = get_client(serialize=not is_gradio_version4)
res = client.predict(
*tuple(args),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
response=md_to_text(res))
print(res_dict)
return res_dict, client
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic_api(prompt_type='human_bot', version=None, h2ogpt_key=None):
return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50, version=version,
h2ogpt_key=h2ogpt_key)
def run_client_nochat_api(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None):
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version,
h2ogpt_key=h2ogpt_key)
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
client = get_client(serialize=not is_gradio_version4)
res = client.predict(
str(dict(kwargs)),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
response=md_to_text(ast.literal_eval(res)['response']),
sources=ast.literal_eval(res)['sources'])
print(res_dict)
return res_dict, client
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic_api_lean(prompt='Who are you?', prompt_type='human_bot', version=None, h2ogpt_key=None,
chat_conversation=None, system_prompt=''):
return run_client_nochat_api_lean(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50,
version=version, h2ogpt_key=h2ogpt_key,
chat_conversation=chat_conversation,
system_prompt=system_prompt)
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None,
chat_conversation=None, system_prompt=''):
kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key, chat_conversation=chat_conversation,
system_prompt=system_prompt)
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
client = get_client(serialize=not is_gradio_version4)
res = client.predict(
str(dict(kwargs)),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'],
response=md_to_text(ast.literal_eval(res)['response']),
sources=ast.literal_eval(res)['sources'],
h2ogpt_key=h2ogpt_key)
print(res_dict)
return res_dict, client
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_basic_api_lean_morestuff(prompt_type='human_bot', version=None, h2ogpt_key=None):
return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50,
version=version, h2ogpt_key=h2ogpt_key)
def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512, version=None,
h2ogpt_key=None):
kwargs = dict(
instruction='',
iinput='',
context='',
stream_output=False,
prompt_type=prompt_type,
temperature=0.1,
top_p=1.0,
top_k=40,
penalty_alpha=0,
num_beams=1,
max_new_tokens=1024,
min_new_tokens=0,
early_stopping=False,
max_time=20,
repetition_penalty=1.0,
num_return_sequences=1,
do_sample=True,
seed=0,
chat=False,
instruction_nochat=prompt,
iinput_nochat='',
langchain_mode='Disabled',
add_chat_history_to_context=True,
langchain_action=LangChainAction.QUERY.value,
langchain_agents=[],
top_k_docs=4,
document_subset=DocumentSubset.Relevant.name,
document_choice=[],
document_source_substrings=[],
document_source_substrings_op='and',
document_content_substrings=[],
document_content_substrings_op='and',
h2ogpt_key=h2ogpt_key,
add_search_to_context=False,
)
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
client = get_client(serialize=not is_gradio_version4)
res = client.predict(
str(dict(kwargs)),
api_name=api_name,
)
print("Raw client result: %s" % res, flush=True)
res_dict = dict(prompt=kwargs['instruction_nochat'],
response=md_to_text(ast.literal_eval(res)['response']),
sources=ast.literal_eval(res)['sources'],
h2ogpt_key=h2ogpt_key)
print(res_dict)
return res_dict, client
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_chat(prompt_type='human_bot', version=None, h2ogpt_key=None):
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
langchain_mode='Disabled',
langchain_action=LangChainAction.QUERY.value,
langchain_agents=[],
version=version,
h2ogpt_key=h2ogpt_key)
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_chat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None):
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
stream_output=True, max_new_tokens=512,
langchain_mode='Disabled',
langchain_action=LangChainAction.QUERY.value,
langchain_agents=[],
version=version,
h2ogpt_key=h2ogpt_key)
def run_client_chat(prompt='',
stream_output=None,
max_new_tokens=128,
langchain_mode='Disabled',
langchain_action=LangChainAction.QUERY.value,
langchain_agents=[],
prompt_type=None, prompt_dict=None, chat_template=None,
version=None,
h2ogpt_key=None,
chat_conversation=None,
system_prompt='',
document_choice=[],
document_content_substrings=[],
document_content_substrings_op='and',
document_source_substrings=[],
document_source_substrings_op='and',
top_k_docs=3,
max_time=20,
repetition_penalty=1.0,
do_sample=True,
seed=0,
):
client = get_client(serialize=False)
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
max_new_tokens=max_new_tokens,
langchain_mode=langchain_mode,
langchain_action=langchain_action,
langchain_agents=langchain_agents,
prompt_dict=prompt_dict,
chat_template=chat_template,
version=version,
h2ogpt_key=h2ogpt_key,
chat_conversation=chat_conversation,
system_prompt=system_prompt,
document_choice=document_choice,
document_source_substrings=document_source_substrings,
document_source_substrings_op=document_source_substrings_op,
document_content_substrings=document_content_substrings,
document_content_substrings_op=document_content_substrings_op,
top_k_docs=top_k_docs,
max_time=max_time,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
seed=seed,
)
return run_client(client, prompt, args, kwargs)
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
if is_gradio_version4:
kwargs['answer_with_sources'] = True
kwargs['sources_show_text_in_accordion'] = True
kwargs['append_sources_to_answer'] = True
kwargs['append_sources_to_chat'] = False
kwargs['show_link_in_sources'] = True
res_dict, client = run_client_gen(client, kwargs, do_md_to_text=do_md_to_text)
res_dict['response'] += str(res_dict.get('sources_str', ''))
return res_dict, client
# FIXME: https://github.com/gradio-app/gradio/issues/6592
assert kwargs['chat'], "Chat mode only"
res = client.predict(*tuple(args), api_name='/instruction')
args[-1] += [res[-1]]
res_dict = kwargs
res_dict['prompt'] = prompt
if not kwargs['stream_output']:
res = client.predict(*tuple(args), api_name='/instruction_bot')
res_dict['response'] = res[0][-1][1]
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
return res_dict, client
else:
job = client.submit(*tuple(args), api_name='/instruction_bot')
res1 = ''
while not job.done():
outputs_list = job.outputs().copy()
if outputs_list:
res = outputs_list[-1]
res1 = res[0][-1][-1]
res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
print(res1)
time.sleep(0.1)
full_outputs = job.outputs().copy()
if verbose:
print('job.outputs: %s' % str(full_outputs))
# ensure get ending to avoid race
# -1 means last response if streaming
# 0 means get text_output, ignore exception_text
# 0 means get list within text_output that looks like [[prompt], [answer]]
# 1 means get bot answer, so will have last bot answer
res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
return res_dict, client
@pytest.mark.skip(reason="For manual use against some server, no server launched")
def test_client_nochat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None):
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
stream_output=True, max_new_tokens=512,
langchain_mode='Disabled',
langchain_action=LangChainAction.QUERY.value,
langchain_agents=[],
version=version,
h2ogpt_key=h2ogpt_key)
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
langchain_mode, langchain_action, langchain_agents, version=None,
h2ogpt_key=None):
client = get_client(serialize=False)
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
langchain_action=langchain_action, langchain_agents=langchain_agents,
version=version, h2ogpt_key=h2ogpt_key)
return run_client_gen(client, kwargs)
def run_client_gen(client, kwargs, do_md_to_text=True):
res_dict = kwargs
res_dict['prompt'] = kwargs['instruction'] or kwargs['instruction_nochat']
if not kwargs['stream_output']:
res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
res_dict1 = ast.literal_eval(res)
res_dict.update(res_dict1)
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
return res_dict, client
else:
job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
while not job.done():
outputs_list = job.outputs().copy()
if outputs_list:
res = outputs_list[-1]
res_dict1 = ast.literal_eval(res)
print('Stream: %s' % res_dict1['response'])
time.sleep(0.1)
res_list = job.outputs().copy()
assert len(res_list) > 0, "No response, check server"
res = res_list[-1]
res_dict1 = ast.literal_eval(res)
print('Final: %s' % res_dict1['response'])
res_dict.update(res_dict1)
return res_dict, client
def md_to_text(md, do_md_to_text=True):
if not do_md_to_text:
return md
assert md is not None, "Markdown is None"
html = markdown.markdown(md)
soup = BeautifulSoup(html, features='html.parser')
return soup.get_text()
def run_client_many(prompt_type='human_bot', version=None, h2ogpt_key=None):
kwargs = dict(prompt_type=prompt_type, version=version, h2ogpt_key=h2ogpt_key)
ret1, _ = test_client_chat(**kwargs)
ret2, _ = test_client_chat_stream(**kwargs)
ret3, _ = test_client_nochat_stream(**kwargs)
ret4, _ = test_client_basic(**kwargs)
ret5, _ = test_client_basic_api(**kwargs)
ret6, _ = test_client_basic_api_lean(**kwargs)
ret7, _ = test_client_basic_api_lean_morestuff(**kwargs)
return ret1, ret2, ret3, ret4, ret5, ret6, ret7
if __name__ == '__main__':
run_client_many()