pseudotensor commited on
Commit
cf9ad1a
·
1 Parent(s): 1c0f538

Update with h2oGPT hash e7d4914948ac2b9a5a82f1cc82556197b261cb46

Browse files
app.py CHANGED
@@ -1 +1 @@
1
- generate.py
 
1
+ gen.py
client_test.py CHANGED
@@ -12,13 +12,13 @@ Currently, this will force model to be on a single GPU.
12
 
13
  Then run this client as:
14
 
15
- python client_test.py
16
 
17
 
18
 
19
  For HF spaces:
20
 
21
- HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
22
 
23
  Result:
24
 
@@ -28,7 +28,7 @@ Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
28
 
29
  For demo:
30
 
31
- HOST="https://gpt.h2o.ai" python client_test.py
32
 
33
  Result:
34
 
@@ -48,7 +48,7 @@ import markdown # pip install markdown
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
- from enums import DocumentChoices
52
 
53
  debug = False
54
 
@@ -67,7 +67,9 @@ def get_client(serialize=True):
67
  def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
- langchain_mode='Disabled'):
 
 
71
  from collections import OrderedDict
72
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
73
  iinput='', # only for chat=True
@@ -76,7 +78,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
76
  # but leave stream_output=False for simple input/output mode
77
  stream_output=stream_output,
78
  prompt_type=prompt_type,
79
- prompt_dict='',
80
  temperature=0.1,
81
  top_p=0.75,
82
  top_k=40,
@@ -92,12 +94,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
92
  instruction_nochat=prompt if not chat else '',
93
  iinput_nochat='', # only for chat=False
94
  langchain_mode=langchain_mode,
 
95
  top_k_docs=top_k_docs,
96
  chunk=True,
97
  chunk_size=512,
98
  document_choice=[DocumentChoices.All_Relevant.name],
99
  )
100
- from generate import eval_func_param_names
101
  assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
102
  if chat:
103
  # add chatbot output on end. Assumes serialize=False
@@ -198,6 +201,7 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
198
  instruction_nochat=prompt,
199
  iinput_nochat='',
200
  langchain_mode='Disabled',
 
201
  top_k_docs=4,
202
  document_choice=['All'],
203
  )
@@ -219,21 +223,24 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
219
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
220
  def test_client_chat(prompt_type='human_bot'):
221
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
222
- langchain_mode='Disabled')
223
 
224
 
225
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
226
  def test_client_chat_stream(prompt_type='human_bot'):
227
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
228
  stream_output=True, max_new_tokens=512,
229
- langchain_mode='Disabled')
230
 
231
 
232
- def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
 
233
  client = get_client(serialize=False)
234
 
235
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
236
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
 
 
237
  return run_client(client, prompt, args, kwargs)
238
 
239
 
@@ -276,14 +283,15 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
276
  def test_client_nochat_stream(prompt_type='human_bot'):
277
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
278
  stream_output=True, max_new_tokens=512,
279
- langchain_mode='Disabled')
280
 
281
 
282
- def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
283
  client = get_client(serialize=False)
284
 
285
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
286
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
 
287
  return run_client_gen(client, prompt, args, kwargs)
288
 
289
 
 
12
 
13
  Then run this client as:
14
 
15
+ python src/client_test.py
16
 
17
 
18
 
19
  For HF spaces:
20
 
21
+ HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
22
 
23
  Result:
24
 
 
28
 
29
  For demo:
30
 
31
+ HOST="https://gpt.h2o.ai" python src/client_test.py
32
 
33
  Result:
34
 
 
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
+ from enums import DocumentChoices, LangChainAction
52
 
53
  debug = False
54
 
 
67
  def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
+ langchain_mode='Disabled',
71
+ langchain_action=LangChainAction.QUERY.value,
72
+ prompt_dict=None):
73
  from collections import OrderedDict
74
  kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
75
  iinput='', # only for chat=True
 
78
  # but leave stream_output=False for simple input/output mode
79
  stream_output=stream_output,
80
  prompt_type=prompt_type,
81
+ prompt_dict=prompt_dict,
82
  temperature=0.1,
83
  top_p=0.75,
84
  top_k=40,
 
94
  instruction_nochat=prompt if not chat else '',
95
  iinput_nochat='', # only for chat=False
96
  langchain_mode=langchain_mode,
97
+ langchain_action=langchain_action,
98
  top_k_docs=top_k_docs,
99
  chunk=True,
100
  chunk_size=512,
101
  document_choice=[DocumentChoices.All_Relevant.name],
102
  )
103
+ from evaluate_params import eval_func_param_names
104
  assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
105
  if chat:
106
  # add chatbot output on end. Assumes serialize=False
 
201
  instruction_nochat=prompt,
202
  iinput_nochat='',
203
  langchain_mode='Disabled',
204
+ langchain_action=LangChainAction.QUERY.value,
205
  top_k_docs=4,
206
  document_choice=['All'],
207
  )
 
223
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
224
  def test_client_chat(prompt_type='human_bot'):
225
  return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
226
+ langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
227
 
228
 
229
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
230
  def test_client_chat_stream(prompt_type='human_bot'):
231
  return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
232
  stream_output=True, max_new_tokens=512,
233
+ langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
234
 
235
 
236
+ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action,
237
+ prompt_dict=None):
238
  client = get_client(serialize=False)
239
 
240
  kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
241
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
242
+ langchain_action=langchain_action,
243
+ prompt_dict=prompt_dict)
244
  return run_client(client, prompt, args, kwargs)
245
 
246
 
 
283
  def test_client_nochat_stream(prompt_type='human_bot'):
284
  return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
285
  stream_output=True, max_new_tokens=512,
286
+ langchain_mode='Disabled', langchain_action=LangChainAction.QUERY.value)
287
 
288
 
289
+ def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode, langchain_action):
290
  client = get_client(serialize=False)
291
 
292
  kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
293
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
294
+ langchain_action=langchain_action)
295
  return run_client_gen(client, prompt, args, kwargs)
296
 
297
 
enums.py CHANGED
@@ -37,6 +37,9 @@ class DocumentChoices(Enum):
37
  Just_LLM = 3
38
 
39
 
 
 
 
40
  class LangChainMode(Enum):
41
  """LangChain mode"""
42
 
@@ -52,10 +55,22 @@ class LangChainMode(Enum):
52
  H2O_DAI_DOCS = "DriverlessAI docs"
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
55
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
56
 
57
 
58
- # from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
 
