barunsaha commited on
Commit
e65d286
·
1 Parent(s): ec952fc

Add Azure OpenAI support

Browse files
Files changed (5) hide show
  1. README.md +13 -12
  2. app.py +55 -8
  3. global_config.py +10 -3
  4. helpers/llm_helper.py +54 -9
  5. requirements.txt +1 -0
README.md CHANGED
@@ -40,22 +40,23 @@ Clicking on the button will download the file.
40
 
41
  # Summary of the LLMs
42
 
43
- SlideDeck AI allows the use of different LLMs from four online providers—Hugging Face, Google, Cohere, and Together AI. These service providers—even the latter three—offer generous free usage of relevant LLMs without requiring any billing information.
44
 
45
- Based on several experiments, SlideDeck AI generally recommends the use of Mistral NeMo and Gemini Flash to generate the slide decks.
46
 
47
  The supported LLMs offer different styles of content generation. Use one of the following LLMs along with relevant API keys/access tokens, as appropriate, to create the content of the slide deck:
48
 
49
- | LLM | Provider (code) | Requires API key | Characteristics |
50
- |:---------------------------------| :------- |:-------------------------------------------------------------------------------------|:-------------------------|
51
- | Mistral 7B Instruct v0.2 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens) | Faster, shorter content |
52
- | Mistral NeMo Instruct 2407 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens) | Slower, longer content |
53
- | Gemini 1.5 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
54
- | Gemini 2.0 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
55
- | Gemini 2.0 Flash Lite | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
56
- | Command R+ | Cohere (`co`) | Mandatory; [get here](https://dashboard.cohere.com/api-keys) | Shorter, simpler content |
57
- | Llama 3.3 70B Instruct Turbo | Together AI (`to`) | Mandatory; [get here](https://api.together.ai/settings/api-keys) | Detailed, slower |
58
- | Llama 3.1 8B Instruct Turbo 128K | Together AI (`to`) | Mandatory; [get here](https://api.together.ai/settings/api-keys) | Shorter |
 
59
 
60
  The Mistral models (via Hugging Face) do not mandatorily require an access token. In other words, you are always free to use these two LLMs, subject to Hugging Face's usage constrains. However, you are strongly encouraged to get and use your own Hugging Face access token.
61
 
 
40
 
41
  # Summary of the LLMs
42
 
43
+ SlideDeck AI allows the use of different LLMs from five online providers—Azure OpenAI, Hugging Face, Google, Cohere, and Together AI. The latter four service providers offer generous free usage of relevant LLMs without requiring any billing information.
44
 
45
+ Based on several experiments, SlideDeck AI generally recommends the use of Mistral NeMo, Gemini Flash, and GPT-4o to generate the slide decks.
46
 
47
  The supported LLMs offer different styles of content generation. Use one of the following LLMs along with relevant API keys/access tokens, as appropriate, to create the content of the slide deck:
48
 
49
+ | LLM | Provider (code) | Requires API key | Characteristics |
50
+ |:---------------------------------| :------- |:-------------------------------------------------------------------------------------------------------------------------|:-------------------------|
51
+ | Mistral 7B Instruct v0.2 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens) | Faster, shorter content |
52
+ | Mistral NeMo Instruct 2407 | Hugging Face (`hf`) | Optional but strongly encouraged; [get here](https://huggingface.co/settings/tokens) | Slower, longer content |
53
+ | Gemini 1.5 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
54
+ | Gemini 2.0 Flash | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
55
+ | Gemini 2.0 Flash Lite | Google Gemini API (`gg`) | Mandatory; [get here](https://aistudio.google.com/apikey) | Faster, longer content |
56
+ | GPT | Azure OpenAI (`az`) | Mandatory; [get here](https://ai.azure.com/resource/playground) NOTE: You need to have your subscription/billing set up | Faster, longer content |
57
+ | Command R+ | Cohere (`co`) | Mandatory; [get here](https://dashboard.cohere.com/api-keys) | Shorter, simpler content |
58
+ | Llama 3.3 70B Instruct Turbo | Together AI (`to`) | Mandatory; [get here](https://api.together.ai/settings/api-keys) | Detailed, slower |
59
+ | Llama 3.1 8B Instruct Turbo 128K | Together AI (`to`) | Mandatory; [get here](https://api.together.ai/settings/api-keys) | Shorter |
60
 
61
  The Mistral models (via Hugging Face) do not mandatorily require an access token. In other words, you are always free to use these two LLMs, subject to Hugging Face's usage constrains. However, you are strongly encouraged to get and use your own Hugging Face access token.
62
 
app.py CHANGED
@@ -66,6 +66,9 @@ def are_all_inputs_valid(
66
  selected_provider: str,
67
  selected_model: str,
68
  user_key: str,
 
 
 
69
  ) -> bool:
70
  """
71
  Validate user input and LLM selection.
@@ -74,6 +77,9 @@ def are_all_inputs_valid(
74
  :param selected_provider: The LLM provider.
75
  :param selected_model: Name of the model.
76
  :param user_key: User-provided API key.
 
 
 
77
  :return: `True` if all inputs "look" OK; `False` otherwise.
78
  """
79
 
@@ -90,11 +96,16 @@ def are_all_inputs_valid(
90
  handle_error('No valid LLM provider and/or model name found!', False)
91
  return False
92
 
93
- if not llm_helper.is_valid_llm_provider_model(selected_provider, selected_model, user_key):
 
 
 
94
  handle_error(
95
  'The LLM settings do not look correct. Make sure that an API key/access token'
96
- ' is provided if the selected LLM requires it. An API key should be 6-64 characters'
97
- ' long, only containing alphanumeric characters, hyphens, and underscores.',
 
 
98
  False
99
  )
100
  return False
@@ -170,13 +181,35 @@ with st.sidebar:
170
  api_key_token = st.text_input(
171
  label=(
172
  '3: Paste your API key/access token:\n\n'
173
- '*Mandatory* for Cohere, Google Gemini, and Together AI providers.'
174
  ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
175
  ),
176
  type='password',
177
  key='api_key_input'
178
  )
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  def build_ui():
182
  """
@@ -238,7 +271,15 @@ def set_up_chat_ui():
238
  use_ollama=RUN_IN_OFFLINE_MODE
239
  )
240
 
241
- if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
 
 
 
 
 
 
 
 
242
  return
243
 
244
  logger.info(
@@ -270,7 +311,10 @@ def set_up_chat_ui():
270
  provider=provider,
271
  model=llm_name,
272
  max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
273
- api_key=api_key_token.strip(),
 
 
 
274
  )
275
 
276
  if not llm:
@@ -282,8 +326,11 @@ def set_up_chat_ui():
282
  )
283
  return
284
 
285
- for _ in llm.stream(formatted_template):
286
- response += _
 
 
 
287
 
288
  # Update the progress bar with an approx progress percentage
289
  progress_bar.progress(
 
66
  selected_provider: str,
67
  selected_model: str,
68
  user_key: str,
69
+ azure_deployment_url: str = '',
70
+ azure_endpoint_name: str = '',
71
+ azure_api_version: str = '',
72
  ) -> bool:
73
  """
74
  Validate user input and LLM selection.
 
77
  :param selected_provider: The LLM provider.
78
  :param selected_model: Name of the model.
79
  :param user_key: User-provided API key.
80
+ :param azure_deployment_url: Azure OpenAI deployment URL.
81
+ :param azure_endpoint_name: Azure OpenAI model endpoint.
82
+ :param azure_api_version: Azure OpenAI API version.
83
  :return: `True` if all inputs "look" OK; `False` otherwise.
84
  """
85
 
 
96
  handle_error('No valid LLM provider and/or model name found!', False)
97
  return False
98
 
99
+ if not llm_helper.is_valid_llm_provider_model(
100
+ selected_provider, selected_model, user_key,
101
+ azure_endpoint_name, azure_deployment_url, azure_api_version
102
+ ):
103
  handle_error(
104
  'The LLM settings do not look correct. Make sure that an API key/access token'
105
+ ' is provided if the selected LLM requires it. An API key should be 6-94 characters'
106
+ ' long, only containing alphanumeric characters, hyphens, and underscores.\n\n'
107
+ 'If you are using Azure OpenAI, make sure that you have provided the additional and'
108
+ ' correct configurations.',
109
  False
110
  )
111
  return False
 
181
  api_key_token = st.text_input(
182
  label=(
183
  '3: Paste your API key/access token:\n\n'
184
+ '*Mandatory* for Azure OpenAI, Cohere, Google Gemini, and Together AI providers.'
185
  ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
186
  ),
187
  type='password',
188
  key='api_key_input'
189
  )
190
 
191
+ # Additional configs for Azure OpenAI
192
+ with st.expander('**Azure OpenAI-specific configurations**'):
193
+ azure_endpoint = st.text_input(
194
+ label=(
195
+ '4: Azure endpoint URL, e.g., https://example.openai.azure.com/.\n\n'
196
+ '*Mandatory* for Azure OpenAI (only).'
197
+ )
198
+ )
199
+ azure_deployment = st.text_input(
200
+ label=(
201
+ '5: Deployment name on Azure OpenAI:\n\n'
202
+ '*Mandatory* for Azure OpenAI (only).'
203
+ ),
204
+ )
205
+ api_version = st.text_input(
206
+ label=(
207
+ '6: API version:\n\n'
208
+ '*Mandatory* field. Change based on your deployment configurations.'
209
+ ),
210
+ value='2024-05-01-preview',
211
+ )
212
+
213
 
214
  def build_ui():
215
  """
 
271
  use_ollama=RUN_IN_OFFLINE_MODE
272
  )
273
 
274
+ user_key = api_key_token.strip()
275
+ az_deployment = azure_deployment.strip()
276
+ az_endpoint = azure_endpoint.strip()
277
+ api_ver = api_version.strip()
278
+
279
+ if not are_all_inputs_valid(
280
+ prompt, provider, llm_name, user_key,
281
+ az_deployment, az_endpoint, api_ver
282
+ ):
283
  return
284
 
285
  logger.info(
 
311
  provider=provider,
312
  model=llm_name,
313
  max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
314
+ api_key=user_key,
315
+ azure_endpoint_url=az_endpoint,
316
+ azure_deployment_name=az_deployment,
317
+ azure_api_version=api_ver,
318
  )
319
 
320
  if not llm:
 
326
  )
327
  return
328
 
329
+ for chunk in llm.stream(formatted_template):
330
+ if isinstance(chunk, str):
331
+ response += chunk
332
+ else:
333
+ response += chunk.content # AIMessageChunk
334
 
335
  # Update the progress bar with an approx progress percentage
336
  progress_bar.progress(
global_config.py CHANGED
@@ -22,14 +22,21 @@ class GlobalConfig:
22
  PROVIDER_HUGGING_FACE = 'hf'
23
  PROVIDER_OLLAMA = 'ol'
24
  PROVIDER_TOGETHER_AI = 'to'
 
25
  VALID_PROVIDERS = {
26
  PROVIDER_COHERE,
27
  PROVIDER_GOOGLE_GEMINI,
28
  PROVIDER_HUGGING_FACE,
29
  PROVIDER_OLLAMA,
30
- PROVIDER_TOGETHER_AI
 
31
  }
32
  VALID_MODELS = {
 
 
 
 
 
33
  '[co]command-r-08-2024': {
34
  'description': 'simpler, slower',
35
  'max_new_tokens': 4096,
@@ -79,7 +86,7 @@ class GlobalConfig:
79
  '- **[to]**: Together AI\n\n'
80
  '[Find out more](https://github.com/barun-saha/slide-deck-ai?tab=readme-ov-file#summary-of-the-llms)'
81
  )
82
- DEFAULT_MODEL_INDEX = 4
83
  LLM_MODEL_TEMPERATURE = 0.2
84
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
85
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
@@ -135,7 +142,7 @@ class GlobalConfig:
135
  'Remember, the conversational interface is meant to (and will) update yor *initial*'
136
  ' slide deck. If you want to create a new slide deck on a different topic,'
137
  ' start a new chat session by reloading this page.\n\n'
138
- 'Currently, eight *free-to-use* LLMs from four different providers are supported.'
139
  ' If one is not available, choose the other from the dropdown list. A [summary of'
140
  ' the supported LLMs]('
141
  'https://github.com/barun-saha/slide-deck-ai/blob/main/README.md#summary-of-the-llms)'
 
22
  PROVIDER_HUGGING_FACE = 'hf'
23
  PROVIDER_OLLAMA = 'ol'
24
  PROVIDER_TOGETHER_AI = 'to'
25
+ PROVIDER_AZURE_OPENAI = 'az'
26
  VALID_PROVIDERS = {
27
  PROVIDER_COHERE,
28
  PROVIDER_GOOGLE_GEMINI,
29
  PROVIDER_HUGGING_FACE,
30
  PROVIDER_OLLAMA,
31
+ PROVIDER_TOGETHER_AI,
32
+ PROVIDER_AZURE_OPENAI,
33
  }
34
  VALID_MODELS = {
35
+ '[az]azure/open-ai': {
36
+ 'description': 'faster, detailed',
37
+ 'max_new_tokens': 8192,
38
+ 'paid': True,
39
+ },
40
  '[co]command-r-08-2024': {
41
  'description': 'simpler, slower',
42
  'max_new_tokens': 4096,
 
86
  '- **[to]**: Together AI\n\n'
87
  '[Find out more](https://github.com/barun-saha/slide-deck-ai?tab=readme-ov-file#summary-of-the-llms)'
88
  )
89
+ DEFAULT_MODEL_INDEX = 5
90
  LLM_MODEL_TEMPERATURE = 0.2
91
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
92
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
 
142
  'Remember, the conversational interface is meant to (and will) update yor *initial*'
143
  ' slide deck. If you want to create a new slide deck on a different topic,'
144
  ' start a new chat session by reloading this page.\n\n'
145
+ 'Currently, paid or *free-to-use* LLMs from five different providers are supported.'
146
  ' If one is not available, choose the other from the dropdown list. A [summary of'
147
  ' the supported LLMs]('
148
  'https://github.com/barun-saha/slide-deck-ai/blob/main/README.md#summary-of-the-llms)'
helpers/llm_helper.py CHANGED
@@ -4,12 +4,14 @@ Helper functions to access LLMs.
4
  import logging
5
  import re
6
  import sys
 
7
  from typing import Tuple, Union
8
 
9
  import requests
10
  from requests.adapters import HTTPAdapter
11
  from urllib3.util import Retry
12
- from langchain_core.language_models import BaseLLM
 
13
 
14
  sys.path.append('..')
15
 
@@ -18,14 +20,16 @@ from global_config import GlobalConfig
18
 
19
  LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
20
  OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
21
- # 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
22
- API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,64}$')
23
  HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
24
  REQUEST_TIMEOUT = 35
25
 
 
26
  logger = logging.getLogger(__name__)
27
  logging.getLogger('httpx').setLevel(logging.WARNING)
28
  logging.getLogger('httpcore').setLevel(logging.WARNING)
 
29
 
30
  retries = Retry(
31
  total=5,
@@ -66,7 +70,14 @@ def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]
66
  return '', ''
67
 
68
 
69
- def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool:
 
 
 
 
 
 
 
70
  """
71
  Verify whether LLM settings are proper.
72
  This function does not verify whether `api_key` is correct. It only confirms that the key has
@@ -75,6 +86,9 @@ def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool
75
  :param provider: Name of the LLM provider.
76
  :param model: Name of the model.
77
  :param api_key: The API key or access token.
 
 
 
78
  :return: `True` if the settings "look" OK; `False` otherwise.
79
  """
80
 
@@ -85,11 +99,19 @@ def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool
85
  GlobalConfig.PROVIDER_GOOGLE_GEMINI,
86
  GlobalConfig.PROVIDER_COHERE,
87
  GlobalConfig.PROVIDER_TOGETHER_AI,
 
88
  ] and not api_key:
89
  return False
90
 
91
- if api_key:
92
- return API_KEY_REGEX.match(api_key) is not None
 
 
 
 
 
 
 
93
 
94
  return True
95
 
@@ -98,8 +120,11 @@ def get_langchain_llm(
98
  provider: str,
99
  model: str,
100
  max_new_tokens: int,
101
- api_key: str = ''
102
- ) -> Union[BaseLLM, None]:
 
 
 
103
  """
104
  Get an LLM based on the provider and model specified.
105
 
@@ -107,7 +132,10 @@ def get_langchain_llm(
107
  :param model: The name of the LLM.
108
  :param max_new_tokens: The maximum number of tokens to generate.
109
  :param api_key: API key or access token to use.
110
- :return: An instance of the LLM or `None` in case of any error.
 
 
 
111
  """
112
 
113
  if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
@@ -149,6 +177,23 @@ def get_langchain_llm(
149
  }
150
  )
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  if provider == GlobalConfig.PROVIDER_COHERE:
153
  from langchain_cohere.llms import Cohere
154
 
 
4
  import logging
5
  import re
6
  import sys
7
+ import urllib3
8
  from typing import Tuple, Union
9
 
10
  import requests
11
  from requests.adapters import HTTPAdapter
12
  from urllib3.util import Retry
13
+ from langchain_core.language_models import BaseLLM, BaseChatModel
14
+
15
 
16
  sys.path.append('..')
17
 
 
20
 
21
  LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
22
  OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
23
+ # 94 characters long, only containing alphanumeric characters, hyphens, and underscores
24
+ API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,94}$')
25
  HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
26
  REQUEST_TIMEOUT = 35
27
 
28
+
29
  logger = logging.getLogger(__name__)
30
  logging.getLogger('httpx').setLevel(logging.WARNING)
31
  logging.getLogger('httpcore').setLevel(logging.WARNING)
32
+ logging.getLogger('openai').setLevel(logging.ERROR)
33
 
34
  retries = Retry(
35
  total=5,
 
70
  return '', ''
71
 
72
 
73
+ def is_valid_llm_provider_model(
74
+ provider: str,
75
+ model: str,
76
+ api_key: str,
77
+ azure_endpoint_url: str = '',
78
+ azure_deployment_name: str = '',
79
+ azure_api_version: str = '',
80
+ ) -> bool:
81
  """
82
  Verify whether LLM settings are proper.
83
  This function does not verify whether `api_key` is correct. It only confirms that the key has
 
86
  :param provider: Name of the LLM provider.
87
  :param model: Name of the model.
88
  :param api_key: The API key or access token.
89
+ :param azure_endpoint_url: Azure OpenAI endpoint URL.
90
+ :param azure_deployment_name: Azure OpenAI deployment name.
91
+ :param azure_api_version: Azure OpenAI API version.
92
  :return: `True` if the settings "look" OK; `False` otherwise.
93
  """
94
 
 
99
  GlobalConfig.PROVIDER_GOOGLE_GEMINI,
100
  GlobalConfig.PROVIDER_COHERE,
101
  GlobalConfig.PROVIDER_TOGETHER_AI,
102
+ GlobalConfig.PROVIDER_AZURE_OPENAI,
103
  ] and not api_key:
104
  return False
105
 
106
+ if api_key and API_KEY_REGEX.match(api_key) is None:
107
+ return False
108
+
109
+ if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
110
+ valid_url = urllib3.util.parse_url(azure_endpoint_url)
111
+ all_status = all(
112
+ [azure_api_version, azure_deployment_name, str(valid_url)]
113
+ )
114
+ return all_status
115
 
116
  return True
117
 
 
120
  provider: str,
121
  model: str,
122
  max_new_tokens: int,
123
+ api_key: str = '',
124
+ azure_endpoint_url: str = '',
125
+ azure_deployment_name: str = '',
126
+ azure_api_version: str = '',
127
+ ) -> Union[BaseLLM, BaseChatModel, None]:
128
  """
129
  Get an LLM based on the provider and model specified.
130
 
 
132
  :param model: The name of the LLM.
133
  :param max_new_tokens: The maximum number of tokens to generate.
134
  :param api_key: API key or access token to use.
135
+ :param azure_endpoint_url: Azure OpenAI endpoint URL.
136
+ :param azure_deployment_name: Azure OpenAI deployment name.
137
+ :param azure_api_version: Azure OpenAI API version.
138
+ :return: An instance of the LLM or Chat model; `None` in case of any error.
139
  """
140
 
141
  if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
 
177
  }
178
  )
179
 
180
+ if provider == GlobalConfig.PROVIDER_AZURE_OPENAI:
181
+ from langchain_openai import AzureChatOpenAI
182
+
183
+ logger.debug('Getting LLM via Azure OpenAI: %s', model)
184
+
185
+ # The `model` parameter is not used here; `azure_deployment` points to the desired name
186
+ return AzureChatOpenAI(
187
+ azure_deployment=azure_deployment_name,
188
+ api_version=azure_api_version,
189
+ azure_endpoint=azure_endpoint_url,
190
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
191
+ max_tokens=max_new_tokens,
192
+ timeout=None,
193
+ max_retries=1,
194
+ api_key=api_key,
195
+ )
196
+
197
  if provider == GlobalConfig.PROVIDER_COHERE:
198
  from langchain_cohere.llms import Cohere
199
 
requirements.txt CHANGED
@@ -14,6 +14,7 @@ langchain-google-genai==2.0.6
14
  langchain-cohere==0.3.3
15
  langchain-together==0.3.0
16
  langchain-ollama==0.2.1
 
17
  streamlit~=1.38.0
18
 
19
  python-pptx~=0.6.21
 
14
  langchain-cohere==0.3.3
15
  langchain-together==0.3.0
16
  langchain-ollama==0.2.1
17
+ langchain-openai==0.3.3
18
  streamlit~=1.38.0
19
 
20
  python-pptx~=0.6.21