Spaces:
Running
Running
Add Azure OpenAI support
Browse files- README.md +13 -12
- app.py +55 -8
- global_config.py +10 -3
- helpers/llm_helper.py +54 -9
- requirements.txt +1 -0
README.md
CHANGED
@@ -40,22 +40,23 @@ Clicking on the button will download the file.
|
|
40 |
|
41 |
# Summary of the LLMs
|
42 |
|
43 |
-
SlideDeck AI allows the use of different LLMs from
|
44 |
|
45 |
-
Based on several experiments, SlideDeck AI generally recommends the use of Mistral NeMo
|
46 |
|
47 |
The supported LLMs offer different styles of content generation. Use one of the following LLMs along with relevant API keys/access tokens, as appropriate, to create the content of the slide deck:
|
48 |
|
49 |
-
| LLM | Provider (code) | Requires API key
|
50 |
-
|:---------------------------------| :-------
|
51 |
-
| Mistral 7B Instruct v0.2 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens)
|
52 |
-
| Mistral NeMo Instruct 2407 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens)
|
53 |
-
| Gemini 1.5 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey)
|
54 |
-
| Gemini 2.0 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey)
|
55 |
-
| Gemini 2.0 Flash Lite | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey)
|
56 |
-
|
|
57 |
-
|
|
58 |
-
| Llama 3.
|
|
|
59 |
|
60 |
The Mistral models (via Hugging Face) do not mandatorily require an access token. In other words, you are always free to use these two LLMs, subject to Hugging Face's usage constrains. However, you are strongly encouraged to get and use your own Hugging Face access token.
|
61 |
|
|
|
40 |
|
41 |
# Summary of the LLMs
|
42 |
|
43 |
+
SlideDeck AI allows the use of different LLMs from five online providers—Azure OpenAI, Hugging Face, Google, Cohere, and Together AI. The latter four service providers offer generous free usage of relevant LLMs without requiring any billing information.
|
44 |
|
45 |
+
Based on several experiments, SlideDeck AI generally recommends the use of Mistral NeMo, Gemini Flash, and GPT-4o to generate the slide decks.
|
46 |
|
47 |
The supported LLMs offer different styles of content generation. Use one of the following LLMs along with relevant API keys/access tokens, as appropriate, to create the content of the slide deck:
|
48 |
|
49 |
+
| LLM | Provider (code) | Requires API key | Characteristics |
|
50 |
+
|:---------------------------------| :------- |:-------------------------------------------------------------------------------------------------------------------------|:-------------------------|
|
51 |
+
| Mistral 7B Instruct v0.2 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens) | Faster, shorter content |
|
52 |
+
| Mistral NeMo Instruct 2407 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens) | Slower, longer content |
|
53 |
+
| Gemini 1.5 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
|
54 |
+
| Gemini 2.0 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
|
55 |
+
| Gemini 2.0 Flash Lite | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
|
56 |
+
| GPT | Azure OpenAI (`az`) | Mandatory; [get here](https://ai.azure.com/resource/playground) NOTE: You need to have your subscription/billing set up | Faster, longer content |
|
57 |
+
| Command R+ | Cohere (`co`) | Mandatory; [get here](https://dashboard.cohere.com/api-keys) | Shorter, simpler content |
|
58 |
+
| Llama 3.3 70B Instruct Turbo | Together AI (`to`) | Mandatory; [get here](https://api.together.ai/settings/api-keys) | Detailed, slower |
|
59 |
+
| Llama 3.1 8B Instruct Turbo 128K | Together AI (`to`) | Mandatory; [get here](https://api.together.ai/settings/api-keys) | Shorter |
|
60 |
|
61 |
The Mistral models (via Hugging Face) do not mandatorily require an access token. In other words, you are always free to use these two LLMs, subject to Hugging Face's usage constrains. However, you are strongly encouraged to get and use your own Hugging Face access token.
|
62 |
|
app.py
CHANGED
@@ -66,6 +66,9 @@ def are_all_inputs_valid(
|
|
66 |
selected_provider: str,
|
67 |
selected_model: str,
|
68 |
user_key: str,
|
|
|
|
|
|
|
69 |
) -> bool:
|
70 |
"""
|
71 |
Validate user input and LLM selection.
|
@@ -74,6 +77,9 @@ def are_all_inputs_valid(
|
|
74 |
:param selected_provider: The LLM provider.
|
75 |
:param selected_model: Name of the model.
|
76 |
:param user_key: User-provided API key.
|
|
|
|
|
|
|
77 |
:return: `True` if all inputs "look" OK; `False` otherwise.
|
78 |
"""
|
79 |
|
@@ -90,11 +96,16 @@ def are_all_inputs_valid(
|
|
90 |
handle_error('No valid LLM provider and/or model name found!', False)
|
91 |
return False
|
92 |
|
93 |
-
if not llm_helper.is_valid_llm_provider_model(
|
|
|
|
|
|
|
94 |
handle_error(
|
95 |
'The LLM settings do not look correct. Make sure that an API key/access token'
|
96 |
-
' is provided if the selected LLM requires it. An API key should be 6-
|
97 |
-
' long, only containing alphanumeric characters, hyphens, and underscores
|
|
|
|
|
98 |
False
|
99 |
)
|
100 |
return False
|
@@ -170,13 +181,35 @@ with st.sidebar:
|
|
170 |
api_key_token = st.text_input(
|
171 |
label=(
|
172 |
'3: Paste your API key/access token:\n\n'
|
173 |
-
'*Mandatory* for Cohere, Google Gemini, and Together AI providers.'
|
174 |
' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
|
175 |
),
|
176 |
type='password',
|
177 |
key='api_key_input'
|
178 |
)
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
def build_ui():
|
182 |
"""
|
@@ -238,7 +271,15 @@ def set_up_chat_ui():
|
|
238 |
use_ollama=RUN_IN_OFFLINE_MODE
|
239 |
)
|
240 |
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
return
|
243 |
|
244 |
logger.info(
|
@@ -270,7 +311,10 @@ def set_up_chat_ui():
|
|
270 |
provider=provider,
|
271 |
model=llm_name,
|
272 |
max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
|
273 |
-
api_key=
|
|
|
|
|
|
|
274 |
)
|
275 |
|
276 |
if not llm:
|
@@ -282,8 +326,11 @@ def set_up_chat_ui():
|
|
282 |
)
|
283 |
return
|
284 |
|
285 |
-
for
|
286 |
-
|
|
|
|
|
|
|
287 |
|
288 |
# Update the progress bar with an approx progress percentage
|
289 |
progress_bar.progress(
|
|
|
66 |
selected_provider: str,
|
67 |
selected_model: str,
|
68 |
user_key: str,
|
69 |
+
azure_deployment_url: str = '',
|
70 |
+
azure_endpoint_name: str = '',
|
71 |
+
azure_api_version: str = '',
|
72 |
) -> bool:
|
73 |
"""
|
74 |
Validate user input and LLM selection.
|
|
|
77 |
:param selected_provider: The LLM provider.
|
78 |
:param selected_model: Name of the model.
|
79 |
:param user_key: User-provided API key.
|
80 |
+
:param azure_deployment_url: Azure OpenAI deployment URL.
|
81 |
+
:param azure_endpoint_name: Azure OpenAI model endpoint.
|
82 |
+
:param azure_api_version: Azure OpenAI API version.
|
83 |
:return: `True` if all inputs "look" OK; `False` otherwise.
|
84 |
"""
|
85 |
|
|
|
96 |
handle_error('No valid LLM provider and/or model name found!', False)
|
97 |
return False
|
98 |
|
99 |
+
if not llm_helper.is_valid_llm_provider_model(
|
100 |
+
selected_provider, selected_model, user_key,
|
101 |
+
azure_endpoint_name, azure_deployment_url, azure_api_version
|
102 |
+
):
|
103 |
handle_error(
|
104 |
'The LLM settings do not look correct. Make sure that an API key/access token'
|
105 |
+
' is provided if the selected LLM requires it. An API key should be 6-94 characters'
|
106 |
+
' long, only containing alphanumeric characters, hyphens, and underscores.\n\n'
|
107 |
+
'If you are using Azure OpenAI, make sure that you have provided the additional and'
|
108 |
+
' correct configurations.',
|
109 |
False
|
110 |
)
|
111 |
return False
|
|
|
181 |
api_key_token = st.text_input(
|
182 |
label=(
|
183 |
'3: Paste your API key/access token:\n\n'
|
184 |
+
'*Mandatory* for Azure OpenAI, Cohere, Google Gemini, and Together AI providers.'
|
185 |
' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
|
186 |
),
|
187 |
type='password',
|
188 |
key='api_key_input'
|
189 |
)
|
190 |
|
191 |
+
# Additional configs for Azure OpenAI
|
192 |
+
with st.expander('**Azure OpenAI-specific configurations**'):
|
193 |
+
azure_endpoint = st.text_input(
|
194 |
+
label=(
|
195 |
+
'4: Azure endpoint URL, e.g., https://example.openai.azure.com/.\n\n'
|
196 |
+
'*Mandatory* for Azure OpenAI (only).'
|
197 |
+
)
|
198 |
+
)
|
199 |
+
azure_deployment = st.text_input(
|
200 |
+
label=(
|
201 |
+
'5: Deployment name on Azure OpenAI:\n\n'
|
202 |
+
'*Mandatory* for Azure OpenAI (only).'
|
203 |
+
),
|
204 |
+
)
|
205 |
+
api_version = st.text_input(
|
206 |
+
label=(
|
207 |
+
'6: API version:\n\n'
|
208 |
+
'*Mandatory* field. Change based on your deployment configurations.'
|
209 |
+
),
|
210 |
+
value='2024-05-01-preview',
|
211 |
+
)
|
212 |
+
|
213 |
|
214 |
def build_ui():
|
215 |
"""
|
|
|
271 |
use_ollama=RUN_IN_OFFLINE_MODE
|
272 |
)
|
273 |
|
274 |
+
user_key = api_key_token.strip()
|
275 |
+
az_deployment = azure_deployment.strip()
|
276 |
+
az_endpoint = azure_endpoint.strip()
|
277 |
+
api_ver = api_version.strip()
|
278 |
+
|
279 |
+
if not are_all_inputs_valid(
|
280 |
+
prompt, provider, llm_name, user_key,
|
281 |
+
az_deployment, az_endpoint, api_ver
|
282 |
+
):
|
283 |
return
|
284 |
|
285 |
logger.info(
|
|
|
311 |
provider=provider,
|
312 |
model=llm_name,
|
313 |
max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
|
314 |
+
api_key=user_key,
|
315 |
+
azure_endpoint_url=az_endpoint,
|
316 |
+
azure_deployment_name=az_deployment,
|
317 |
+
azure_api_version=api_ver,
|
318 |
)
|
319 |
|
320 |
if not llm:
|
|
|
326 |
)
|
327 |
return
|
328 |
|
329 |
+
for chunk in llm.stream(formatted_template):
|
330 |
+
if isinstance(chunk, str):
|
331 |
+
response += chunk
|
332 |
+
else:
|
333 |
+
response += chunk.content # AIMessageChunk
|
334 |
|
335 |
# Update the progress bar with an approx progress percentage
|
336 |
progress_bar.progress(
|
global_config.py
CHANGED
@@ -22,14 +22,21 @@ class GlobalConfig:
|
|
22 |
PROVIDER_HUGGING_FACE = 'hf'
|
23 |
PROVIDER_OLLAMA = 'ol'
|
24 |
PROVIDER_TOGETHER_AI = 'to'
|
|
|
25 |
VALID_PROVIDERS = {
|
26 |
PROVIDER_COHERE,
|
27 |
PROVIDER_GOOGLE_GEMINI,
|
28 |
PROVIDER_HUGGING_FACE,
|
29 |
PROVIDER_OLLAMA,
|
30 |
-
PROVIDER_TOGETHER_AI
|
|
|
31 |
}
|
32 |
VALID_MODELS = {
|
|
|
|
|
|
|
|
|
|
|
33 |
'[co]command-r-08-2024': {
|
34 |
'description': 'simpler, slower',
|
35 |
'max_new_tokens': 4096,
|
@@ -79,7 +86,7 @@ class GlobalConfig:
|
|
79 |
'- **[to]**: Together AI\n\n'
|
80 |
'[Find out more](https://github.com/barun-saha/slide-deck-ai?tab=readme-ov-file#summary-of-the-llms)'
|
81 |
)
|
82 |
-
DEFAULT_MODEL_INDEX =
|
83 |
LLM_MODEL_TEMPERATURE = 0.2
|
84 |
LLM_MODEL_MIN_OUTPUT_LENGTH = 100
|
85 |
LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
|
@@ -135,7 +142,7 @@ class GlobalConfig:
|
|
135 |
'Remember, the conversational interface is meant to (and will) update yor *initial*'
|
136 |
' slide deck. If you want to create a new slide deck on a different topic,'
|
137 |
' start a new chat session by reloading this page.\n\n'
|
138 |
-
'Currently,
|
139 |
' If one is not available, choose the other from the dropdown list. A [summary of'
|
140 |
' the supported LLMs]('
|
141 |
'https://github.com/barun-saha/slide-deck-ai/blob/main/README.md#summary-of-the-llms)'
|
|
|
22 |
PROVIDER_HUGGING_FACE = 'hf'
|
23 |
PROVIDER_OLLAMA = 'ol'
|
24 |
PROVIDER_TOGETHER_AI = 'to'
|
25 |
+
PROVIDER_AZURE_OPENAI = 'az'
|
26 |
VALID_PROVIDERS = {
|
27 |
PROVIDER_COHERE,
|
28 |
PROVIDER_GOOGLE_GEMINI,
|
29 |
PROVIDER_HUGGING_FACE,
|
30 |
PROVIDER_OLLAMA,
|
31 |
+
PROVIDER_TOGETHER_AI,
|
32 |
+
PROVIDER_AZURE_OPENAI,
|
33 |
}
|
34 |
VALID_MODELS = {
|
35 |
+
'[az]azure/open-ai': {
|
36 |
+
'description': 'faster, detailed',
|
37 |
+
'max_new_tokens': 8192,
|
38 |
+
'paid': True,
|
39 |
+
},
|
40 |
'[co]command-r-08-2024': {
|
41 |
'description': 'simpler, slower',
|
42 |
'max_new_tokens': 4096,
|
|
|
86 |
'- **[to]**: Together AI\n\n'
|
87 |
'[Find out more](https://github.com/barun-saha/slide-deck-ai?tab=readme-ov-file#summary-of-the-llms)'
|
88 |
)
|
89 |
+
DEFAULT_MODEL_INDEX = 5
|
90 |
LLM_MODEL_TEMPERATURE = 0.2
|
91 |
LLM_MODEL_MIN_OUTPUT_LENGTH = 100
|
92 |
LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
|
|
|
142 |
'Remember, the conversational interface is meant to (and will) update yor *initial*'
|
143 |
' slide deck. If you want to create a new slide deck on a different topic,'
|
144 |
' start a new chat session by reloading this page.\n\n'
|
145 |
+
'Currently, paid or *free-to-use* LLMs from five different providers are supported.'
|
146 |
' If one is not available, choose the other from the dropdown list. A [summary of'
|
147 |
' the supported LLMs]('
|
148 |
'https://github.com/barun-saha/slide-deck-ai/blob/main/README.md#summary-of-the-llms)'
|
helpers/llm_helper.py
CHANGED
@@ -4,12 +4,14 @@ Helper functions to access LLMs.
|
|
4 |
import logging
|
5 |
import re
|
6 |
import sys
|
|
|
7 |
from typing import Tuple, Union
|
8 |
|
9 |
import requests
|
10 |
from requests.adapters import HTTPAdapter
|
11 |
from urllib3.util import Retry
|
12 |
-
from langchain_core.language_models import BaseLLM
|
|
|
13 |
|
14 |
sys.path.append('..')
|
15 |
|
@@ -18,14 +20,16 @@ from global_config import GlobalConfig
|
|
18 |
|
19 |
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
|
20 |
OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
|
21 |
-
#
|
22 |
-
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,
|
23 |
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
|
24 |
REQUEST_TIMEOUT = 35
|
25 |
|
|
|
26 |
logger = logging.getLogger(__name__)
|
27 |
logging.getLogger('httpx').setLevel(logging.WARNING)
|
28 |
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
|
|
29 |
|
30 |
retries = Retry(
|
31 |
total=5,
|
@@ -66,7 +70,14 @@ def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]
|
|
66 |
return '', ''
|
67 |
|
68 |
|
69 |
-
def is_valid_llm_provider_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
"""
|
71 |
Verify whether LLM settings are proper.
|
72 |
This function does not verify whether `api_key` is correct. It only confirms that the key has
|
@@ -75,6 +86,9 @@ def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool
|
|
75 |
:param provider: Name of the LLM provider.
|
76 |
:param model: Name of the model.
|
77 |
:param api_key: The API key or access token.
|
|
|
|
|
|
|
78 |
:return: `True` if the settings "look" OK; `False` otherwise.
|
79 |
"""
|
80 |
|
@@ -85,11 +99,19 @@ def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool
|
|
85 |
GlobalConfig.PROVIDER_GOOGLE_GEMINI,
|
86 |
GlobalConfig.PROVIDER_COHERE,
|
87 |
GlobalConfig.PROVIDER_TOGETHER_AI,
|
|
|
88 |
] and not api_key:
|
89 |
return False
|
90 |
|
91 |
-
if api_key:
|
92 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
return True
|
95 |
|
@@ -98,8 +120,11 @@ def get_langchain_llm(
|
|
98 |
provider: str,
|
99 |
model: str,
|
100 |
max_new_tokens: int,
|
101 |
-
api_key: str = ''
|
102 |
-
|
|
|
|
|
|
|
103 |
"""
|
104 |
Get an LLM based on the provider and model specified.
|
105 |
|
@@ -107,7 +132,10 @@ def get_langchain_llm(
|
|
107 |
:param model: The name of the LLM.
|
108 |
:param max_new_tokens: The maximum number of tokens to generate.
|
109 |
:param api_key: API key or access token to use.
|
110 |
-
:
|
|
|
|
|
|
|
111 |
"""
|
112 |
|
113 |
if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
|
@@ -149,6 +177,23 @@ def get_langchain_llm(
|
|
149 |
}
|
150 |
)
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
if provider == GlobalConfig.PROVIDER_COHERE:
|
153 |
from langchain_cohere.llms import Cohere
|
154 |
|
|
|
4 |
import logging
|
5 |
import re
|
6 |
import sys
|
7 |
+
import urllib3
|
8 |
from typing import Tuple, Union
|
9 |
|
10 |
import requests
|
11 |
from requests.adapters import HTTPAdapter
|
12 |
from urllib3.util import Retry
|
13 |
+
from langchain_core.language_models import BaseLLM, BaseChatModel
|
14 |
+
|
15 |
|
16 |
sys.path.append('..')
|
17 |
|
|
|
20 |
|
21 |
LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
|
22 |
OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
|
23 |
+
# 94 characters long, only containing alphanumeric characters, hyphens, and underscores
|
24 |
+
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,94}$')
|
25 |
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
|
26 |
REQUEST_TIMEOUT = 35
|
27 |
|
28 |
+
|
29 |
logger = logging.getLogger(__name__)
|
30 |
logging.getLogger('httpx').setLevel(logging.WARNING)
|
31 |
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
32 |
+
logging.getLogger('openai').setLevel(logging.ERROR)
|
33 |
|
34 |
retries = Retry(
|
35 |
total=5,
|
|
|
70 |
return '', ''
|
71 |
|
72 |
|
73 |
+
def is_valid_llm_provider_model(
|
74 |
+
provider: str,
|
75 |
+
model: str,
|
76 |
+
api_key: str,
|
77 |
+
azure_endpoint_url: str = '',
|
78 |
+
azure_deployment_name: str = '',
|
79 |
+
azure_api_version: str = '',
|
80 |
+
) -> bool:
|
81 |
"""
|
82 |
Verify whether LLM settings are proper.
|
83 |
This function does not verify whether `api_key` is correct. It only confirms that the key has
|
|
|
86 |
:param provider: Name of the LLM provider.
|
87 |
:param model: Name of the model.
|
88 |
:param api_key: The API key or access token.
|
89 |
+
:param azure_endpoint_url: Azure OpenAI endpoint URL.
|
90 |
+
:param azure_deployment_name: Azure OpenAI deployment name.
|
91 |
+
:param azure_api_version: Azure OpenAI API version.
|
92 |
:return: `True` if the settings "look" OK; `False` otherwise.
|
93 |
"""
|
94 |
|
|
|
99 |
GlobalConfig.PROVIDER_GOOGLE_GEMINI,
|
100 |
GlobalConfig.PROVIDER_COHERE,
|
101 |
GlobalConfig.PROVIDER_TOGETHER_AI,
|
102 |
+
GlobalConfig.PROVIDER_AZURE_OPENAI,
|
103 |
] and not api_key:
|
104 |
return False
|
105 |
|
106 |
+
if api_key and API_KEY_REGEX.match(api_key) is None:
|
107 |
+
return False
|
108 |
+
|
109 |
+
if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
|
110 |
+
valid_url = urllib3.util.parse_url(azure_endpoint_url)
|
111 |
+
all_status = all(
|
112 |
+
[azure_api_version, azure_deployment_name, str(valid_url)]
|
113 |
+
)
|
114 |
+
return all_status
|
115 |
|
116 |
return True
|
117 |
|
|
|
120 |
provider: str,
|
121 |
model: str,
|
122 |
max_new_tokens: int,
|
123 |
+
api_key: str = '',
|
124 |
+
azure_endpoint_url: str = '',
|
125 |
+
azure_deployment_name: str = '',
|
126 |
+
azure_api_version: str = '',
|
127 |
+
) -> Union[BaseLLM, BaseChatModel, None]:
|
128 |
"""
|
129 |
Get an LLM based on the provider and model specified.
|
130 |
|
|
|
132 |
:param model: The name of the LLM.
|
133 |
:param max_new_tokens: The maximum number of tokens to generate.
|
134 |
:param api_key: API key or access token to use.
|
135 |
+
:param azure_endpoint_url: Azure OpenAI endpoint URL.
|
136 |
+
:param azure_deployment_name: Azure OpenAI deployment name.
|
137 |
+
:param azure_api_version: Azure OpenAI API version.
|
138 |
+
:return: An instance of the LLM or Chat model; `None` in case of any error.
|
139 |
"""
|
140 |
|
141 |
if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
|
|
|
177 |
}
|
178 |
)
|
179 |
|
180 |
+
if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
|
181 |
+
from langchain_openai import AzureChatOpenAI
|
182 |
+
|
183 |
+
logger.debug('Getting LLM via Azure OpenAI: %s', model)
|
184 |
+
|
185 |
+
# The `model` parameter is not used here; `azure_deployment` points to the desired name
|
186 |
+
return AzureChatOpenAI(
|
187 |
+
azure_deployment=azure_deployment_name,
|
188 |
+
api_version=azure_api_version,
|
189 |
+
azure_endpoint=azure_endpoint_url,
|
190 |
+
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
|
191 |
+
max_tokens=max_new_tokens,
|
192 |
+
timeout=None,
|
193 |
+
max_retries=1,
|
194 |
+
api_key=api_key,
|
195 |
+
)
|
196 |
+
|
197 |
if provider == GlobalConfig.PROVIDER_COHERE:
|
198 |
from langchain_cohere.llms import Cohere
|
199 |
|
requirements.txt
CHANGED
@@ -14,6 +14,7 @@ langchain-google-genai==2.0.6
|
|
14 |
langchain-cohere==0.3.3
|
15 |
langchain-together==0.3.0
|
16 |
langchain-ollama==0.2.1
|
|
|
17 |
streamlit~=1.38.0
|
18 |
|
19 |
python-pptx~=0.6.21
|
|
|
14 |
langchain-cohere==0.3.3
|
15 |
langchain-together==0.3.0
|
16 |
langchain-ollama==0.2.1
|
17 |
+
langchain-openai==0.3.3
|
18 |
streamlit~=1.38.0
|
19 |
|
20 |
python-pptx~=0.6.21
|