59
  model_token_mapping = {
60
  "gpt-4": 8192,
61
  "gpt-4-0314": 8192,
 
37
  Just_LLM = 3
38
 
39
 
40
+ non_query_commands = [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]
41
+
42
+
43
  class LangChainMode(Enum):
44
  """LangChain mode"""
45
 
 
55
  H2O_DAI_DOCS = "DriverlessAI docs"
56
 
57
 
58
+ class LangChainAction(Enum):
59
+ """LangChain action"""
60
+
61
+ QUERY = "Query"
62
+ # WIP:
63
+ #SUMMARIZE_MAP = "Summarize_map_reduce"
64
+ SUMMARIZE_MAP = "Summarize"
65
+ SUMMARIZE_ALL = "Summarize_all"
66
+ SUMMARIZE_REFINE = "Summarize_refine"
67
+
68
+
69
  no_server_str = no_lora_str = no_model_str = '[None/Remove]'
70
 
71
 
72
+ # from site-packages/langchain/llms/openai.py
73
+ # but needed since ChatOpenAI doesn't have this information
74
  model_token_mapping = {
75
  "gpt-4": 8192,
76
  "gpt-4-0314": 8192,
evaluate_params.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ no_default_param_names = [
2
+ 'instruction',
3
+ 'iinput',
4
+ 'context',
5
+ 'instruction_nochat',
6
+ 'iinput_nochat',
7
+ ]
8
+
9
+ gen_hyper = ['temperature',
10
+ 'top_p',
11
+ 'top_k',
12
+ 'num_beams',
13
+ 'max_new_tokens',
14
+ 'min_new_tokens',
15
+ 'early_stopping',
16
+ 'max_time',
17
+ 'repetition_penalty',
18
+ 'num_return_sequences',
19
+ 'do_sample',
20
+ ]
21
+
22
+ eval_func_param_names = ['instruction',
23
+ 'iinput',
24
+ 'context',
25
+ 'stream_output',
26
+ 'prompt_type',
27
+ 'prompt_dict'] + \
28
+ gen_hyper + \
29
+ ['chat',
30
+ 'instruction_nochat',
31
+ 'iinput_nochat',
32
+ 'langchain_mode',
33
+ 'langchain_action',
34
+ 'top_k_docs',
35
+ 'chunk',
36
+ 'chunk_size',
37
+ 'document_choice',
38
+ ]
39
+
40
+ # form evaluate defaults for submit_nochat_api
41
+ eval_func_param_names_defaults = eval_func_param_names.copy()
42
+ for k in no_default_param_names:
43
+ if k in eval_func_param_names_defaults:
44
+ eval_func_param_names_defaults.remove(k)
45
+
46
+
47
+ eval_extra_columns = ['prompt', 'response', 'score']
gen.py ADDED
The diff for this file is too large to render. See raw diff
 
gpt4all_llm.py CHANGED
@@ -19,6 +19,15 @@ def get_model_tokenizer_gpt4all(base_model, **kwargs):
19
  n_ctx=2048 - 256)
20
  env_gpt4all_file = ".env_gpt4all"
21
  model_kwargs.update(dotenv_values(env_gpt4all_file))
 
 
 
 
 
 
 
 
 
22
 
23
  if base_model == "llama":
24
  if 'model_path_llama' not in model_kwargs:
 
19
  n_ctx=2048 - 256)
20
  env_gpt4all_file = ".env_gpt4all"
21
  model_kwargs.update(dotenv_values(env_gpt4all_file))
22
+ # make int or float if can to satisfy types for class
23
+ for k, v in model_kwargs.items():
24
+ try:
25
+ if float(v) == int(v):
26
+ model_kwargs[k] = int(v)
27
+ else:
28
+ model_kwargs[k] = float(v)
29
+ except:
30
+ pass
31
 
32
  if base_model == "llama":
33
  if 'model_path_llama' not in model_kwargs:
gpt_langchain.py CHANGED
@@ -23,8 +23,10 @@ from langchain.callbacks import streaming_stdout
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
24
  from tqdm import tqdm
25
 
26
- from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
27
- from generate import gen_hyper, get_model, SEED
 
 
28
  from prompter import non_hf_types, PromptType, Prompter
29
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
30
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
@@ -43,7 +45,8 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
43
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
44
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
45
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
46
- UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
 
47
  from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
48
  from langchain.chains.question_answering import load_qa_chain
49
  from langchain.docstore.document import Document
@@ -351,6 +354,7 @@ class GradioInference(LLM):
351
  stream_output = self.stream
352
  gr_client = self.client
353
  client_langchain_mode = 'Disabled'
 
354
  top_k_docs = 1
355
  chunk = True
356
  chunk_size = 512
@@ -379,6 +383,7 @@ class GradioInference(LLM):
379
  instruction_nochat=prompt if not self.chat_client else '',
380
  iinput_nochat='', # only for chat=False
381
  langchain_mode=client_langchain_mode,
 
382
  top_k_docs=top_k_docs,
383
  chunk=chunk,
384
  chunk_size=chunk_size,
