barunsaha commited on
Commit
4184417
1 Parent(s): 2551512

Add support for offline LLMs via Ollama

Browse files
Files changed (4) hide show
  1. app.py +59 -26
  2. global_config.py +22 -2
  3. helpers/llm_helper.py +30 -9
  4. requirements.txt +6 -1
app.py CHANGED
@@ -3,23 +3,34 @@ Streamlit app containing the UI and the application logic.
3
  """
4
  import datetime
5
  import logging
 
6
  import pathlib
7
  import random
8
  import tempfile
9
  from typing import List, Union
10
 
 
11
  import huggingface_hub
12
  import json5
 
13
  import requests
14
  import streamlit as st
 
15
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
16
  from langchain_core.messages import HumanMessage
17
  from langchain_core.prompts import ChatPromptTemplate
18
 
 
19
  from global_config import GlobalConfig
20
  from helpers import llm_helper, pptx_helper, text_helper
21
 
22
 
 
 
 
 
 
 
23
  @st.cache_data
24
  def _load_strings() -> dict:
25
  """
@@ -135,25 +146,36 @@ with st.sidebar:
135
  horizontal=True
136
  )
137
 
138
- # The LLMs
139
- llm_provider_to_use = st.sidebar.selectbox(
140
- label='2: Select an LLM to use:',
141
- options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
142
- index=GlobalConfig.DEFAULT_MODEL_INDEX,
143
- help=GlobalConfig.LLM_PROVIDER_HELP,
144
- on_change=reset_api_key
145
- ).split(' ')[0]
146
-
147
- # The API key/access token
148
- api_key_token = st.text_input(
149
- label=(
150
- '3: Paste your API key/access token:\n\n'
151
- '*Mandatory* for Cohere and Gemini LLMs.'
152
- ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
153
- ),
154
- type='password',
155
- key='api_key_input'
156
- )
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
  def build_ui():
@@ -200,7 +222,11 @@ def set_up_chat_ui():
200
  placeholder=APP_TEXT['chat_placeholder'],
201
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
202
  ):
203
- provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
 
 
 
 
204
 
205
  if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
206
  return
@@ -233,7 +259,7 @@ def set_up_chat_ui():
233
  llm = llm_helper.get_langchain_llm(
234
  provider=provider,
235
  model=llm_name,
236
- max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
237
  api_key=api_key_token.strip(),
238
  )
239
 
@@ -252,18 +278,17 @@ def set_up_chat_ui():
252
  # Update the progress bar with an approx progress percentage
253
  progress_bar.progress(
254
  min(
255
- len(response) / GlobalConfig.VALID_MODELS[
256
- llm_provider_to_use
257
- ]['max_new_tokens'],
258
  0.95
259
  ),
260
  text='Streaming content...this might take a while...'
261
  )
262
- except requests.exceptions.ConnectionError:
263
  handle_error(
264
  'A connection error occurred while streaming content from the LLM endpoint.'
265
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
266
- ' Alternatively, try selecting a different LLM from the dropdown list.',
 
267
  True
268
  )
269
  return
@@ -274,6 +299,14 @@ def set_up_chat_ui():
274
  True
275
  )
276
  return
 
 
 
 
 
 
 
 
277
  except Exception as ex:
278
  handle_error(
279
  f'An unexpected error occurred while generating the content: {ex}'
 
3
  """
4
  import datetime
5
  import logging
6
+ import os
7
  import pathlib
8
  import random
9
  import tempfile
10
  from typing import List, Union
11
 
12
+ import httpx
13
  import huggingface_hub
14
  import json5
15
+ import ollama
16
  import requests
17
  import streamlit as st
18
+ from dotenv import load_dotenv
19
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
20
  from langchain_core.messages import HumanMessage
21
  from langchain_core.prompts import ChatPromptTemplate
22
 
23
+ import global_config as gcfg
24
  from global_config import GlobalConfig
25
  from helpers import llm_helper, pptx_helper, text_helper
26
 
27
 
28
+ load_dotenv()
29
+
30
+
31
+ RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true'
32
+
33
+
34
  @st.cache_data
35
  def _load_strings() -> dict:
36
  """
 
146
  horizontal=True
147
  )
148
 
