barunsaha commited on
Commit
813ce6e
·
1 Parent(s): c4c876e

Allow users to choose from two different Mistral models

Browse files
Files changed (3) hide show
  1. app.py +22 -11
  2. global_config.py +10 -2
  3. helpers/llm_helper.py +70 -69
app.py CHANGED
@@ -20,7 +20,6 @@ from langchain_core.prompts import ChatPromptTemplate
20
  sys.path.append('..')
21
  sys.path.append('../..')
22
 
23
- import helpers.icons_embeddings as ice
24
  from global_config import GlobalConfig
25
  from helpers import llm_helper, pptx_helper, text_helper
26
 
@@ -56,14 +55,16 @@ def _get_prompt_template(is_refinement: bool) -> str:
56
 
57
 
58
  @st.cache_resource
59
- def _get_llm():
60
  """
61
  Get an LLM instance.
62
 
 
 
63
  :return: The LLM.
64
  """
65
 
66
- return llm_helper.get_hf_endpoint()
67
 
68
 
69
  APP_TEXT = _load_strings()
@@ -78,12 +79,19 @@ logger = logging.getLogger(__name__)
78
 
79
  texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
80
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
81
- pptx_template = st.sidebar.radio(
82
- 'Select a presentation template:',
83
- texts,
84
- captions=captions,
85
- horizontal=True
86
- )
 
 
 
 
 
 
 
87
 
88
 
89
  def build_ui():
@@ -187,12 +195,15 @@ def set_up_chat_ui():
187
  response = ''
188
 
189
  try:
190
- for chunk in _get_llm().stream(formatted_template):
 
 
 
191
  response += chunk
192
 
193
  # Update the progress bar
194
  progress_percentage = min(
195
- len(response) / GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH, 0.95
196
  )