@@ -637,6 +642,7 @@ def get_llm(use_openai_model=False,
637
  callbacks = [StreamingGradioCallbackHandler()]
638
  assert prompter is not None
639
  stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
 
640
 
641
  if gr_client:
642
  chat_client = False
@@ -744,7 +750,7 @@ def get_llm(use_openai_model=False,
744
 
745
  if stream_output:
746
  skip_prompt = False
747
- from generate import H2OTextIteratorStreamer
748
  decoder_kwargs = {}
749
  streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
750
  gen_kwargs.update(dict(streamer=streamer))
@@ -944,14 +950,16 @@ have_playwright = False
944
 
945
  image_types = ["png", "jpg", "jpeg"]
946
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
947
- "md", "html",
 
948
  "enex", "eml", "epub", "odt", "pptx", "ppt",
949
  "zip", "urls",
 
950
  ]
951
  # "msg", GPL3
952
 
953
  if have_libreoffice:
954
- non_image_types.extend(["docx", "doc"])
955
 
956
  file_types = non_image_types + image_types
957
 
@@ -961,7 +969,7 @@ def add_meta(docs1, file):
961
  hashid = hash_file(file)
962
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
963
  docs1 = [docs1]
964
- [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
965
 
966
 
967
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
@@ -1038,6 +1046,10 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1038
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1039
  add_meta(docs1, file)
1040
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
 
 
 
 
1041
  elif file.lower().endswith('.odt'):
1042
  docs1 = UnstructuredODTLoader(file_path=file).load()
1043
  add_meta(docs1, file)
@@ -1171,7 +1183,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1171
  # so just extract in path where
1172
  zip_ref.extractall(base_path)
1173
  # recurse
1174
- doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception)
1175
  else:
1176
  raise RuntimeError("No file handler for %s" % os.path.basename(file))
1177
 
@@ -1758,6 +1770,8 @@ def run_qa_db(**kwargs):
1758
 
1759
 
1760
  def _run_qa_db(query=None,
 
 
1761
  use_openai_model=False, use_openai_embedding=False,
1762
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1763
  user_path=None,
@@ -1787,6 +1801,7 @@ def _run_qa_db(query=None,
1787
  repetition_penalty=1.0,
1788
  num_return_sequences=1,
1789
  langchain_mode=None,
 
1790
  document_choice=[DocumentChoices.All_Relevant.name],
1791
  n_jobs=-1,
1792
  verbose=False,
@@ -1803,7 +1818,7 @@ def _run_qa_db(query=None,
1803
  :param use_openai_embedding:
1804
  :param first_para:
1805
  :param text_limit:
1806
- :param k:
1807
  :param chunk:
1808
  :param chunk_size:
1809
  :param user_path: user path to glob recursively from
@@ -1869,12 +1884,28 @@ def _run_qa_db(query=None,
1869
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1870
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1871
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1872
- docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
1873
- if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
1874
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1875
  yield formatted_doc_chunks, ''
1876
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1877
  if chain is None and model_name not in non_hf_types:
 
1878
  # can only return if HF type
1879
  return
1880
 
@@ -1933,6 +1964,7 @@ def _run_qa_db(query=None,
1933
 
1934
 
1935
  def get_similarity_chain(query=None,
 
1936
  use_openai_model=False, use_openai_embedding=False,
1937
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1938
  user_path=None,
@@ -1947,6 +1979,7 @@ def get_similarity_chain(query=None,
1947
  load_db_if_exists=False,
1948
  db=None,
1949
  langchain_mode=None,
 
1950
  document_choice=[DocumentChoices.All_Relevant.name],
1951
  n_jobs=-1,
1952
  # beyond run_db_query:
@@ -1997,25 +2030,56 @@ def get_similarity_chain(query=None,
1997
  db=db,
1998
  n_jobs=n_jobs,
1999
  verbose=verbose)
2000
-
2001
- if 'falcon' in model_name:
2002
- extra = "According to only the information in the document sources provided within the context above, "
2003
- prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
2004
- elif inference_server in ['openai', 'openai_chat']:
2005
- extra = "According to (primarily) the information in the document sources provided within context above, "
2006
- prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
2007
- else:
2008
- extra = ""
2009
- prefix = ""
2010
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2011
- template_if_no_docs = template = """%s{context}{question}""" % prefix
2012
- else:
2013
- template = """%s
2014
- \"\"\"
2015
- {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2016
  \"\"\"
2017
- %s{question}""" % (prefix, extra)
2018
- template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
 
 
 
 
 
 
 
2019
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2020
  use_template = True
2021
  else:
@@ -2040,14 +2104,26 @@ def get_similarity_chain(query=None,
2040
  if cmd == DocumentChoices.Just_LLM.name:
2041
  docs = []
2042
  scores = []
2043
- elif cmd == DocumentChoices.Only_All_Sources.name:
2044
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2045
  # similar to langchain's chroma's _results_to_docs_and_scores
2046
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2047
- for result in zip(db_documents, db_metadatas)][:top_k_docs]
 
 
 
 
 
 
 
 
 
2048
  docs = [x[0] for x in docs_with_score]
2049
  scores = [x[1] for x in docs_with_score]
 
2050
  else:
 
 
2051
  if top_k_docs == -1 or auto_reduce_chunks:
2052
  # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2053
  top_k_docs_tokenize = 100
@@ -2120,6 +2196,7 @@ def get_similarity_chain(query=None,
2120
  if reverse_docs:
2121
  docs_with_score.reverse()
2122
  # cut off so no high distance docs/sources considered
 
2123
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2124
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
2125
  if len(scores) > 0 and verbose:
@@ -2131,14 +2208,14 @@ def get_similarity_chain(query=None,
2131
 
2132
  if not docs and use_context and model_name not in non_hf_types:
2133
  # if HF type and have no docs, can bail out
2134
- return docs, None, [], False
2135
 
2136
- if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
2137
  # no LLM use
2138
- return docs, None, [], False
2139
 
2140
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
2141
- if os.path.isfile(common_words_file):
2142
  df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
2143
  import string
2144
  reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
@@ -2155,25 +2232,47 @@ def get_similarity_chain(query=None,
2155
  use_context = False
2156
  template = template_if_no_docs
2157
 
2158
- if use_template:
2159
- # instruct-like, rather than few-shot prompt_type='plain' as default
2160
- # but then sources confuse the model with how inserted among rest of text, so avoid
2161
- prompt = PromptTemplate(
2162
- # input_variables=["summaries", "question"],
2163
- input_variables=["context", "question"],
2164
- template=template,
2165
- )
2166
- chain = load_qa_chain(llm, prompt=prompt)
2167
- else:
2168
- chain = load_qa_with_sources_chain(llm)
2169
-
2170
- if not use_context:
2171
- chain_kwargs = dict(input_documents=[], question=query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2172
  else:
2173
- chain_kwargs = dict(input_documents=docs, question=query)
2174
 
2175
- target = wrapped_partial(chain, chain_kwargs)
2176
- return docs, target, scores, use_context
2177
 
2178
 
2179
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
@@ -2243,6 +2342,11 @@ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2243
  splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
2244
  separators=separators)
2245
  source_chunks = splitter.split_documents(sources)
 
 
 
 
 
2246
  return source_chunks
2247
 
2248
 
 
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
24
  from tqdm import tqdm
25
 
26
+ from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
27
+ LangChainAction, LangChainMode
28
+ from evaluate_params import gen_hyper
29
+ from gen import get_model, SEED
30
  from prompter import non_hf_types, PromptType, Prompter
31
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
32
  get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
 
45
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
46
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
47
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
48
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \
49
+ UnstructuredExcelLoader
50
  from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
51
  from langchain.chains.question_answering import load_qa_chain
52
  from langchain.docstore.document import Document
 
354
  stream_output = self.stream
355
  gr_client = self.client
356
  client_langchain_mode = 'Disabled'
357
+ client_langchain_action = LangChainAction.QUERY.value
358
  top_k_docs = 1
359
  chunk = True
360
  chunk_size = 512
 
383
  instruction_nochat=prompt if not self.chat_client else '',
384
  iinput_nochat='', # only for chat=False
385
  langchain_mode=client_langchain_mode,
386
+ langchain_action=client_langchain_action,
387
  top_k_docs=top_k_docs,
388
  chunk=chunk,
389
  chunk_size=chunk_size,
 
642
  callbacks = [StreamingGradioCallbackHandler()]
643
  assert prompter is not None
644
  stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
645
+ stop_sequences = [x for x in stop_sequences if x]
646
 
647
  if gr_client:
648
  chat_client = False
 
750
 
751
  if stream_output:
752
  skip_prompt = False
753
+ from gen import H2OTextIteratorStreamer
754
  decoder_kwargs = {}
755
  streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
756
  gen_kwargs.update(dict(streamer=streamer))
 
950
 
951
  image_types = ["png", "jpg", "jpeg"]
952
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
953
+ "md",
954
+ "html", "mhtml",
955
  "enex", "eml", "epub", "odt", "pptx", "ppt",
956
  "zip", "urls",
957
+
958
  ]
959
  # "msg", GPL3
960
 
961
  if have_libreoffice:
962
+ non_image_types.extend(["docx", "doc", "xls", "xlsx"])
963
 
964
  file_types = non_image_types + image_types
965
 
 
969
  hashid = hash_file(file)
970
  if not isinstance(docs1, (list, tuple, types.GeneratorType)):
971
  docs1 = [docs1]
972
+ [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid)) for x in docs1]
973
 
974
 
975
  def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
 
1046
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1047
  add_meta(docs1, file)
1048
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1049
+ elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and have_libreoffice:
1050
+ docs1 = UnstructuredExcelLoader(file_path=file).load()
1051
+ add_meta(docs1, file)
1052
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1053
  elif file.lower().endswith('.odt'):
1054
  docs1 = UnstructuredODTLoader(file_path=file).load()
1055
  add_meta(docs1, file)
 
1183
  # so just extract in path where
1184
  zip_ref.extractall(base_path)
1185
  # recurse
1186
+ doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception, n_jobs=n_jobs)
1187
  else:
1188
  raise RuntimeError("No file handler for %s" % os.path.basename(file))
1189
 
 
1770
 
1771
 
1772
  def _run_qa_db(query=None,
1773
+ iinput=None,
1774
+ context=None,
1775
  use_openai_model=False, use_openai_embedding=False,
1776
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1777
  user_path=None,
 
1801
  repetition_penalty=1.0,
1802
  num_return_sequences=1,
1803
  langchain_mode=None,
1804
+ langchain_action=None,
1805
  document_choice=[DocumentChoices.All_Relevant.name],
1806
  n_jobs=-1,
1807
  verbose=False,
 
1818
  :param use_openai_embedding:
1819
  :param first_para:
1820
  :param text_limit:
1821
+ :param top_k_docs:
1822
  :param chunk:
1823
  :param chunk_size:
1824
  :param user_path: user path to glob recursively from
 
1884
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1885
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1886
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1887
+ docs, chain, scores, use_context, have_any_docs = get_similarity_chain(**sim_kwargs)
1888
+ if cmd in non_query_commands:
1889
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1890
  yield formatted_doc_chunks, ''
1891
  return
1892
+ if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
1893
+ LangChainAction.SUMMARIZE_ALL.value,
1894
+ LangChainAction.SUMMARIZE_REFINE.value]:
1895
+ ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.'
1896
+ extra = ''
1897
+ yield ret, extra
1898
+ return
1899
+ if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
1900
+ LangChainMode.CHAT_LLM.value,
1901
+ LangChainMode.LLM.value]:
1902
+ ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
1903
+ extra = ''
1904
+ yield ret, extra
1905
+ return
1906
+
1907
  if chain is None and model_name not in non_hf_types:
1908
+ # here if no docs at all and not HF type
1909
  # can only return if HF type
1910
  return
1911
 
 
1964
 
1965
 
1966
  def get_similarity_chain(query=None,
1967
+ iinput=None,
1968
  use_openai_model=False, use_openai_embedding=False,
1969
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1970
  user_path=None,
 
1979
  load_db_if_exists=False,
1980
  db=None,
1981
  langchain_mode=None,
1982
+ langchain_action=None,
1983
  document_choice=[DocumentChoices.All_Relevant.name],
1984
  n_jobs=-1,
1985
  # beyond run_db_query:
 
2030
  db=db,
2031
  n_jobs=n_jobs,
2032
  verbose=verbose)
2033
+ have_any_docs = db is not None
2034
+ if langchain_action == LangChainAction.QUERY.value:
2035
+ if iinput:
2036
+ query = "%s\n%s" % (query, iinput)
2037
+
2038
+ if 'falcon' in model_name:
2039
+ extra = "According to only the information in the document sources provided within the context above, "
2040
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
2041
+ elif inference_server in ['openai', 'openai_chat']:
2042
+ extra = "According to (primarily) the information in the document sources provided within context above, "
2043
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
2044
+ else:
2045
+ extra = ""
2046
+ prefix = ""
2047
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2048
+ template_if_no_docs = template = """%s{context}{question}""" % prefix
2049
+ else:
2050
+ template = """%s
2051
+ \"\"\"
2052
+ {context}
2053
+ \"\"\"
2054
+ %s{question}""" % (prefix, extra)
2055
+ template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
2056
+ elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]:
2057
+ none = ['', '\n', None]
2058
+ if query in none and iinput in none:
2059
+ prompt_summary = "Using only the text above, write a condensed and concise summary:\n"
2060
+ elif query not in none:
2061
+ prompt_summary = "Focusing on %s, write a condensed and concise Summary:\n" % query
2062
+ elif iinput not in None:
2063
+ prompt_summary = iinput
2064
+ else:
2065
+ prompt_summary = "Focusing on %s, %s:\n" % (query, iinput)
2066
+ # don't auto reduce
2067
+ auto_reduce_chunks = False
2068
+ if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
2069
+ fstring = '{text}'
2070
+ else:
2071
+ fstring = '{input_documents}'
2072
+ template = """In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text:
2073
  \"\"\"
2074
+ %s
2075
+ \"\"\"\n%s""" % (fstring, prompt_summary)
2076
+ template_if_no_docs = "Exactly only say: There are no documents to summarize."
2077
+ elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]:
2078
+ template = '' # unused
2079
+ template_if_no_docs = '' # unused
2080
+ else:
2081
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
2082
+
2083
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2084
  use_template = True
2085
  else:
 
2104
  if cmd == DocumentChoices.Just_LLM.name:
2105
  docs = []
2106
  scores = []
2107
+ elif cmd == DocumentChoices.Only_All_Sources.name or query in [None, '', '\n']:
2108
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2109
  # similar to langchain's chroma's _results_to_docs_and_scores
2110
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2111
+ for result in zip(db_documents, db_metadatas)]
2112
+
2113
+ # order documents
2114
+ doc_hashes = [x['doc_hash'] for x in db_metadatas]
2115
+ doc_chunk_ids = [x['chunk_id'] for x in db_metadatas]
2116
+ docs_with_score = [x for _, _, x in
2117
+ sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
2118
+ ]
2119
+
2120
+ docs_with_score = docs_with_score[:top_k_docs]
2121
  docs = [x[0] for x in docs_with_score]
2122
  scores = [x[1] for x in docs_with_score]
2123
+ have_any_docs |= len(docs) > 0
2124
  else:
2125
+ # FIXME: if langchain_action == LangChainAction.SUMMARIZE_MAP.value
2126
+ # if map_reduce, then no need to auto reduce chunks
2127
  if top_k_docs == -1 or auto_reduce_chunks:
2128
  # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2129
  top_k_docs_tokenize = 100
 
2196
  if reverse_docs:
2197
  docs_with_score.reverse()
2198
  # cut off so no high distance docs/sources considered
2199
+ have_any_docs |= len(docs_with_score) > 0 # before cut
2200
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2201
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
2202
  if len(scores) > 0 and verbose:
 
2208
 
2209
  if not docs and use_context and model_name not in non_hf_types:
2210
  # if HF type and have no docs, can bail out
2211
+ return docs, None, [], False, have_any_docs
2212
 
2213
+ if cmd in non_query_commands:
2214
  # no LLM use
2215
+ return docs, None, [], False, have_any_docs
2216
 
2217
  common_words_file = "data/NGSL_1.2_stats.csv.zip"
2218
+ if os.path.isfile(common_words_file) and langchain_mode == LangChainAction.QUERY.value:
2219
  df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
2220
  import string
2221
  reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
 
2232
  use_context = False
2233
  template = template_if_no_docs
2234
 
2235
+ if langchain_action == LangChainAction.QUERY.value:
2236
+ if use_template:
2237
+ # instruct-like, rather than few-shot prompt_type='plain' as default
2238
+ # but then sources confuse the model with how inserted among rest of text, so avoid
2239
+ prompt = PromptTemplate(
2240
+ # input_variables=["summaries", "question"],
2241
+ input_variables=["context", "question"],
2242
+ template=template,
2243
+ )
2244
+ chain = load_qa_chain(llm, prompt=prompt)
2245
+ else:
2246
+ # only if use_openai_model = True, unused normally except in testing
2247
+ chain = load_qa_with_sources_chain(llm)
2248
+ if not use_context:
2249
+ chain_kwargs = dict(input_documents=[], question=query)
2250
+ else:
2251
+ chain_kwargs = dict(input_documents=docs, question=query)
2252
+ target = wrapped_partial(chain, chain_kwargs)
2253
+ elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
2254
+ LangChainAction.SUMMARIZE_REFINE,
2255
+ LangChainAction.SUMMARIZE_ALL.value]:
2256
+ from langchain.chains.summarize import load_summarize_chain
2257
+ if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
2258
+ prompt = PromptTemplate(input_variables=["text"], template=template)
2259
+ chain = load_summarize_chain(llm, chain_type="map_reduce",
2260
+ map_prompt=prompt, combine_prompt=prompt, return_intermediate_steps=True)
2261
+ target = wrapped_partial(chain, {"input_documents": docs}) # , return_only_outputs=True)
2262
+ elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
2263
+ assert use_template
2264
+ prompt = PromptTemplate(input_variables=["text"], template=template)
2265
+ chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, return_intermediate_steps=True)
2266
+ target = wrapped_partial(chain)
2267
+ elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
2268
+ chain = load_summarize_chain(llm, chain_type="refine", return_intermediate_steps=True)
2269
+ target = wrapped_partial(chain)
2270
+ else:
2271
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
2272
  else:
2273
+ raise RuntimeError("No such langchain_action=%s" % langchain_action)
2274
 
2275
+ return docs, target, scores, use_context, have_any_docs
 
2276
 
2277
 
2278
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
 
2342
  splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
2343
  separators=separators)