149
+ if RUN_IN_OFFLINE_MODE:
150
+ llm_provider_to_use = st.text_input(
151
+ label='2: Enter Ollama model name to use:',
152
+ help=(
153
+ 'Specify a correct, locally available LLM, found by running `ollama list`, for'
154
+ ' example `mistral:v0.2` and `mistral-nemo:latest`. Having an Ollama-compatible'
155
+ ' and supported GPU is strongly recommended.'
156
+ )
157
+ )
158
+ api_key_token: str = ''
159
+ else:
160
+ # The LLMs
161
+ llm_provider_to_use = st.sidebar.selectbox(
162
+ label='2: Select an LLM to use:',
163
+ options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
164
+ index=GlobalConfig.DEFAULT_MODEL_INDEX,
165
+ help=GlobalConfig.LLM_PROVIDER_HELP,
166
+ on_change=reset_api_key
167
+ ).split(' ')[0]
168
+
169
+ # The API key/access token
170
+ api_key_token = st.text_input(
171
+ label=(
172
+ '3: Paste your API key/access token:\n\n'
173
+ '*Mandatory* for Cohere and Gemini LLMs.'
174
+ ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
175
+ ),
176
+ type='password',
177
+ key='api_key_input'
178
+ )
179
 
180
 
181
  def build_ui():
 
222
  placeholder=APP_TEXT['chat_placeholder'],
223
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
224
  ):
225
+ provider, llm_name = llm_helper.get_provider_model(
226
+ llm_provider_to_use,
227
+ use_ollama=RUN_IN_OFFLINE_MODE
228
+ )
229
+ print(f'{llm_provider_to_use=}, {provider=}, {llm_name=}, {api_key_token=}')
230
 
231
  if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
232
  return
 
259
  llm = llm_helper.get_langchain_llm(
260
  provider=provider,
261
  model=llm_name,
262
+ max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
263
  api_key=api_key_token.strip(),
264
  )
265
 
 
278
  # Update the progress bar with an approx progress percentage
279
  progress_bar.progress(
280
  min(
281
+ len(response) / gcfg.get_max_output_tokens(llm_provider_to_use),
 
 
282
  0.95
283
  ),
284
  text='Streaming content...this might take a while...'
285
  )
286
+ except (httpx.ConnectError, requests.exceptions.ConnectionError):
287
  handle_error(
288
  'A connection error occurred while streaming content from the LLM endpoint.'
289
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
290
+ ' Alternatively, try selecting a different LLM from the dropdown list. If you are'
291
+ ' using Ollama, make sure that Ollama is already running on your system.',
292
  True
293
  )
294
  return
 
299
  True
300
  )
301
  return
302
+ except ollama.ResponseError:
303
+ handle_error(
304
+ f'The model `{llm_name}` is unavailable with Ollama on your system.'
305
+ f' Make sure that you have provided the correct LLM name or pull it using'
306
+ f' `ollama pull {llm_name}`. View LLMs available locally by running `ollama list`.',
307
+ True
308
+ )
309
+ return
310
  except Exception as ex:
311
  handle_error(
312
  f'An unexpected error occurred while generating the content: {ex}'
global_config.py CHANGED
@@ -20,7 +20,13 @@ class GlobalConfig:
20
  PROVIDER_COHERE = 'co'
21
  PROVIDER_GOOGLE_GEMINI = 'gg'
22
  PROVIDER_HUGGING_FACE = 'hf'
23
- VALID_PROVIDERS = {PROVIDER_COHERE, PROVIDER_GOOGLE_GEMINI, PROVIDER_HUGGING_FACE}
 
 
 
 
 
 
24
  VALID_MODELS = {
25
  '[co]command-r-08-2024': {
26
  'description': 'simpler, slower',
@@ -47,7 +53,7 @@ class GlobalConfig:
47
  'LLM provider codes:\n\n'
48
  '- **[co]**: Cohere\n'
49
  '- **[gg]**: Google Gemini API\n'
50
- '- **[hf]**: Hugging Face Inference Endpoint\n'
51
  )
52
  DEFAULT_MODEL_INDEX = 2
53
  LLM_MODEL_TEMPERATURE = 0.2
@@ -125,3 +131,17 @@ logging.basicConfig(
125
  format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
126
  datefmt='%Y-%m-%d %H:%M:%S'
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  PROVIDER_COHERE = 'co'
21
  PROVIDER_GOOGLE_GEMINI = 'gg'
22
  PROVIDER_HUGGING_FACE = 'hf'
23
+ PROVIDER_OLLAMA = 'ol'
24
+ VALID_PROVIDERS = {
25
+ PROVIDER_COHERE,
26
+ PROVIDER_GOOGLE_GEMINI,
27
+ PROVIDER_HUGGING_FACE,
28
+ PROVIDER_OLLAMA
29
+ }
30
  VALID_MODELS = {
31
  '[co]command-r-08-2024': {
32
  'description': 'simpler, slower',
 
53
  'LLM provider codes:\n\n'
54
  '- **[co]**: Cohere\n'
55
  '- **[gg]**: Google Gemini API\n'
56
+ '- **[hf]**: Hugging Face Inference API\n'
57
  )
58
  DEFAULT_MODEL_INDEX = 2
59
  LLM_MODEL_TEMPERATURE = 0.2
 
131
  format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
132
  datefmt='%Y-%m-%d %H:%M:%S'
133
  )
134
+
135
+
136
+ def get_max_output_tokens(llm_name: str) -> int:
137
+ """
138
+ Get the max output tokens value configured for an LLM. Return a default value if not configured.
139
+
140
+ :param llm_name: The name of the LLM.
141
+ :return: Max output tokens or a default count.
142
+ """
143
+
144
+ try:
145
+ return GlobalConfig.VALID_MODELS[llm_name]['max_new_tokens']
146
+ except KeyError:
147
+ return 2048
helpers/llm_helper.py CHANGED
@@ -17,8 +17,9 @@ from global_config import GlobalConfig
17
 
18
 
19
  LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
 
20
  # 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
21
- API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9\-_]{6,64}$')
22
  HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
23
  REQUEST_TIMEOUT = 35
24
 
@@ -39,20 +40,28 @@ http_session.mount('https://', adapter)
39
  http_session.mount('http://', adapter)
40
 
41
 
42
- def get_provider_model(provider_model: str) -> Tuple[str, str]:
43
  """
44
  Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
45
 
46
  :param provider_model: The provider, model name string from `GlobalConfig`.
47
- :return: The provider and the model name.
 
48
  """
49
 
50
- match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
51
 
52
- if match:
53
- inside_brackets = match.group(1)
54
- outside_brackets = match.group(2)
55
- return inside_brackets, outside_brackets
 
 
 
 
 
 
 
56
 
57
  return '', ''
58
 
@@ -152,6 +161,18 @@ def get_langchain_llm(
152
  streaming=True,
153
  )
154
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return None
156
 
157
 
@@ -163,4 +184,4 @@ if __name__ == '__main__':
163
  ]
164
 
165
  for text in inputs:
166
- print(get_provider_model(text))
 
17
 
18
 
19
  LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
20
+ OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
21
  # 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
22
+ API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,64}$')
23
  HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
