mattoofahad commited on
Commit
f79cf2d
·
1 Parent(s): ed24ac5

adding support ti run lm-studio

Browse files
.gitignore CHANGED
@@ -158,3 +158,5 @@ cython_debug/
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
 
 
 
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
+
162
+ /notebooks
src/app.py CHANGED
@@ -24,7 +24,7 @@ def main():
24
  if (
25
  st.session_state.openai_api_key is not None
26
  and st.session_state.openai_api_key != ""
27
- ):
28
  logger.info("OpenAI key Checking condition passed")
29
  if OpenAIFunctions.check_openai_api_key():
30
  logger.info("Inference Started")
 
24
  if (
25
  st.session_state.openai_api_key is not None
26
  and st.session_state.openai_api_key != ""
27
+ ) or st.session_state.provider_select != "OpenAI":
28
  logger.info("OpenAI key Checking condition passed")
29
  if OpenAIFunctions.check_openai_api_key():
30
  logger.info("Inference Started")
src/utils/config.py CHANGED
@@ -1,12 +1,12 @@
1
- """Module doc string"""
2
-
3
- import os
4
-
5
- from dotenv import find_dotenv, load_dotenv
6
-
7
- load_dotenv(find_dotenv(), override=True)
8
-
9
- LOGGER_LEVEL = os.getenv("LOGGER_LEVEL", "INFO")
10
- DISCORD_HOOK = os.getenv("DISCORD_HOOK", "NO_HOOK")
11
- ENVIRONMENT = os.getenv("ENVIRONMENT", "NOT_LOCAL")
12
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "NO_KEY")
 
1
+ """Module doc string"""
2
+
3
+ import os
4
+
5
+ from dotenv import find_dotenv, load_dotenv
6
+
7
+ load_dotenv(find_dotenv(), override=True)
8
+
9
+ LOGGER_LEVEL = os.getenv("LOGGER_LEVEL", "INFO")
10
+ DISCORD_HOOK = os.getenv("DISCORD_HOOK", "NO_HOOK")
11
+ ENVIRONMENT = os.getenv("ENVIRONMENT", "NOT_LOCAL")
12
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "NO_KEY")
src/utils/constants.py CHANGED
@@ -12,13 +12,14 @@ class ConstantVariables:
12
  "gpt-4-turbo",
13
  "gpt-3.5-turbo",
14
  "o1-preview",
15
- "o1-mini"
16
  )
 
17
  default_model = "gpt-4o-mini"
18
-
19
- max_tokens = 180
20
- min_token = 20
21
- step = 20
22
  default = round(((max_tokens + min_token) / 2) / step) * step
23
  default_token = max(min_token, min(max_tokens, default))
24
 
 
12
  "gpt-4-turbo",
13
  "gpt-3.5-turbo",
14
  "o1-preview",
15
+ "o1-mini",
16
  )
17
+ provider = ("lm-studio", "OpenAI")
18
  default_model = "gpt-4o-mini"
19
+ default_provider = "lm-studio"
20
+ max_tokens = 1024
21
+ min_token = 32
22
+ step = 32
23
  default = round(((max_tokens + min_token) / 2) / step) * step
24
  default_token = max(min_token, min(max_tokens, default))
25
 