2344
  source_chunks = splitter.split_documents(sources)
2345
+
2346
+ # currently in order, but when pull from db won't be, so mark order and document by hash
2347
+ doc_hash = str(uuid.uuid4())[:10]
2348
+ [x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)]
2349
+
2350
  return source_chunks
2351
 
2352
 
gradio_runner.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import copy
2
  import functools
3
  import inspect
@@ -49,16 +50,16 @@ def fix_pydantic_duplicate_validators_error():
49
 
50
  fix_pydantic_duplicate_validators_error()
51
 
52
- from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainMode
53
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
54
  text_xsm
55
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
56
  get_prompt
57
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
58
  ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
59
- from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
60
- inputs_kwargs_list, scratch_base_dir, evaluate_from_str, no_default_param_names, \
61
- eval_func_param_names_defaults, get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context
62
 
63
  from apscheduler.schedulers.background import BackgroundScheduler
64
 
@@ -99,6 +100,7 @@ def go_gradio(**kwargs):
99
  dbs = kwargs['dbs']
100
  db_type = kwargs['db_type']
101
  visible_langchain_modes = kwargs['visible_langchain_modes']
 
102
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
103
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
104
  enable_sources_list = kwargs['enable_sources_list']
@@ -213,7 +215,28 @@ def go_gradio(**kwargs):
213
  'base_model') else no_model_msg
