barunsaha commited on
Commit
69fbdcb
1 Parent(s): 0b846c8

Allow users to provide their own HF access token/API key

Browse files
Files changed (3) hide show
  1. app.py +38 -26
  2. global_config.py +4 -3
  3. helpers/llm_helper.py +67 -72
app.py CHANGED
@@ -54,19 +54,6 @@ def _get_prompt_template(is_refinement: bool) -> str:
54
  return template
55
 
56
 
57
- @st.cache_resource
58
- def _get_llm(repo_id: str, max_new_tokens: int):
59
- """
60
- Get an LLM instance.
61
-
62
- :param repo_id: The model name.
63
- :param max_new_tokens: The max new tokens to generate.
64
- :return: The LLM.
65
- """
66
-
67
- return llm_helper.get_hf_endpoint(repo_id, max_new_tokens)
68
-
69
-
70
  APP_TEXT = _load_strings()
71
 
72
  # Session variables
@@ -81,18 +68,35 @@ texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
81
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
82
 
83
  with st.sidebar:
 
84
  pptx_template = st.sidebar.radio(
85
- 'Select a presentation template:',
86
  texts,
87
  captions=captions,
88
  horizontal=True
89
  )
90
- st.divider()
91
- llm_to_use = st.sidebar.selectbox(
92
- 'Select an LLM to use:',
93
- [f'{k} ({v["description"]})' for k, v in GlobalConfig.HF_MODELS.items()]
 
 
 
 
 
 
94
  ).split(' ')[0]
95
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def build_ui():
98
  """
@@ -101,9 +105,9 @@ def build_ui():
101
 
102
  st.title(APP_TEXT['app_name'])
103
  st.subheader(APP_TEXT['caption'])
104
- st.markdown(
105
- '![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbarunsaha%2Fslide-deck-ai&countColor=%23263759)' # noqa: E501
106
- )
107
 
108
  with st.expander('Usage Policies and Limitations'):
109
  st.text(APP_TEXT['tos'] + '\n\n' + APP_TEXT['tos2'])
@@ -162,9 +166,15 @@ def set_up_chat_ui():
162
  )
163
  return
164
 
 
 
 
 
 
 
165
  logger.info(
166
  'User input: %s | #characters: %d | LLM: %s',
167
- prompt, len(prompt), llm_to_use
168
  )
169
  st.chat_message('user').write(prompt)
170
 
@@ -193,15 +203,17 @@ def set_up_chat_ui():
193
  response = ''
194
 
195
  try:
196
- for chunk in _get_llm(
197
- repo_id=llm_to_use,
198
- max_new_tokens=GlobalConfig.HF_MODELS[llm_to_use]['max_new_tokens']
 
 
199
  ).stream(formatted_template):
200
  response += chunk
201
 
202
  # Update the progress bar
203
  progress_percentage = min(
204
- len(response) / GlobalConfig.HF_MODELS[llm_to_use]['max_new_tokens'], 0.95
205
  )
206
  progress_bar.progress(
207
  progress_percentage,
 
54
  return template
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  APP_TEXT = _load_strings()
58
 
59
  # Session variables
 
68
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
69
 
70
  with st.sidebar:
71
+ # The PPT templates
72
  pptx_template = st.sidebar.radio(
73
+ '1: Select a presentation template:',
74
  texts,
75
  captions=captions,
76
  horizontal=True
77
  )
78
+
79
+ # The LLMs
80
+ llm_provider_to_use = st.sidebar.selectbox(
81
+ label='2: Select an LLM to use:',
82
+ options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
83
+ index=0,
84
+ help=(
85
+ 'LLM provider codes:\n\n'
86
+ '- **[hf]**: Hugging Face Inference Endpoint\n'
87
+ ),
88
  ).split(' ')[0]
89
 
90
+ # The API key/access token
91
+ api_key_token = st.text_input(
92
+ label=(
93
+ '3: Paste your API key/access token:\n\n'
94
+ '*Optional* if an HF Mistral LLM is selected from the list but still encouraged.\n\n'
95
+ ),
96
+ type='password',
97
+ )
98
+ st.caption('(Wrong HF access token will lead to validation error)')
99
+
100
 
101
  def build_ui():
102
  """
 
105
 
106
  st.title(APP_TEXT['app_name'])
107
  st.subheader(APP_TEXT['caption'])