src/utils/logs.py CHANGED
@@ -1,98 +1,106 @@
1
- """Module doc string"""
2
-
3
- import asyncio
4
- import logging
5
- import sys
6
- import time
7
- from functools import wraps
8
-
9
- from colorama import Back, Fore, Style, init
10
-
11
- from .config import LOGGER_LEVEL
12
-
13
- # Initialize colorama
14
- init(autoreset=True)
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
- if not logger.hasHandlers():
19
- logger.propagate = False
20
- logger.setLevel(LOGGER_LEVEL)
21
-
22
- # Define color codes for different log levels
23
- log_colors = {
24
- logging.DEBUG: Fore.CYAN,
25
- logging.INFO: Fore.GREEN,
26
- logging.WARNING: Fore.YELLOW,
27
- logging.ERROR: Fore.RED,
28
- logging.CRITICAL: Fore.RED + Back.WHITE + Style.BRIGHT,
29
- }
30
-
31
- class ColoredFormatter(logging.Formatter):
32
- """Module doc string"""
33
-
34
- def format(self, record):
35
- """Module doc string"""
36
-
37
- levelno = record.levelno
38
- color = log_colors.get(levelno, "")
39
-
40
- # Format the message
41
- message = record.getMessage()
42
-
43
- # Format the rest of the log details
44
- details = self._fmt % {
45
- "asctime": self.formatTime(record, self.datefmt),
46
- "levelname": record.levelname,
47
- "module": record.module,
48
- "funcName": record.funcName,
49
- "lineno": record.lineno,
50
- }
51
-
52
- # Combine details and colored message
53
- return f"{Fore.WHITE}{details} :: {color}{message}{Style.RESET_ALL}"
54
-
55
- normal_handler = logging.StreamHandler(sys.stdout)
56
- normal_handler.setLevel(logging.DEBUG)
57
- normal_handler.addFilter(lambda logRecord: logRecord.levelno < logging.WARNING)
58
-
59
- error_handler = logging.StreamHandler(sys.stderr)
60
- error_handler.setLevel(logging.WARNING)
61
-
62
- formatter = ColoredFormatter(
63
- "%(asctime)s :: %(levelname)s :: %(module)s :: %(funcName)s :: %(lineno)d"
64
- )
65
-
66
- normal_handler.setFormatter(formatter)
67
- error_handler.setFormatter(formatter)
68
-
69
- logger.addHandler(normal_handler)
70
- logger.addHandler(error_handler)
71
-
72
-
73
- def log_execution_time(func):
74
- """Module doc string"""
75
-
76
- @wraps(func)
77
- def sync_wrapper(*args, **kwargs):
78
- start_time = time.time()
79
- result = func(*args, **kwargs)
80
- end_time = time.time()
81
- execution_time = end_time - start_time
82
- message_string = f"{func.__name__} executed in {execution_time:.4f} seconds"
83
- logger.debug(message_string)
84
- return result
85
-
86
- @wraps(func)
87
- async def async_wrapper(*args, **kwargs):
88
- start_time = time.time()
89
- result = await func(*args, **kwargs)
90
- end_time = time.time()
91
- execution_time = end_time - start_time
92
- message_string = f"{func.__name__} executed in {execution_time:.4f} seconds"
93
- logger.debug(message_string)
94
- return result
95
-
96
- if asyncio.iscoroutinefunction(func):
97
- return async_wrapper
98
- return sync_wrapper
 
 
 
 
 
 
 
 
 