214
  output_label0_model2 = no_model_msg
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
 
 
 
 
217
  for k in no_default_param_names:
218
  default_kwargs[k] = ''
219
 
@@ -239,7 +262,8 @@ def go_gradio(**kwargs):
239
  model_options_state = gr.State([model_options])
240
  lora_options_state = gr.State([lora_options])
241
  server_options_state = gr.State([server_options])
242
- my_db_state = gr.State([None, None])
 
243
  chat_state = gr.State({})
244
  # make user default first and default choice, dedup
245
  docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
@@ -283,7 +307,7 @@ def go_gradio(**kwargs):
283
 
284
  col_chat = gr.Column(visible=kwargs['chat'])
285
  with col_chat:
286
- instruction, submit, stop_btn = make_prompt_form(kwargs)
287
  text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
288
  **kwargs)
289
 
@@ -332,6 +356,12 @@ def go_gradio(**kwargs):
332
  value=kwargs['langchain_mode'],
333
  label="Data Collection of Sources",
334
  visible=kwargs['langchain_mode'] != 'Disabled')
 
 
 
 
 
 
335
  data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
336
  with data_row2:
337
  with gr.Column(scale=50):
@@ -726,6 +756,7 @@ def go_gradio(**kwargs):
726
  caption_loader=caption_loader,
727
  verbose=kwargs['verbose'],
728
  user_path=kwargs['user_path'],
 
729
  )
730
  add_file_outputs = [fileup_output, langchain_mode, add_to_shared_db_btn, add_to_my_db_btn]
731
  add_file_kwargs = dict(fn=update_user_db_func,
@@ -804,6 +835,7 @@ def go_gradio(**kwargs):
804
  caption_loader=caption_loader,
805
  verbose=kwargs['verbose'],
806
  user_path=kwargs['user_path'],
 
807
  )
808
 
809
  add_my_file_outputs = [fileup_output, langchain_mode, my_db_state, add_to_shared_db_btn, add_to_my_db_btn]
@@ -920,19 +952,59 @@ def go_gradio(**kwargs):
920
  for k in inputs_kwargs_list:
921
  assert k in kwargs_evaluate, "Missing %s" % k
922
 