24
  REQUEST_TIMEOUT = 35
25
 
 
40
  http_session.mount('http://', adapter)
41
 
42
 
43
+ def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]:
44
  """
45
  Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
46
 
47
  :param provider_model: The provider, model name string from `GlobalConfig`.
48
+ :param use_ollama: Whether Ollama is used (i.e., running in offline mode).
49
+ :return: The provider and the model name; empty strings in case no matching pattern found.
50
  """
51
 
52
+ provider_model = provider_model.strip()
53
 
54
+ if use_ollama:
55
+ match = OLLAMA_MODEL_REGEX.match(provider_model)
56
+ if match:
57
+ return GlobalConfig.PROVIDER_OLLAMA, match.group(0)
58
+ else:
59
+ match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
60
+
61
+ if match:
62
+ inside_brackets = match.group(1)
63
+ outside_brackets = match.group(2)
64
+ return inside_brackets, outside_brackets
65
 
66
  return '', ''
67
 
 
161
  streaming=True,
162
  )
163
 
164
+ if provider == GlobalConfig.PROVIDER_OLLAMA:
165
+ from langchain_ollama.llms import OllamaLLM
166
+
167
+ logger.debug('Getting LLM via Ollama: %s', model)
168
+ return OllamaLLM(
169
+ model=model,
170
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
171
+ num_predict=max_new_tokens,
172
+ format='json',
173
+ streaming=True,
174
+ )
175
+
176
  return None
177
 
178
 
 
184
  ]
185
 
186
  for text in inputs:
187
+ print(get_provider_model(text, use_ollama=False))
requirements.txt CHANGED
@@ -12,9 +12,10 @@ langchain-core~=0.3.0
12
  langchain-community==0.3.0
13
  langchain-google-genai==2.0.6
14
  langchain-cohere==0.3.3
 
15
  streamlit~=1.38.0
16
 
17
- python-pptx
18
  # metaphor-python
19
  json5~=0.9.14
20
  requests~=2.32.3
@@ -32,3 +33,7 @@ certifi==2024.8.30
32
  urllib3==2.2.3
33
 
34
  anyio==4.4.0
 
 
 
 
 
12
  langchain-community==0.3.0
13
  langchain-google-genai==2.0.6
14
  langchain-cohere==0.3.3
15
+ langchain-ollama==0.2.1
16
  streamlit~=1.38.0
17
 
18
+ python-pptx~=0.6.21
19
  # metaphor-python
20
  json5~=0.9.14
21
  requests~=2.32.3
 
33
  urllib3==2.2.3
34
 
35
  anyio==4.4.0
36
+
37
+ httpx~=0.27.2
38
+ huggingface-hub~=0.24.5
39
+ ollama~=0.4.3