1
+ """Module doc string"""
2
+
3
+ import asyncio
4
+ import logging
5
+ import sys
6
+ import time
7
+ from functools import wraps
8
+
9
+ from colorama import Back, Fore, Style, init
10
+
11
+ from .config import LOGGER_LEVEL
12
+
13
+ # Initialize colorama
14
+ init(autoreset=True)
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ if not logger.hasHandlers():
19
+ logger.propagate = False
20
+ logger.setLevel(LOGGER_LEVEL)
21
+
22
+ # Define color codes for different log levels
23
+ log_colors = {
24
+ logging.DEBUG: Fore.CYAN,
25
+ logging.INFO: Fore.GREEN,
26
+ logging.WARNING: Fore.YELLOW,
27
+ logging.ERROR: Fore.RED,
28
+ logging.CRITICAL: Fore.RED + Back.WHITE + Style.BRIGHT,
29
+ }
30
+
31
+ class ColoredFormatter(logging.Formatter):
32
+ """Module doc string"""
33
+
34
+ def format(self, record):
35
+ """Module doc string"""
36
+
37
+ levelno = record.levelno
38
+ color = log_colors.get(levelno, "")
39
+
40
+ # Format the message
41
+ message = record.getMessage()
42
+
43
+ # Format the rest of the log details
44
+ details = self._fmt % {
45
+ "asctime": self.formatTime(record, self.datefmt),
46
+ "levelname": record.levelname,
47
+ "module": record.module,
48
+ "funcName": record.funcName,
49
+ "lineno": record.lineno,
50
+ }
51
+
52
+ # Combine details and colored message
53
+ return (
54
+ f"{Fore.WHITE}{details} :: {color}{message}{Style.RESET_ALL}"
55
+ )
56
+
57
+ normal_handler = logging.StreamHandler(sys.stdout)
58
+ normal_handler.setLevel(logging.DEBUG)
59
+ normal_handler.addFilter(
60
+ lambda logRecord: logRecord.levelno < logging.WARNING
61
+ )
62
+
63
+ error_handler = logging.StreamHandler(sys.stderr)
64
+ error_handler.setLevel(logging.WARNING)
65
+
66
+ formatter = ColoredFormatter(
67
+ "%(asctime)s :: %(levelname)s :: %(module)s :: %(funcName)s :: %(lineno)d"
68
+ )
69
+
70
+ normal_handler.setFormatter(formatter)
71
+ error_handler.setFormatter(formatter)
72
+
73
+ logger.addHandler(normal_handler)
74
+ logger.addHandler(error_handler)
75
+
76
+
77
+ def log_execution_time(func):
78
+ """Module doc string"""
79
+
80
+ @wraps(func)
81
+ def sync_wrapper(*args, **kwargs):
82
+ start_time = time.time()
83
+ result = func(*args, **kwargs)
84
+ end_time = time.time()
85
+ execution_time = end_time - start_time
86
+ message_string = (
87
+ f"{func.__name__} executed in {execution_time:.4f} seconds"
88
+ )
89
+ logger.debug(message_string)
90
+ return result
91
+
92
+ @wraps(func)
93
+ async def async_wrapper(*args, **kwargs):
94
+ start_time = time.time()
95
+ result = await func(*args, **kwargs)
96
+ end_time = time.time()
97
+ execution_time = end_time - start_time
98
+ message_string = (
99
+ f"{func.__name__} executed in {execution_time:.4f} seconds"
100
+ )
101
+ logger.debug(message_string)
102
+ return result
103
+
104
+ if asyncio.iscoroutinefunction(func):
105
+ return async_wrapper
106
+ return sync_wrapper
src/utils/openai_utils.py CHANGED
@@ -15,21 +15,27 @@ class OpenAIFunctions:
15
  @staticmethod
16
  def invoke_model():
17
  """_summary_"""
 
18
  logger.debug("OpenAI invoked")
19
  with st.chat_message("assistant"):
20
  messages = [
21
  {"role": m["role"], "content": m["content"]}
22
  for m in st.session_state.messages
23
  ]
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- stream = completion(
26
- api_key=st.session_state.openai_api_key,
27
- model=st.session_state["openai_model"],
28
- messages=messages,
29
- max_tokens=st.session_state["openai_maxtokens"],
30
- stream=True,
31
- stream_options={"include_usage": True},
32
- )
33
 
34
  def stream_data():
35
  for chunk in stream:
@@ -50,24 +56,28 @@ class OpenAIFunctions:
50
  @staticmethod
51
  def check_openai_api_key():
52
  """_summary_"""
53
- logger.info("Checking OpenAI Key")
54
- try:
55
- client = OpenAI(api_key=st.session_state.openai_api_key)
56
- client.models.list()
57
- logger.debug("OpenAI key Working")
58
  return True
59
- except openai.AuthenticationError as auth_error:
60
- with st.chat_message("assistant"):
61
- st.error(str(auth_error))
62
- logger.error("AuthenticationError: %s", auth_error)
63
- return False
64
- except openai.OpenAIError as openai_error:
65
- with st.chat_message("assistant"):
66
- st.error(str(openai_error))
67
- logger.error("OpenAIError: %s", openai_error)
68
- return False
69
- except Exception as general_error:
70
- with st.chat_message("assistant"):
71
- st.error(str(general_error))
72
- logger.error("Unexpected error: %s", general_error)
73
- return False
 
 
 
 
 
 
 
 
15
  @staticmethod