923
- def evaluate_gradio(*args1, **kwargs1):
924
- for res_dict in evaluate(*args1, **kwargs1):
925
- if kwargs['langchain_mode'] == 'Disabled':
926
- yield fix_text_for_gradio(res_dict['response'])
927
- else:
928
- yield '<br>' + fix_text_for_gradio(res_dict['response'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929
 
930
- fun = partial(evaluate_gradio,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
  **kwargs_evaluate)
932
- fun2 = partial(evaluate_gradio,
 
 
933
  **kwargs_evaluate)
934
- fun_with_dict_str = partial(evaluate_from_str,
935
- default_kwargs=default_kwargs,
 
936
  **kwargs_evaluate
937
  )
938
 
@@ -1072,14 +1144,17 @@ def go_gradio(**kwargs):
1072
  User that fills history for bot
1073
  :param args:
1074
  :param undo:
 
1075
  :param sanitize_user_prompt:
1076
- :param model2:
1077
  :return:
1078
  """
1079
  args_list = list(args)
1080
  user_message = args_list[eval_func_param_names.index('instruction')] # chat only
1081
  input1 = args_list[eval_func_param_names.index('iinput')] # chat only
1082
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
 
 
 
1083
  if not prompt_type1:
1084
  # shouldn't have to specify if CLI launched model
1085
  prompt_type1 = kwargs['prompt_type']
@@ -1110,8 +1185,12 @@ def go_gradio(**kwargs):
1110
  history[-1][1] = None
1111
  return history
1112
  if user_message1 in ['', None, '\n']:
1113
- # reject non-retry submit/enter
1114
- return history
 
 
 
 
1115
  user_message1 = fix_text_for_gradio(user_message1)
1116
  return history + [[user_message1, None]]
1117
 
@@ -1147,11 +1226,13 @@ def go_gradio(**kwargs):
1147
  else:
1148
  return 2000
1149
 
1150
- def prep_bot(*args, retry=False):
1151
  """
1152
 
1153
  :param args:
1154
  :param retry:
 
 
1155
  :return: last element is True if should run bot, False if should just yield history
1156
  """
1157
  # don't deepcopy, can contain model itself
@@ -1159,12 +1240,16 @@ def go_gradio(**kwargs):
1159
  model_state1 = args_list[-3]
1160
  my_db_state1 = args_list[-2]
1161
  history = args_list[-1]
1162
- langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
 
1163
 
1164
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1165
  return history, None, None, None
1166
 
1167
  args_list = args_list[:-3] # only keep rest needed for evaluate()
 
 
 
1168
  if not history:
1169
  print("No history", flush=True)
1170
  history = []
@@ -1175,22 +1260,23 @@ def go_gradio(**kwargs):
1175
  instruction1 = history[-1][0]
1176
  history[-1][1] = None
1177
  elif not instruction1:
1178
- # if not retrying, then reject empty query
1179
- return history, None, None, None
 
 
 
 
1180
  elif len(history) > 0 and history[-1][1] not in [None, '']:
1181
  # reject submit button if already filled and not retrying
1182
  # None when not filling with '' to keep client happy
1183
  return history, None, None, None
1184
 
1185
  # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
1186
- prompt_type1 = kwargs.get('prompt_type', args_list[eval_func_param_names.index('prompt_type')])
1187
- # prefer model specific prompt type instead of global one, and apply back to args_list for evaluate()
1188
- args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 = \
1189
- model_state1.get('prompt_type', prompt_type1)
1190
-
1191
- prompt_dict1 = kwargs.get('prompt_dict', args_list[eval_func_param_names.index('prompt_dict')])
1192
- args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 = \
1193
- model_state1.get('prompt_dict', prompt_dict1)
1194
 
1195
  chat1 = args_list[eval_func_param_names.index('chat')]
1196
  model_max_length1 = get_model_max_length(model_state1)
@@ -1264,6 +1350,7 @@ def go_gradio(**kwargs):
1264
  for res in get_response(fun1, history):
1265
  yield res
1266
  finally:
 
1267
  clear_embeddings(langchain_mode1, my_db_state1)
1268
 
1269
  def all_bot(*args, retry=False, model_states1=None):
@@ -1277,7 +1364,7 @@ def go_gradio(**kwargs):
1277
  my_db_state1 = None # will be filled below by some bot
1278
  try:
1279
  gen_list = []
1280
- for chatbot1, model_state1 in zip(chatbots, model_states1):
1281
  args_list1 = args_list0.copy()
1282
  args_list1.insert(-1, model_state1) # insert at -1 so is at -2
1283
  # if at start, have None in response still, replace with '' so client etc. acts like normal
@@ -1289,7 +1376,8 @@ def go_gradio(**kwargs):
1289
  # so consistent with prep_bot()
1290
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1291
  # langchain_mode1 and my_db_state1 should be same for every bot
1292
- history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry)
 
1293
  gen1 = get_response(fun1, history)
1294
  if stream_output1:
1295
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
@@ -1301,6 +1389,7 @@ def go_gradio(**kwargs):
1301
  tgen0 = time.time()
1302
  for res1 in itertools.zip_longest(*gen_list):
1303
  if time.time() - tgen0 > max_time1:
 
1304
  break
1305
 
1306
  bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
@@ -1735,6 +1824,9 @@ def go_gradio(**kwargs):
1735
 
1736
  def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
1737
  infer_devices, gpu_id):
 
 
 
1738
  # ensure old model removed from GPU memory
1739
  if kwargs['debug']:
1740
  print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
@@ -2161,6 +2253,15 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData',
2161
  clear_torch_cache()
2162
 
2163
 
 
 
 
 
 
 
 
 
 
2164
  def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
2165
  user_path=None,
2166
  use_openai_embedding=None,
@@ -2170,7 +2271,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2170
  captions_model=None,
2171
  enable_ocr=None,
2172
  verbose=None,
2173
- is_url=None, is_txt=None):
 
2174
  assert use_openai_embedding is not None
2175
  assert hf_embedding_model is not None
2176
  assert caption_loader is not None
@@ -2211,6 +2313,7 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2211
  print("Adding %s" % file, flush=True)
2212
  sources = path_to_docs(file if not is_url and not is_txt else None,
2213
  verbose=verbose,
 
2214
  chunk=chunk, chunk_size=chunk_size,
2215
  url=file if is_url else None,
2216
  text=file if is_txt else None,
@@ -2222,7 +2325,8 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2222
  exceptions = [x for x in sources if x.metadata.get('exception')]
2223
  sources = [x for x in sources if 'exception' not in x.metadata]
2224
 
2225
- with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
 
2226
  if langchain_mode == 'MyData':
2227
  if db1[0] is not None:
2228
  # then add
@@ -2235,18 +2339,14 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2235
  # for production hit, when user gets clicky:
2236
  assert len(db1) == 2, "Bad MyData db: %s" % db1
2237
  # then create
2238
- # assign fresh hash for this user session, so not shared
2239
  # if added has to original state and didn't change, then would be shared db for all users
2240
- db1[1] = str(uuid.uuid4())
2241
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
2242
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2243
  db_type=db_type,
2244
  persist_directory=persist_directory,
2245
  langchain_mode=langchain_mode,
2246
  hf_embedding_model=hf_embedding_model)
2247
- if db is None:
2248
- db1[1] = None
2249
- else:
2250
  db1[0] = db
2251
  source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
2252
  return None, langchain_mode, db1, x, y, source_files_added
@@ -2274,7 +2374,9 @@ def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None,
2274
 
2275
 
2276
  def get_db(db1, langchain_mode, dbs=None):
2277
- with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
 
 
2278
  if langchain_mode in ['wiki_full']:
2279
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2280
  db = None
 
1
+ import ast
2
  import copy
3
  import functools
4
  import inspect
 
50
 
51
  fix_pydantic_duplicate_validators_error()
52
 
53
+ from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode
54
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
55
  text_xsm
56
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
57
  get_prompt
58
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
  ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
60
+ from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
61
+ get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions
62
+ from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
63
 
64
  from apscheduler.schedulers.background import BackgroundScheduler
65
 
 
100
  dbs = kwargs['dbs']
101
  db_type = kwargs['db_type']
102
  visible_langchain_modes = kwargs['visible_langchain_modes']
103
+ visible_langchain_actions = kwargs['visible_langchain_actions']
104
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
105
  allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
106
  enable_sources_list = kwargs['enable_sources_list']
 
215
  'base_model') else no_model_msg
216
  output_label0_model2 = no_model_msg
217
 
218
+ def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0):
219
+ if not prompt_type1 or which_model != 0:
220
+ # keep prompt_type and prompt_dict in sync if possible
221
+ prompt_type1 = kwargs.get('prompt_type', prompt_type1)
222
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
223
+ # prefer model specific prompt type instead of global one
224
+ if not prompt_type1 or which_model != 0:
225
+ prompt_type1 = model_state1.get('prompt_type', prompt_type1)
226
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
227
+
228
+ if not prompt_dict1 or which_model != 0:
229
+ # if still not defined, try to get
230
+ prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
231
+ if not prompt_dict1 or which_model != 0:
232
+ prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
233
+ return prompt_type1, prompt_dict1
234
+
235
  default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
236
+ # ensure prompt_type consistent with prep_bot(), so nochat API works same way
237
+ default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
238
+ update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
239
+ model_state1=model_state0, which_model=0)
240
  for k in no_default_param_names:
241
  default_kwargs[k] = ''
242
 
 
262
  model_options_state = gr.State([model_options])
263
  lora_options_state = gr.State([lora_options])
264
  server_options_state = gr.State([server_options])
265
+ # uuid in db is used as user ID
266
+ my_db_state = gr.State([None, str(uuid.uuid4())])
267
  chat_state = gr.State({})
268
  # make user default first and default choice, dedup
269
  docs_state00 = kwargs['document_choice'] + [x.name for x in list(DocumentChoices)]
 
307
 
308
  col_chat = gr.Column(visible=kwargs['chat'])
309
  with col_chat:
310
+ instruction, submit, stop_btn = make_prompt_form(kwargs, LangChainMode)
311
  text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
312
  **kwargs)
313
 
 
356
  value=kwargs['langchain_mode'],
357
  label="Data Collection of Sources",
358
  visible=kwargs['langchain_mode'] != 'Disabled')
359
+ allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
360
+ langchain_action = gr.Radio(
361
+ allowed_actions,
362
+ value=allowed_actions[0] if len(allowed_actions) > 0 else None,
363
+ label="Data Action",
364
+ visible=True)
365
  data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
366
  with data_row2:
367
  with gr.Column(scale=50):
 
756
  caption_loader=caption_loader,
757
  verbose=kwargs['verbose'],
758
  user_path=kwargs['user_path'],
759
+ n_jobs=kwargs['n_jobs'],
760
  )
761
  add_file_outputs = [fileup_output, langchain_mode, add_to_shared_db_btn, add_to_my_db_btn]
762
  add_file_kwargs = dict(fn=update_user_db_func,
 
835
  caption_loader=caption_loader,
836
  verbose=kwargs['verbose'],
837
  user_path=kwargs['user_path'],
838
+ n_jobs=kwargs['n_jobs'],
839
  )
840
 
841
  add_my_file_outputs = [fileup_output, langchain_mode, my_db_state, add_to_shared_db_btn, add_to_my_db_btn]
 
952
  for k in inputs_kwargs_list:
953
  assert k in kwargs_evaluate, "Missing %s" % k
954
 
955
+ def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
956
+ args_list = list(args1)
957
+ if str_api:
958
+ user_kwargs = args_list[2]
959
+ assert isinstance(user_kwargs, str)
960
+ user_kwargs = ast.literal_eval(user_kwargs)
961
+ else:
962
+ user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[2:])}
963
+ # only used for submit_nochat_api
964
+ user_kwargs['chat'] = False
965
+ if 'stream_output' not in user_kwargs:
966
+ user_kwargs['stream_output'] = False
967
+ if 'langchain_mode' not in user_kwargs:
968
+ # if user doesn't specify, then assume disabled, not use default
969
+ user_kwargs['langchain_mode'] = 'Disabled'
970
+ if 'langchain_action' not in user_kwargs:
971
+ user_kwargs['langchain_action'] = LangChainAction.QUERY.value
972
+
973
+ set1 = set(list(default_kwargs1.keys()))
974
+ set2 = set(eval_func_param_names)
975
+ assert set1 == set2, "Set diff: %s %s: %s" % (set1, set2, set1.symmetric_difference(set2))
976
+ # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
977
+ model_state1 = args_list[0]
978
+ my_db_state1 = args_list[1]
979
+ args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
980
+ in eval_func_param_names]
981
+ assert len(args_list) == len(eval_func_param_names)
982
+ args_list = [model_state1, my_db_state1] + args_list
983
 
984
+ try:
985
+ for res_dict in evaluate(*tuple(args_list), **kwargs1):
986
+ if str_api:
987
+ # full return of dict
988
+ yield res_dict
989
+ elif kwargs['langchain_mode'] == 'Disabled':
990
+ yield fix_text_for_gradio(res_dict['response'])
991
+ else:
992
+ yield '<br>' + fix_text_for_gradio(res_dict['response'])
993
+ finally:
994
+ clear_torch_cache()
995
+ clear_embeddings(user_kwargs['langchain_mode'], my_db_state1)
996
+
997
+ fun = partial(evaluate_nochat,
998
+ default_kwargs1=default_kwargs,
999
+ str_api=False,
1000
  **kwargs_evaluate)
1001
+ fun2 = partial(evaluate_nochat,
1002
+ default_kwargs1=default_kwargs,
1003
+ str_api=False,
1004
  **kwargs_evaluate)
1005
+ fun_with_dict_str = partial(evaluate_nochat,
1006
+ default_kwargs1=default_kwargs,
1007
+ str_api=True,
1008
  **kwargs_evaluate
1009
  )
1010
 
 
1144
  User that fills history for bot
1145
  :param args:
1146
  :param undo:
1147
+ :param retry:
1148
  :param sanitize_user_prompt:
 
1149
  :return:
1150
  """
1151
  args_list = list(args)
1152
  user_message = args_list[eval_func_param_names.index('instruction')] # chat only
1153
  input1 = args_list[eval_func_param_names.index('iinput')] # chat only
1154
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1155
+ langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1156
+ langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1157
+ document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1158
  if not prompt_type1:
1159
  # shouldn't have to specify if CLI launched model
1160
  prompt_type1 = kwargs['prompt_type']
 
1185
  history[-1][1] = None
1186
  return history
1187
  if user_message1 in ['', None, '\n']:
1188
+ if langchain_action1 in LangChainAction.QUERY.value and \
1189
+ DocumentChoices.Only_All_Sources.name not in document_choice1 \
1190
+ or \
1191
+ langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1192
+ # reject non-retry submit/enter
1193
+ return history
1194
  user_message1 = fix_text_for_gradio(user_message1)
1195
  return history + [[user_message1, None]]
1196
 
 
1226
  else:
1227
  return 2000
1228
 
1229
+ def prep_bot(*args, retry=False, which_model=0):
1230
  """
1231
 
1232
  :param args:
1233
  :param retry:
1234
+ :param which_model: identifies which model if doing model_lock
1235
+ API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
1236
  :return: last element is True if should run bot, False if should just yield history
1237
  """
1238
  # don't deepcopy, can contain model itself
 
1240
  model_state1 = args_list[-3]
1241
  my_db_state1 = args_list[-2]
1242
  history = args_list[-1]
1243
+ prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1244
+ prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
1245
 
1246
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1247
  return history, None, None, None
1248
 
1249
  args_list = args_list[:-3] # only keep rest needed for evaluate()
1250
+ langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1251
+ langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1252
+ document_choice1 = args_list[eval_func_param_names.index('document_choice')]
1253
  if not history:
1254
  print("No history", flush=True)
1255
  history = []
 
1260
  instruction1 = history[-1][0]
1261
  history[-1][1] = None
1262
  elif not instruction1:
1263
+ if langchain_action1 in LangChainAction.QUERY.value and \
1264
+ DocumentChoices.Only_All_Sources.name not in document_choice1 \
1265
+ or \
1266
+ langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1267
+ # if not retrying, then reject empty query
1268
+ return history, None, None, None
1269
  elif len(history) > 0 and history[-1][1] not in [None, '']:
1270
  # reject submit button if already filled and not retrying
1271
  # None when not filling with '' to keep client happy
1272
  return history, None, None, None
1273
 
1274
  # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
1275
+ prompt_type1, prompt_dict1 = update_prompt(prompt_type1, prompt_dict1, model_state1,
1276
+ which_model=which_model)
1277
+ # apply back to args_list for evaluate()
1278
+ args_list[eval_func_param_names.index('prompt_type')] = prompt_type1
1279
+ args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1
 
 
 
1280
 
1281
  chat1 = args_list[eval_func_param_names.index('chat')]
1282
  model_max_length1 = get_model_max_length(model_state1)
 
1350
  for res in get_response(fun1, history):
1351
  yield res
1352
  finally:
1353
+ clear_torch_cache()
1354
  clear_embeddings(langchain_mode1, my_db_state1)
1355
 
1356
  def all_bot(*args, retry=False, model_states1=None):
 
1364
  my_db_state1 = None # will be filled below by some bot
1365
  try:
1366
  gen_list = []
1367
+ for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
1368
  args_list1 = args_list0.copy()
1369
  args_list1.insert(-1, model_state1) # insert at -1 so is at -2
1370
  # if at start, have None in response still, replace with '' so client etc. acts like normal
 
1376
  # so consistent with prep_bot()
1377
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1378
  # langchain_mode1 and my_db_state1 should be same for every bot
1379
+ history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry,
1380
+ which_model=chatboti)
1381
  gen1 = get_response(fun1, history)
1382
  if stream_output1:
1383
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
 
1389
  tgen0 = time.time()
1390
  for res1 in itertools.zip_longest(*gen_list):
1391
  if time.time() - tgen0 > max_time1:
1392
+ print("Took too long: %s" % max_time1, flush=True)
1393
  break
1394
 
1395
  bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
 
1824
 
1825
  def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit,
1826
  infer_devices, gpu_id):