197
  progress_bar.progress(
198
  progress_percentage,
 
20
  sys.path.append('..')
21
  sys.path.append('../..')
22
 
 
23
  from global_config import GlobalConfig
24
  from helpers import llm_helper, pptx_helper, text_helper
25
 
 
55
 
56
 
57
  @st.cache_resource
58
+ def _get_llm(repo_id: str, max_new_tokens: int):
59
  """
60
  Get an LLM instance.
61
 
62
+ :param repo_id: The model name.
63
+ :param max_new_tokens: The max new tokens to generate.
64
  :return: The LLM.
65
  """
66
 
67
+ return llm_helper.get_hf_endpoint(repo_id, max_new_tokens)
68
 
69
 
70
  APP_TEXT = _load_strings()
 
79
 
80
  texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
81
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
82
+
83
+ with st.sidebar:
84
+ pptx_template = st.sidebar.radio(
85
+ 'Select a presentation template:',
86
+ texts,
87
+ captions=captions,
88
+ horizontal=True
89
+ )
90
+ st.divider()
91
+ llm_to_use = st.sidebar.selectbox(
92
+ 'Select an LLM to use:',
93
+ [f'{k} ({v["description"]})' for k, v in GlobalConfig.HF_MODELS.items()]
94
+ ).split(' ')[0]
95
 
96
 
97
  def build_ui():
 
195
  response = ''
196
 
197
  try:
198
+ for chunk in _get_llm(
199
+ repo_id=llm_to_use,
200
+ max_new_tokens=GlobalConfig.HF_MODELS[llm_to_use]['max_new_tokens']
201
+ ).stream(formatted_template):
202
  response += chunk
203
 
204
  # Update the progress bar
205
  progress_percentage = min(
206
+ len(response) / GlobalConfig.HF_MODELS[llm_to_use]['max_new_tokens'], 0.95
207
  )
208
  progress_bar.progress(
209
  progress_percentage,
global_config.py CHANGED
@@ -17,10 +17,18 @@ class GlobalConfig:
17
  A data class holding the configurations.
18
  """
19
 
20
- HF_LLM_MODEL_NAME = 'mistralai/Mistral-Nemo-Instruct-2407'
 
 
 
 
 
 
 
 
 
21
  LLM_MODEL_TEMPERATURE = 0.2
22
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
23
- LLM_MODEL_MAX_OUTPUT_LENGTH = 4 * 4096 # tokens
24
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
25
 
26
  HUGGINGFACEHUB_API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN', '')
 
17
  A data class holding the configurations.
18
  """
19
 
20
+ HF_MODELS = {
21
+ 'mistralai/Mistral-Nemo-Instruct-2407': {
22
+ 'description': 'longer response',
23
+ 'max_new_tokens': 12228
24
+ },
25
+ 'mistralai/Mistral-7B-Instruct-v0.2': {
26
+ 'description': 'faster, shorter',
27
+ 'max_new_tokens': 8192
28
+ },
29
+ }
30
  LLM_MODEL_TEMPERATURE = 0.2
31
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
 
32
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
33
 
34
  HUGGINGFACEHUB_API_TOKEN = os.environ.get('HUGGINGFACEHUB_API_TOKEN', '')
helpers/llm_helper.py CHANGED
@@ -9,7 +9,6 @@ from langchain_core.language_models import LLM
9
  from global_config import GlobalConfig
10
 
11
 
12
- HF_API_URL = f"https://api-inference.huggingface.co/models/{GlobalConfig.HF_LLM_MODEL_NAME}"
13
  HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
14
  REQUEST_TIMEOUT = 35
15
 
@@ -28,18 +27,20 @@ http_session.mount('https://', adapter)
28
  http_session.mount('http://', adapter)
29
 
30
 
31
- def get_hf_endpoint() -> LLM:
32
  """
33
  Get an LLM via the HuggingFaceEndpoint of LangChain.
34
 
35
- :return: The LLM.
 
 
36
  """
37
 
38
- logger.debug('Getting LLM via HF endpoint')
39
 
40
  return HuggingFaceEndpoint(
41
- repo_id=GlobalConfig.HF_LLM_MODEL_NAME,
42
- max_new_tokens=GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
43
  top_k=40,
44
  top_p=0.95,
45
  temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
@@ -51,69 +52,69 @@ def get_hf_endpoint() -> LLM:
51
  )
52
 
53
 
54
- def hf_api_query(payload: dict) -> dict:
55
- """
56
- Invoke HF inference end-point API.
57
-
58
- :param payload: The prompt for the LLM and related parameters.
59
- :return: The output from the LLM.
60
- """
61
-
62
- try:
63
- response = http_session.post(
64
- HF_API_URL,
65
- headers=HF_API_HEADERS,
66
- json=payload,
67
- timeout=REQUEST_TIMEOUT
68
- )
69
- result = response.json()
70
- except requests.exceptions.Timeout as te:
71
- logger.error('*** Error: hf_api_query timeout! %s', str(te))
72
- result = []
73
-
74
- return result
75
-
76
-
77
- def generate_slides_content(topic: str) -> str:
78
- """
79
- Generate the outline/contents of slides for a presentation on a given topic.
80
-
81
- :param topic: Topic on which slides are to be generated.
82
- :return: The content in JSON format.
83
- """
84
-
85
- with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
86
- template_txt = in_file.read().strip()
87
- template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
88
-
89
- output = hf_api_query({
90
- 'inputs': template_txt,
91
- 'parameters': {
92
- 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
93
- 'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
94
- 'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
95
- 'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
96
- 'num_return_sequences': 1,
97
- 'return_full_text': False,
98
- # "repetition_penalty": 0.0001
99
- },
100
- 'options': {
101
- 'wait_for_model': True,
102
- 'use_cache': True
103
- }
104
- })
105
-
106
- output = output[0]['generated_text'].strip()
107
- # output = output[len(template_txt):]
108
-
109
- json_end_idx = output.rfind('```')
110
- if json_end_idx != -1:
111
- # logging.debug(f'{json_end_idx=}')
112
- output = output[:json_end_idx]
113
-
114
- logger.debug('generate_slides_content: output: %s', output)
115
-
116
- return output
117
 
118
 
119
  if __name__ == '__main__':
 
9
  from global_config import GlobalConfig
10
 
11
 
 
12
  HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
13
  REQUEST_TIMEOUT = 35
14
 
 
27
  http_session.mount('http://', adapter)
28
 
29
 
30
+ def get_hf_endpoint(repo_id: str, max_new_tokens: int) -> LLM:
31
  """
32
  Get an LLM via the HuggingFaceEndpoint of LangChain.
33
 
34
+ :param repo_id: The model name.
35
+ :param max_new_tokens: The max new tokens to generate.
36
+ :return: The HF LLM inference endpoint.
37
  """
38
 
39
+ logger.debug('Getting LLM via HF endpoint: %s', repo_id)
40
 
41
  return HuggingFaceEndpoint(
42
+ repo_id=repo_id,
43
+ max_new_tokens=max_new_tokens,
44
  top_k=40,
45
  top_p=0.95,
46
  temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
 
52
  )
53
 
54
 
55
+ # def hf_api_query(payload: dict) -> dict:
56
+ # """
57
+ # Invoke HF inference end-point API.
58
+ #
59
+ # :param payload: The prompt for the LLM and related parameters.
60
+ # :return: The output from the LLM.
61
+ # """
62
+ #
63
+ # try:
64
+ # response = http_session.post(
65
+ # HF_API_URL,
66
+ # headers=HF_API_HEADERS,
67
+ # json=payload,
68
+ # timeout=REQUEST_TIMEOUT
69
+ # )
70
+ # result = response.json()
71
+ # except requests.exceptions.Timeout as te:
72
+ # logger.error('*** Error: hf_api_query timeout! %s', str(te))
73
+ # result = []
74
+ #
75
+ # return result
76
+
77
+
78
+ # def generate_slides_content(topic: str) -> str:
79
+ # """
80
+ # Generate the outline/contents of slides for a presentation on a given topic.
81
+ #
82
+ # :param topic: Topic on which slides are to be generated.
83
+ # :return: The content in JSON format.
84
+ # """
85
+ #
86
+ # with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
87
+ # template_txt = in_file.read().strip()
88
+ # template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
89
+ #
90
+ # output = hf_api_query({
91
+ # 'inputs': template_txt,
92
+ # 'parameters': {
93
+ # 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
94
+ # 'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
95
+ # 'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
96
+ # 'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
97
+ # 'num_return_sequences': 1,
98
+ # 'return_full_text': False,
99
+ # # "repetition_penalty": 0.0001
100
+ # },
101
+ # 'options': {
102
+ # 'wait_for_model': True,
103
+ # 'use_cache': True
104
+ # }
105
+ # })
106
+ #
107
+ # output = output[0]['generated_text'].strip()
108
+ # # output = output[len(template_txt):]
109
+ #
110
+ # json_end_idx = output.rfind('```')
111
+ # if json_end_idx != -1:
112
+ # # logging.debug(f'{json_end_idx=}')
113
+ # output = output[:json_end_idx]
114
+ #
115
+ # logger.debug('generate_slides_content: output: %s', output)
116
+ #
117
+ # return output
118
 
119
 
120
  if __name__ == '__main__':