16
  def invoke_model():
17
  """_summary_"""
18
+
19
  logger.debug("OpenAI invoked")
20
  with st.chat_message("assistant"):
21
  messages = [
22
  {"role": m["role"], "content": m["content"]}
23
  for m in st.session_state.messages
24
  ]
25
+ comp_args = {}
26
+ if st.session_state.provider_select == "OpenAI":
27
+ comp_args["api_key"] = st.session_state.openai_api_key
28
+ comp_args["model"] = st.session_state["openai_model"]
29
+ elif st.session_state.provider_select == "lm-studio":
30
+ comp_args["base_url"] = "http://localhost:1234/v1"
31
+ comp_args["api_key"] = st.session_state.provider_select
32
+ comp_args["model"] = "gpt-4o-mini"
33
+ comp_args["messages"] = messages
34
+ comp_args["max_tokens"] = st.session_state["openai_maxtokens"]
35
+ comp_args["stream"] = True
36
+ comp_args["stream_options"] = {"include_usage": True}
37
 
38
+ stream = completion(**comp_args)
 
 
 
 
 
 
 
39
 
40
  def stream_data():
41
  for chunk in stream:
 
56
  @staticmethod
57
  def check_openai_api_key():
58
  """_summary_"""
59
+ if st.session_state.provider_select == "lm-studio":
60
+ logger.info("Local Provider is Sekected")
 
 
 
61
  return True
62
+ else:
63
+ logger.info("Checking OpenAI Key")
64
+ try:
65
+ client = OpenAI(api_key=st.session_state.openai_api_key)
66
+ client.models.list()
67
+ logger.debug("OpenAI key Working")
68
+ return True
69
+ except openai.AuthenticationError as auth_error:
70
+ with st.chat_message("assistant"):
71
+ st.error(str(auth_error))
72
+ logger.error("AuthenticationError: %s", auth_error)
73
+ return False
74
+ except openai.OpenAIError as openai_error:
75
+ with st.chat_message("assistant"):
76
+ st.error(str(openai_error))
77
+ logger.error("OpenAIError: %s", openai_error)
78
+ return False
79
+ except Exception as general_error:
80
+ with st.chat_message("assistant"):
81
+ st.error(str(general_error))
82
+ logger.error("Unexpected error: %s", general_error)
83
+ return False
src/utils/streamlit_utils.py CHANGED
@@ -25,36 +25,53 @@ class StreamlitFunctions:
25
  def streamlit_side_bar():
26
  """_summary_"""
27
  with st.sidebar:
28
- st.text_input(
29
- label="OpenAI API key",
30
- value=ConstantVariables.api_key,
31
- help="This will not be saved or stored.",
32
- type="password",
33
- key="api_key",
34
- )
35
-
36
  st.selectbox(
37
- "Select the GPT model",
38
- ConstantVariables.model_list_tuple,
39
- key="openai_model",
40
- )
41
- st.slider(
42
- "Max Tokens",
43
- min_value=ConstantVariables.min_token,
44
- max_value=ConstantVariables.max_tokens,
45
- step=ConstantVariables.step,
46
- key="openai_maxtokens",
47
- )
48
- st.button(
49
- "Start Chat",
50
- on_click=StreamlitFunctions.start_app,
51
- use_container_width=True,
52
- )
53
- st.button(
54
- "Reset History",
55
- on_click=StreamlitFunctions.reset_history,
56
- use_container_width=True,
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  @staticmethod
60
  def streamlit_initialize_variables():
@@ -66,15 +83,23 @@ class StreamlitFunctions:
66
  if "openai_model" not in st.session_state:
67
  st.session_state["openai_model"] = ConstantVariables.default_model
68
 
 
 
 
69
  if "openai_api_key" not in st.session_state:
70
  st.session_state["openai_api_key"] = None
71
 
72
  if "openai_maxtokens" not in st.session_state:
73
- st.session_state["openai_maxtokens"] = ConstantVariables.default_token
 
 
74
 
75
  if "start_app" not in st.session_state:
76
  st.session_state["start_app"] = False
77
 
 
 
 
78
  @staticmethod
79
  def reset_history():
80
  """_summary_"""
@@ -102,7 +127,11 @@ class StreamlitFunctions:
102
  if prompt := st.chat_input("Type your Query"):
103
  with st.chat_message("user"):
104
  st.markdown(prompt)
105
- st.session_state.messages.append({"role": "user", "content": prompt})
 
 
106
  response = OpenAIFunctions.invoke_model()
107
  logger.debug(response)
108
- st.session_state.messages.append({"role": "assistant", "content": response[0]})
 
 
 
25
  def streamlit_side_bar():
26
  """_summary_"""
27
  with st.sidebar:
 
 
 
 
 
 
 
 
28
  st.selectbox(
29
+ "Select Provider",
30
+ ConstantVariables.provider,
31
+ placeholder="Choose an option",
32
+ key="provider_select",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
+ if st.session_state.provider_select is not None:
35
+ if st.session_state.provider_select == "OpenAI":
36
+ st.text_input(
37
+ label="OpenAI API key",
38
+ value=ConstantVariables.api_key,
39
+ help="This will not be saved or stored.",
40
+ type="password",
41
+ key="api_key",
42
+ )
43
+
44
+ st.selectbox(
45
+ "Select the GPT model",
46
+ ConstantVariables.model_list_tuple,
47
+ key="openai_model",
48
+ )
49
+
50
+ elif st.session_state.provider_select == "lm-studio":
51
+ st.header("NOTE")
52
+ st.text(
53
+ "lm-studio is configured to work on `http://localhost:1234/v1`"
54
+ )
55
+
56
+ st.slider(
57
+ "Max Tokens",
58
+ min_value=ConstantVariables.min_token,
59
+ max_value=ConstantVariables.max_tokens,
60
+ step=ConstantVariables.step,
61
+ key="openai_maxtokens",
62
+ )
63
+
64
+ st.button(
65
+ "Start Chat",
66
+ on_click=StreamlitFunctions.start_app,
67
+ use_container_width=True,
68
+ )
69
+
70
+ st.button(
71
+ "Reset History",
72
+ on_click=StreamlitFunctions.reset_history,
73
+ use_container_width=True,
74
+ )
75
 
76
  @staticmethod
77
  def streamlit_initialize_variables():
 
83
  if "openai_model" not in st.session_state:
84
  st.session_state["openai_model"] = ConstantVariables.default_model
85
 
86
+ if "provider_select" not in st.session_state:
87
+ st.session_state["provider_select"] = None
88
+
89
  if "openai_api_key" not in st.session_state:
90
  st.session_state["openai_api_key"] = None
91
 
92
  if "openai_maxtokens" not in st.session_state:
93
+ st.session_state["openai_maxtokens"] = (
94
+ ConstantVariables.default_token
95
+ )
96
 
97
  if "start_app" not in st.session_state:
98
  st.session_state["start_app"] = False
99
 
100
+ if "api_key" not in st.session_state:
101
+ st.session_state["api_key"] = None
102
+
103
  @staticmethod
104
  def reset_history():
105
  """_summary_"""
 
127
  if prompt := st.chat_input("Type your Query"):
128
  with st.chat_message("user"):
129
  st.markdown(prompt)
130
+ st.session_state.messages.append(
131
+ {"role": "user", "content": prompt}
132
+ )
133
  response = OpenAIFunctions.invoke_model()
134
  logger.debug(response)
135
+ st.session_state.messages.append(
136
+ {"role": "assistant", "content": response[0]}
137
+ )