1827
+ # ensure no API calls reach here
1828
+ if is_public:
1829
+ raise RuntimeError("Illegal access for %s" % model_name)
1830
  # ensure old model removed from GPU memory
1831
  if kwargs['debug']:
1832
  print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
 
2253
  clear_torch_cache()
2254
 
2255
 
2256
+ def get_lock_file(db1, langchain_mode):
2257
+ assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str)
2258
+ user_id = db1[1]
2259
+ base_path = 'locks'
2260
+ makedirs(base_path)
2261
+ lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
2262
+ return lock_file
2263
+
2264
+
2265
  def _update_user_db(file, db1, x, y, chunk, chunk_size, dbs=None, db_type=None, langchain_mode='UserData',
2266
  user_path=None,
2267
  use_openai_embedding=None,
 
2271
  captions_model=None,
2272
  enable_ocr=None,
2273
  verbose=None,
2274
+ is_url=None, is_txt=None,
2275
+ n_jobs=-1):
2276
  assert use_openai_embedding is not None
2277
  assert hf_embedding_model is not None
2278
  assert caption_loader is not None
 
2313
  print("Adding %s" % file, flush=True)
2314
  sources = path_to_docs(file if not is_url and not is_txt else None,
2315
  verbose=verbose,
2316
+ n_jobs=n_jobs,
2317
  chunk=chunk, chunk_size=chunk_size,
2318
  url=file if is_url else None,
2319
  text=file if is_txt else None,
 
2325
  exceptions = [x for x in sources if x.metadata.get('exception')]
2326
  sources = [x for x in sources if 'exception' not in x.metadata]
2327
 
2328
+ lock_file = get_lock_file(db1, langchain_mode)
2329
+ with filelock.FileLock(lock_file):
2330
  if langchain_mode == 'MyData':
2331
  if db1[0] is not None:
2332
  # then add
 
2339
  # for production hit, when user gets clicky:
2340
  assert len(db1) == 2, "Bad MyData db: %s" % db1
2341
  # then create
 
2342
  # if added has to original state and didn't change, then would be shared db for all users
 
2343
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
2344
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2345
  db_type=db_type,
2346
  persist_directory=persist_directory,
2347
  langchain_mode=langchain_mode,
2348
  hf_embedding_model=hf_embedding_model)
