Spaces:
Running
Running
Add support for offline LLMs via Ollama
Browse files- app.py +59 -26
- global_config.py +22 -2
- helpers/llm_helper.py +30 -9
- requirements.txt +6 -1
app.py
CHANGED
@@ -3,23 +3,34 @@ Streamlit app containing the UI and the application logic.
|
|
3 |
"""
|
4 |
import datetime
|
5 |
import logging
|
|
|
6 |
import pathlib
|
7 |
import random
|
8 |
import tempfile
|
9 |
from typing import List, Union
|
10 |
|
|
|
11 |
import huggingface_hub
|
12 |
import json5
|
|
|
13 |
import requests
|
14 |
import streamlit as st
|
|
|
15 |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
16 |
from langchain_core.messages import HumanMessage
|
17 |
from langchain_core.prompts import ChatPromptTemplate
|
18 |
|
|
|
19 |
from global_config import GlobalConfig
|
20 |
from helpers import llm_helper, pptx_helper, text_helper
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
@st.cache_data
|
24 |
def _load_strings() -> dict:
|
25 |
"""
|
@@ -135,25 +146,36 @@ with st.sidebar:
|
|
135 |
horizontal=True
|
136 |
)
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
'
|
152 |
-
'
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
|
159 |
def build_ui():
|
@@ -200,7 +222,11 @@ def set_up_chat_ui():
|
|
200 |
placeholder=APP_TEXT['chat_placeholder'],
|
201 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
202 |
):
|
203 |
-
provider, llm_name = llm_helper.get_provider_model(
|
|
|
|
|
|
|
|
|
204 |
|
205 |
if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
|
206 |
return
|
@@ -233,7 +259,7 @@ def set_up_chat_ui():
|
|
233 |
llm = llm_helper.get_langchain_llm(
|
234 |
provider=provider,
|
235 |
model=llm_name,
|
236 |
-
max_new_tokens=
|
237 |
api_key=api_key_token.strip(),
|
238 |
)
|
239 |
|
@@ -252,18 +278,17 @@ def set_up_chat_ui():
|
|
252 |
# Update the progress bar with an approx progress percentage
|
253 |
progress_bar.progress(
|
254 |
min(
|
255 |
-
len(response) /
|
256 |
-
llm_provider_to_use
|
257 |
-
]['max_new_tokens'],
|
258 |
0.95
|
259 |
),
|
260 |
text='Streaming content...this might take a while...'
|
261 |
)
|
262 |
-
except requests.exceptions.ConnectionError:
|
263 |
handle_error(
|
264 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
265 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
266 |
-
' Alternatively, try selecting a different LLM from the dropdown list.'
|
|
|
267 |
True
|
268 |
)
|
269 |
return
|
@@ -274,6 +299,14 @@ def set_up_chat_ui():
|
|
274 |
True
|
275 |
)
|
276 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
except Exception as ex:
|
278 |
handle_error(
|
279 |
f'An unexpected error occurred while generating the content: {ex}'
|
|
|
3 |
"""
|
4 |
import datetime
|
5 |
import logging
|
6 |
+
import os
|
7 |
import pathlib
|
8 |
import random
|
9 |
import tempfile
|
10 |
from typing import List, Union
|
11 |
|
12 |
+
import httpx
|
13 |
import huggingface_hub
|
14 |
import json5
|
15 |
+
import ollama
|
16 |
import requests
|
17 |
import streamlit as st
|
18 |
+
from dotenv import load_dotenv
|
19 |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
20 |
from langchain_core.messages import HumanMessage
|
21 |
from langchain_core.prompts import ChatPromptTemplate
|
22 |
|
23 |
+
import global_config as gcfg
|
24 |
from global_config import GlobalConfig
|
25 |
from helpers import llm_helper, pptx_helper, text_helper
|
26 |
|
27 |
|
28 |
+
load_dotenv()
|
29 |
+
|
30 |
+
|
31 |
+
RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true'
|
32 |
+
|
33 |
+
|
34 |
@st.cache_data
|
35 |
def _load_strings() -> dict:
|
36 |
"""
|
|
|
146 |
horizontal=True
|
147 |
)
|
148 |
|
149 |
+
if RUN_IN_OFFLINE_MODE:
|
150 |
+
llm_provider_to_use = st.text_input(
|
151 |
+
label='2: Enter Ollama model name to use:',
|
152 |
+
help=(
|
153 |
+
'Specify a correct, locally available LLM, found by running `ollama list`, for'
|
154 |
+
' example `mistral:v0.2` and `mistral-nemo:latest`. Having an Ollama-compatible'
|
155 |
+
' and supported GPU is strongly recommended.'
|
156 |
+
)
|
157 |
+
)
|
158 |
+
api_key_token: str = ''
|
159 |
+
else:
|
160 |
+
# The LLMs
|
161 |
+
llm_provider_to_use = st.sidebar.selectbox(
|
162 |
+
label='2: Select an LLM to use:',
|
163 |
+
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
164 |
+
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
165 |
+
help=GlobalConfig.LLM_PROVIDER_HELP,
|
166 |
+
on_change=reset_api_key
|
167 |
+
).split(' ')[0]
|
168 |
+
|
169 |
+
# The API key/access token
|
170 |
+
api_key_token = st.text_input(
|
171 |
+
label=(
|
172 |
+
'3: Paste your API key/access token:\n\n'
|
173 |
+
'*Mandatory* for Cohere and Gemini LLMs.'
|
174 |
+
' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
|
175 |
+
),
|
176 |
+
type='password',
|
177 |
+
key='api_key_input'
|
178 |
+
)
|
179 |
|
180 |
|
181 |
def build_ui():
|
|
|
222 |
placeholder=APP_TEXT['chat_placeholder'],
|
223 |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
|
224 |
):
|
225 |
+
provider, llm_name = llm_helper.get_provider_model(
|
226 |
+
llm_provider_to_use,
|
227 |
+
use_ollama=RUN_IN_OFFLINE_MODE
|
228 |
+
)
|
229 |
+
print(f'{llm_provider_to_use=}, {provider=}, {llm_name=}, {api_key_token=}')
|
230 |
|
231 |
if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
|
232 |
return
|
|
|
259 |
llm = llm_helper.get_langchain_llm(
|
260 |
provider=provider,
|
261 |
model=llm_name,
|
262 |
+
max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
|
263 |
api_key=api_key_token.strip(),
|
264 |
)
|
265 |
|
|
|
278 |
# Update the progress bar with an approx progress percentage
|
279 |
progress_bar.progress(
|
280 |
min(
|
281 |
+
len(response) / gcfg.get_max_output_tokens(llm_provider_to_use),
|
|
|
|
|
282 |
0.95
|
283 |
),
|
284 |
text='Streaming content...this might take a while...'
|
285 |
)
|
286 |
+
except (httpx.ConnectError, requests.exceptions.ConnectionError):
|
287 |
handle_error(
|
288 |
'A connection error occurred while streaming content from the LLM endpoint.'
|
289 |
' Unfortunately, the slide deck cannot be generated. Please try again later.'
|
290 |
+
' Alternatively, try selecting a different LLM from the dropdown list. If you are'
|
291 |
+
' using Ollama, make sure that Ollama is already running on your system.',
|
292 |
True
|
293 |
)
|
294 |
return
|
|
|
299 |
True
|
300 |
)
|
301 |
return
|
302 |
+
except ollama.ResponseError:
|
303 |
+
handle_error(
|
304 |
+
f'The model `{llm_name}` is unavailable with Ollama on your system.'
|
305 |
+
f' Make sure that you have provided the correct LLM name or pull it using'
|
306 |
+
f' `ollama pull {llm_name}`. View LLMs available locally by running `ollama list`.',
|
307 |
+
True
|
308 |
+
)
|
309 |
+
return
|
310 |
except Exception as ex:
|
311 |
handle_error(
|
312 |
f'An unexpected error occurred while generating the content: {ex}'
|
global_config.py
CHANGED
@@ -20,7 +20,13 @@ class GlobalConfig:
|
|
20 |
PROVIDER_COHERE = 'co'
|
21 |
PROVIDER_GOOGLE_GEMINI = 'gg'
|
22 |
PROVIDER_HUGGING_FACE = 'hf'
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
VALID_MODELS = {
|
25 |
'[co]command-r-08-2024': {
|
26 |
'description': 'simpler, slower',
|
@@ -47,7 +53,7 @@ class GlobalConfig:
|
|
47 |
'LLM provider codes:\n\n'
|
48 |
'- **[co]**: Cohere\n'
|
49 |
'- **[gg]**: Google Gemini API\n'
|
50 |
-
'- **[hf]**: Hugging Face Inference
|
51 |
)
|
52 |
DEFAULT_MODEL_INDEX = 2
|
53 |
LLM_MODEL_TEMPERATURE = 0.2
|
@@ -125,3 +131,17 @@ logging.basicConfig(
|
|
125 |
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
126 |
datefmt='%Y-%m-%d %H:%M:%S'
|
127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
PROVIDER_COHERE = 'co'
|
21 |
PROVIDER_GOOGLE_GEMINI = 'gg'
|
22 |
PROVIDER_HUGGING_FACE = 'hf'
|
23 |
+
PROVIDER_OLLAMA = 'ol'
|
24 |
+
VALID_PROVIDERS = {
|
25 |
+
PROVIDER_COHERE,
|
26 |
+
PROVIDER_GOOGLE_GEMINI,
|
27 |
+
PROVIDER_HUGGING_FACE,
|
28 |
+
PROVIDER_OLLAMA
|
29 |
+
}
|
30 |
VALID_MODELS = {
|
31 |
'[co]command-r-08-2024': {
|
32 |
'description': 'simpler, slower',
|
|
|
53 |
'LLM provider codes:\n\n'
|
54 |
'- **[co]**: Cohere\n'
|
55 |
'- **[gg]**: Google Gemini API\n'
|
56 |
+
'- **[hf]**: Hugging Face Inference API\n'
|
57 |
)
|
58 |
DEFAULT_MODEL_INDEX = 2
|
59 |
LLM_MODEL_TEMPERATURE = 0.2
|
|
|
131 |
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
132 |
datefmt='%Y-%m-%d %H:%M:%S'
|
133 |
)
|
134 |
+
|
135 |
+
|
136 |
+
def get_max_output_tokens(llm_name: str) -> int:
|
137 |
+
"""
|
138 |
+
Get the max output tokens value configured for an LLM. Return a default value if not configured.
|
139 |
+
|
140 |
+
:param llm_name: The name of the LLM.
|
141 |
+
:return: Max output tokens or a default count.
|
142 |
+
"""
|
143 |
+
|
144 |
+
try:
|
145 |
+
return GlobalConfig.VALID_MODELS[llm_name]['max_new_tokens']
|
146 |
+
except KeyError:
|
147 |
+
return 2048
|
helpers/llm_helper.py
CHANGED
@@ -17,8 +17,9 @@ from global_config import GlobalConfig
|
|
17 |
|
18 |
|
19 |
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
|
|
|
20 |
# 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
|
21 |
-
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-
|
22 |
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
|
23 |
REQUEST_TIMEOUT = 35
|
24 |
|
@@ -39,20 +40,28 @@ http_session.mount('https://', adapter)
|
|
39 |
http_session.mount('http://', adapter)
|
40 |
|
41 |
|
42 |
-
def get_provider_model(provider_model: str) -> Tuple[str, str]:
|
43 |
"""
|
44 |
Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
|
45 |
|
46 |
:param provider_model: The provider, model name string from `GlobalConfig`.
|
47 |
-
:
|
|
|
48 |
"""
|
49 |
|
50 |
-
|
51 |
|
52 |
-
if
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
return '', ''
|
58 |
|
@@ -152,6 +161,18 @@ def get_langchain_llm(
|
|
152 |
streaming=True,
|
153 |
)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
return None
|
156 |
|
157 |
|
@@ -163,4 +184,4 @@ if __name__ == '__main__':
|
|
163 |
]
|
164 |
|
165 |
for text in inputs:
|
166 |
-
print(get_provider_model(text))
|
|
|
17 |
|
18 |
|
19 |
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
|
20 |
+
OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
|
21 |
# 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
|
22 |
+
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,64}$')
|
23 |
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
|
24 |
REQUEST_TIMEOUT = 35
|
25 |
|
|
|
40 |
http_session.mount('http://', adapter)
|
41 |
|
42 |
|
43 |
+
def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]:
|
44 |
"""
|
45 |
Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
|
46 |
|
47 |
:param provider_model: The provider, model name string from `GlobalConfig`.
|
48 |
+
:param use_ollama: Whether Ollama is used (i.e., running in offline mode).
|
49 |
+
:return: The provider and the model name; empty strings in case no matching pattern found.
|
50 |
"""
|
51 |
|
52 |
+
provider_model = provider_model.strip()
|
53 |
|
54 |
+
if use_ollama:
|
55 |
+
match = OLLAMA_MODEL_REGEX.match(provider_model)
|
56 |
+
if match:
|
57 |
+
return GlobalConfig.PROVIDER_OLLAMA, match.group(0)
|
58 |
+
else:
|
59 |
+
match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
|
60 |
+
|
61 |
+
if match:
|
62 |
+
inside_brackets = match.group(1)
|
63 |
+
outside_brackets = match.group(2)
|
64 |
+
return inside_brackets, outside_brackets
|
65 |
|
66 |
return '', ''
|
67 |
|
|
|
161 |
streaming=True,
|
162 |
)
|
163 |
|
164 |
+
if provider == GlobalConfig.PROVIDER_OLLAMA:
|
165 |
+
from langchain_ollama.llms import OllamaLLM
|
166 |
+
|
167 |
+
logger.debug('Getting LLM via Ollama: %s', model)
|
168 |
+
return OllamaLLM(
|
169 |
+
model=model,
|
170 |
+
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
|
171 |
+
num_predict=max_new_tokens,
|
172 |
+
format='json',
|
173 |
+
streaming=True,
|
174 |
+
)
|
175 |
+
|
176 |
return None
|
177 |
|
178 |
|
|
|
184 |
]
|
185 |
|
186 |
for text in inputs:
|
187 |
+
print(get_provider_model(text, use_ollama=False))
|
requirements.txt
CHANGED
@@ -12,9 +12,10 @@ langchain-core~=0.3.0
|
|
12 |
langchain-community==0.3.0
|
13 |
langchain-google-genai==2.0.6
|
14 |
langchain-cohere==0.3.3
|
|
|
15 |
streamlit~=1.38.0
|
16 |
|
17 |
-
python-pptx
|
18 |
# metaphor-python
|
19 |
json5~=0.9.14
|
20 |
requests~=2.32.3
|
@@ -32,3 +33,7 @@ certifi==2024.8.30
|
|
32 |
urllib3==2.2.3
|
33 |
|
34 |
anyio==4.4.0
|
|
|
|
|
|
|
|
|
|
12 |
langchain-community==0.3.0
|
13 |
langchain-google-genai==2.0.6
|
14 |
langchain-cohere==0.3.3
|
15 |
+
langchain-ollama==0.2.1
|
16 |
streamlit~=1.38.0
|
17 |
|
18 |
+
python-pptx~=0.6.21
|
19 |
# metaphor-python
|
20 |
json5~=0.9.14
|
21 |
requests~=2.32.3
|
|
|
33 |
urllib3==2.2.3
|
34 |
|
35 |
anyio==4.4.0
|
36 |
+
|
37 |
+
httpx~=0.27.2
|
38 |
+
huggingface-hub~=0.24.5
|
39 |
+
ollama~=0.4.3
|