108
+ # st.markdown(
109
+ # '![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbarunsaha%2Fslide-deck-ai&countColor=%23263759)' # noqa: E501
110
+ # )
111
 
112
  with st.expander('Usage Policies and Limitations'):
113
  st.text(APP_TEXT['tos'] + '\n\n' + APP_TEXT['tos2'])
 
166
  )
167
  return
168
 
169
+ provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
170
+
171
+ if not provider or not llm_name:
172
+ st.error('No valid LLM provider and/or model name found!')
173
+ return
174
+
175
  logger.info(
176
  'User input: %s | #characters: %d | LLM: %s',
177
+ prompt, len(prompt), llm_name
178
  )
179
  st.chat_message('user').write(prompt)
180
 
 
203
  response = ''
204
 
205
  try:
206
+ for chunk in llm_helper.get_langchain_llm(
207
+ provider=provider,
208
+ model=llm_name,
209
+ max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
210
+ api_key=api_key_token.strip(),
211
  ).stream(formatted_template):
212
  response += chunk
213
 
214
  # Update the progress bar
215
  progress_percentage = min(
216
+ len(response) / GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'], 0.95
217
  )
218
  progress_bar.progress(
219
  progress_percentage,
global_config.py CHANGED
@@ -17,12 +17,13 @@ class GlobalConfig:
17
  A data class holding the configurations.
18
  """
19
 
20
- HF_MODELS = {
21
- 'mistralai/Mistral-7B-Instruct-v0.2': {
 
22
  'description': 'faster, shorter',
23
  'max_new_tokens': 8192
24
  },
25
- 'mistralai/Mistral-Nemo-Instruct-2407': {
26
  'description': 'longer response',
27
  'max_new_tokens': 12228
28
  },
 
17
  A data class holding the configurations.
18
  """
19
 
20
+ VALID_PROVIDERS = {'hf'}
21
+ VALID_MODELS = {
22
+ '[hf]mistralai/Mistral-7B-Instruct-v0.2': {
23
  'description': 'faster, shorter',
24
  'max_new_tokens': 8192
25
  },
26
+ '[hf]mistralai/Mistral-Nemo-Instruct-2407': {
27
  'description': 'longer response',
28
  'max_new_tokens': 12228
29
  },
helpers/llm_helper.py CHANGED
@@ -1,4 +1,7 @@
1
  import logging
 
 
 
2
  import requests
3
  from requests.adapters import HTTPAdapter
4
  from urllib3.util import Retry
@@ -9,7 +12,8 @@ from langchain_core.language_models import LLM
9
  from global_config import GlobalConfig
10
 
11
 
12
- HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
 
13
  REQUEST_TIMEOUT = 35
14
 
15
  logger = logging.getLogger(__name__)
@@ -27,12 +31,31 @@ http_session.mount('https://', adapter)
27
  http_session.mount('http://', adapter)
28
 
29
 
30
- def get_hf_endpoint(repo_id: str, max_new_tokens: int) -> LLM:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
  Get an LLM via the HuggingFaceEndpoint of LangChain.
33
 
34
  :param repo_id: The model name.
35
  :param max_new_tokens: The max new tokens to generate.
 
36
  :return: The HF LLM inference endpoint.
37
  """
38
 
@@ -46,82 +69,54 @@ def get_hf_endpoint(repo_id: str, max_new_tokens: int) -> LLM:
46
  temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
47
  repetition_penalty=1.03,
48
  streaming=True,
49
- huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
50
  return_full_text=False,
51
  stop_sequences=['</s>'],
52
  )
53
 
54
 
55
- # def hf_api_query(payload: dict) -> dict:
56
- # """
57
- # Invoke HF inference end-point API.
58
- #
59
- # :param payload: The prompt for the LLM and related parameters.
60
- # :return: The output from the LLM.
61
- # """
62
- #
63
- # try:
64
- # response = http_session.post(
65
- # HF_API_URL,
66
- # headers=HF_API_HEADERS,
67
- # json=payload,
68
- # timeout=REQUEST_TIMEOUT
69
- # )
70
- # result = response.json()
71
- # except requests.exceptions.Timeout as te:
72
- # logger.error('*** Error: hf_api_query timeout! %s', str(te))
73
- # result = []
74
- #
75
- # return result
76
-
77
-
78
- # def generate_slides_content(topic: str) -> str:
79
- # """
80
- # Generate the outline/contents of slides for a presentation on a given topic.
81
- #
82
- # :param topic: Topic on which slides are to be generated.
83
- # :return: The content in JSON format.
84
- # """
85
- #
86
- # with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
87
- # template_txt = in_file.read().strip()
88
- # template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
89
- #
90
- # output = hf_api_query({
91
- # 'inputs': template_txt,
92
- # 'parameters': {
93
- # 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
94
- # 'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
95
- # 'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
96
- # 'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
97
- # 'num_return_sequences': 1,
98
- # 'return_full_text': False,
99
- # # "repetition_penalty": 0.0001
100
- # },
101
- # 'options': {
102
- # 'wait_for_model': True,
103
- # 'use_cache': True
104
- # }
105
- # })
106
- #
107
- # output = output[0]['generated_text'].strip()
108
- # # output = output[len(template_txt):]
109
- #
110
- # json_end_idx = output.rfind('```')
111
- # if json_end_idx != -1:
112
- # # logging.debug(f'{json_end_idx=}')
113
- # output = output[:json_end_idx]
114
- #
115
- # logger.debug('generate_slides_content: output: %s', output)
116
- #
117
- # return output
118
 
119
 
120
  if __name__ == '__main__':
121
- # results = get_related_websites('5G AI WiFi 6')
122
- #
123
- # for a_result in results.results:
124
- # print(a_result.title, a_result.url, a_result.extract)
125
 
126
- # get_ai_image('A talk on AI, covering pros and cons')
127
- pass
 
1
  import logging
2
+ import re
3
+ from typing import Tuple, Union
4
+
5
  import requests
6
  from requests.adapters import HTTPAdapter
7
  from urllib3.util import Retry
 
12
  from global_config import GlobalConfig
13
 
14
 
15
+ LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
16
+ HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
17
  REQUEST_TIMEOUT = 35
18
 
19
  logger = logging.getLogger(__name__)
 
31
  http_session.mount('http://', adapter)
32
 
33
 
34
+ def get_provider_model(provider_model: str) -> Tuple[str, str]:
35
+ """
36
+ Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
37
+
38
+ :param provider_model: The provider, model name string from `GlobalConfig`.
39
+ :return: The provider and the model name.
40
+ """
41
+
42
+ match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
43
+
44
+ if match:
45
+ inside_brackets = match.group(1)
46
+ outside_brackets = match.group(2)
47
+ return inside_brackets, outside_brackets
48
+
49
+ return '', ''
50
+
51
+
52
+ def get_hf_endpoint(repo_id: str, max_new_tokens: int, api_key: str = '') -> LLM:
53
  """
54
  Get an LLM via the HuggingFaceEndpoint of LangChain.
55
 
56
  :param repo_id: The model name.
57
  :param max_new_tokens: The max new tokens to generate.
58
+ :param api_key: [Optional] Hugging Face access token.
59
  :return: The HF LLM inference endpoint.
60
  """
61
 
 
69
  temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
70
  repetition_penalty=1.03,
71
  streaming=True,
72
+ huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
73
  return_full_text=False,
74
  stop_sequences=['</s>'],
75
  )
76
 
77
 
78
+ def get_langchain_llm(
79
+ provider: str,
80
+ model: str,
81
+ max_new_tokens: int,
82
+ api_key: str = ''
83
+ ) -> Union[LLM, None]:
84
+ """
85
+ Get an LLM based on the provider and model specified.
86
+
87
+ :param provider: The LLM provider. Valid values are `hf` for Hugging Face.
88
+ :param model:
89
+ :param max_new_tokens:
90
+ :param api_key:
91
+ :return:
92
+ """
93
+ if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
94
+ return None
95
+
96
+ if provider == 'hf':
97
+ logger.debug('Getting LLM via HF endpoint: %s', model)
98
+
99
+ return HuggingFaceEndpoint(
100
+ repo_id=model,
101
+ max_new_tokens=max_new_tokens,
102
+ top_k=40,
103
+ top_p=0.95,
104
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
105
+ repetition_penalty=1.03,
106
+ streaming=True,
107
+ huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
108
+ return_full_text=False,
109
+ stop_sequences=['</s>'],
110
+ )
111
+
112
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  if __name__ == '__main__':
116
+ inputs = [
117
+ '[hf]mistralai/Mistral-7B-Instruct-v0.2',
118
+ '[gg]gemini-1.5-flash-002'
119
+ ]
120
 
121
+ for text in inputs:
122
+ print(get_provider_model(text))