2349
+ if db is not None:
 
 
2350
  db1[0] = db
2351
  source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
2352
  return None, langchain_mode, db1, x, y, source_files_added
 
2374
 
2375
 
2376
  def get_db(db1, langchain_mode, dbs=None):
2377
+ lock_file = get_lock_file(db1, langchain_mode)
2378
+
2379
+ with filelock.FileLock(lock_file):
2380
  if langchain_mode in ['wiki_full']:
2381
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2382
  db = None
gradio_utils/__pycache__/grclient.cpython-310.pyc CHANGED
Binary files a/gradio_utils/__pycache__/grclient.cpython-310.pyc and b/gradio_utils/__pycache__/grclient.cpython-310.pyc differ
 
gradio_utils/__pycache__/prompt_form.cpython-310.pyc CHANGED
Binary files a/gradio_utils/__pycache__/prompt_form.cpython-310.pyc and b/gradio_utils/__pycache__/prompt_form.cpython-310.pyc differ
 
gradio_utils/prompt_form.py CHANGED
@@ -95,11 +95,15 @@ def make_chatbots(output_label0, output_label0_model2, **kwargs):
95
  return text_output, text_output2, text_outputs
96
 
97
 
98
- def make_prompt_form(kwargs):
 
 
 
 
99
  if kwargs['input_lines'] > 1:
100
- instruction_label = "Shift-Enter to Submit, Enter for more lines"
101
  else:
102
- instruction_label = "Enter to Submit, Shift-Enter for more lines"
103
 
104
  with gr.Row():#elem_id='prompt-form-area'):
105
  with gr.Column(scale=50):
 
95
  return text_output, text_output2, text_outputs
96
 
97
 
98
+ def make_prompt_form(kwargs, LangChainMode):
99
+ if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
100
+ extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
101
+ else:
102
+ extra_prompt_form = ""
103
  if kwargs['input_lines'] > 1:
104
+ instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
105
  else:
106
+ instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
107
 
108
  with gr.Row():#elem_id='prompt-form-area'):
109
  with gr.Column(scale=50):
h2oai_pipeline.py CHANGED
@@ -136,6 +136,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
 
139
  return records
140
 
141
  def _forward(self, model_inputs, **generate_kwargs):
 
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
139
+ print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
140
  return records
141
 
142
  def _forward(self, model_inputs, **generate_kwargs):
prompter.py CHANGED
@@ -120,7 +120,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context,
120
  elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
121
  PromptType.custom.name]:
122
  promptA = prompt_dict.get('promptA', '')
123
- promptB = prompt_dict('promptB', '')
124
  PreInstruct = prompt_dict.get('PreInstruct', '')
125
  PreInput = prompt_dict.get('PreInput', '')
126
  PreResponse = prompt_dict.get('PreResponse', '')
@@ -693,7 +693,9 @@ class Prompter(object):
693
  output = clean_response(output)
694
  elif prompt is None:
695
  # then use most basic parsing like pipeline
696
- if self.botstr in output:
 
 
697
  if self.humanstr:
698
  output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
699
  else:
 
120
  elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
121
  PromptType.custom.name]:
122
  promptA = prompt_dict.get('promptA', '')
123
+ promptB = prompt_dict.get('promptB', '')
124
  PreInstruct = prompt_dict.get('PreInstruct', '')
125
  PreInput = prompt_dict.get('PreInput', '')
126
  PreResponse = prompt_dict.get('PreResponse', '')
 
693
  output = clean_response(output)
694
  elif prompt is None:
695
  # then use most basic parsing like pipeline
696
+ if not self.botstr:
697
+ pass
698
+ elif self.botstr in output:
699
  if self.humanstr:
700
  output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
701
  else: