Spaces:
Sleeping
Sleeping
mattoofahad
commited on
Commit
·
f79cf2d
1
Parent(s):
ed24ac5
adding support ti run lm-studio
Browse files- .gitignore +2 -0
- src/app.py +1 -1
- src/utils/config.py +12 -12
- src/utils/constants.py +6 -5
- src/utils/logs.py +106 -98
- src/utils/openai_utils.py +38 -28
- src/utils/streamlit_utils.py +60 -31
.gitignore
CHANGED
@@ -158,3 +158,5 @@ cython_debug/
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
|
|
|
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
161 |
+
|
162 |
+
/notebooks
|
src/app.py
CHANGED
@@ -24,7 +24,7 @@ def main():
|
|
24 |
if (
|
25 |
st.session_state.openai_api_key is not None
|
26 |
and st.session_state.openai_api_key != ""
|
27 |
-
):
|
28 |
logger.info("OpenAI key Checking condition passed")
|
29 |
if OpenAIFunctions.check_openai_api_key():
|
30 |
logger.info("Inference Started")
|
|
|
24 |
if (
|
25 |
st.session_state.openai_api_key is not None
|
26 |
and st.session_state.openai_api_key != ""
|
27 |
+
) or st.session_state.provider_select != "OpenAI":
|
28 |
logger.info("OpenAI key Checking condition passed")
|
29 |
if OpenAIFunctions.check_openai_api_key():
|
30 |
logger.info("Inference Started")
|
src/utils/config.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
"""Module doc string"""
|
2 |
-
|
3 |
-
import os
|
4 |
-
|
5 |
-
from dotenv import find_dotenv, load_dotenv
|
6 |
-
|
7 |
-
load_dotenv(find_dotenv(), override=True)
|
8 |
-
|
9 |
-
LOGGER_LEVEL = os.getenv("LOGGER_LEVEL", "INFO")
|
10 |
-
DISCORD_HOOK = os.getenv("DISCORD_HOOK", "NO_HOOK")
|
11 |
-
ENVIRONMENT = os.getenv("ENVIRONMENT", "NOT_LOCAL")
|
12 |
-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "NO_KEY")
|
|
|
1 |
+
"""Module doc string"""
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
from dotenv import find_dotenv, load_dotenv
|
6 |
+
|
7 |
+
load_dotenv(find_dotenv(), override=True)
|
8 |
+
|
9 |
+
LOGGER_LEVEL = os.getenv("LOGGER_LEVEL", "INFO")
|
10 |
+
DISCORD_HOOK = os.getenv("DISCORD_HOOK", "NO_HOOK")
|
11 |
+
ENVIRONMENT = os.getenv("ENVIRONMENT", "NOT_LOCAL")
|
12 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "NO_KEY")
|
src/utils/constants.py
CHANGED
@@ -12,13 +12,14 @@ class ConstantVariables:
|
|
12 |
"gpt-4-turbo",
|
13 |
"gpt-3.5-turbo",
|
14 |
"o1-preview",
|
15 |
-
"o1-mini"
|
16 |
)
|
|
|
17 |
default_model = "gpt-4o-mini"
|
18 |
-
|
19 |
-
max_tokens =
|
20 |
-
min_token =
|
21 |
-
step =
|
22 |
default = round(((max_tokens + min_token) / 2) / step) * step
|
23 |
default_token = max(min_token, min(max_tokens, default))
|
24 |
|
|
|
12 |
"gpt-4-turbo",
|
13 |
"gpt-3.5-turbo",
|
14 |
"o1-preview",
|
15 |
+
"o1-mini",
|
16 |
)
|
17 |
+
provider = ("lm-studio", "OpenAI")
|
18 |
default_model = "gpt-4o-mini"
|
19 |
+
default_provider = "lm-studio"
|
20 |
+
max_tokens = 1024
|
21 |
+
min_token = 32
|
22 |
+
step = 32
|
23 |
default = round(((max_tokens + min_token) / 2) / step) * step
|
24 |
default_token = max(min_token, min(max_tokens, default))
|
25 |
|
src/utils/logs.py
CHANGED
@@ -1,98 +1,106 @@
|
|
1 |
-
"""Module doc string"""
|
2 |
-
|
3 |
-
import asyncio
|
4 |
-
import logging
|
5 |
-
import sys
|
6 |
-
import time
|
7 |
-
from functools import wraps
|
8 |
-
|
9 |
-
from colorama import Back, Fore, Style, init
|
10 |
-
|
11 |
-
from .config import LOGGER_LEVEL
|
12 |
-
|
13 |
-
# Initialize colorama
|
14 |
-
init(autoreset=True)
|
15 |
-
|
16 |
-
logger = logging.getLogger(__name__)
|
17 |
-
|
18 |
-
if not logger.hasHandlers():
|
19 |
-
logger.propagate = False
|
20 |
-
logger.setLevel(LOGGER_LEVEL)
|
21 |
-
|
22 |
-
# Define color codes for different log levels
|
23 |
-
log_colors = {
|
24 |
-
logging.DEBUG: Fore.CYAN,
|
25 |
-
logging.INFO: Fore.GREEN,
|
26 |
-
logging.WARNING: Fore.YELLOW,
|
27 |
-
logging.ERROR: Fore.RED,
|
28 |
-
logging.CRITICAL: Fore.RED + Back.WHITE + Style.BRIGHT,
|
29 |
-
}
|
30 |
-
|
31 |
-
class ColoredFormatter(logging.Formatter):
|
32 |
-
"""Module doc string"""
|
33 |
-
|
34 |
-
def format(self, record):
|
35 |
-
"""Module doc string"""
|
36 |
-
|
37 |
-
levelno = record.levelno
|
38 |
-
color = log_colors.get(levelno, "")
|
39 |
-
|
40 |
-
# Format the message
|
41 |
-
message = record.getMessage()
|
42 |
-
|
43 |
-
# Format the rest of the log details
|
44 |
-
details = self._fmt % {
|
45 |
-
"asctime": self.formatTime(record, self.datefmt),
|
46 |
-
"levelname": record.levelname,
|
47 |
-
"module": record.module,
|
48 |
-
"funcName": record.funcName,
|
49 |
-
"lineno": record.lineno,
|
50 |
-
}
|
51 |
-
|
52 |
-
# Combine details and colored message
|
53 |
-
return
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
normal_handler
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
)
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module doc string"""
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
import logging
|
5 |
+
import sys
|
6 |
+
import time
|
7 |
+
from functools import wraps
|
8 |
+
|
9 |
+
from colorama import Back, Fore, Style, init
|
10 |
+
|
11 |
+
from .config import LOGGER_LEVEL
|
12 |
+
|
13 |
+
# Initialize colorama
|
14 |
+
init(autoreset=True)
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
if not logger.hasHandlers():
|
19 |
+
logger.propagate = False
|
20 |
+
logger.setLevel(LOGGER_LEVEL)
|
21 |
+
|
22 |
+
# Define color codes for different log levels
|
23 |
+
log_colors = {
|
24 |
+
logging.DEBUG: Fore.CYAN,
|
25 |
+
logging.INFO: Fore.GREEN,
|
26 |
+
logging.WARNING: Fore.YELLOW,
|
27 |
+
logging.ERROR: Fore.RED,
|
28 |
+
logging.CRITICAL: Fore.RED + Back.WHITE + Style.BRIGHT,
|
29 |
+
}
|
30 |
+
|
31 |
+
class ColoredFormatter(logging.Formatter):
|
32 |
+
"""Module doc string"""
|
33 |
+
|
34 |
+
def format(self, record):
|
35 |
+
"""Module doc string"""
|
36 |
+
|
37 |
+
levelno = record.levelno
|
38 |
+
color = log_colors.get(levelno, "")
|
39 |
+
|
40 |
+
# Format the message
|
41 |
+
message = record.getMessage()
|
42 |
+
|
43 |
+
# Format the rest of the log details
|
44 |
+
details = self._fmt % {
|
45 |
+
"asctime": self.formatTime(record, self.datefmt),
|
46 |
+
"levelname": record.levelname,
|
47 |
+
"module": record.module,
|
48 |
+
"funcName": record.funcName,
|
49 |
+
"lineno": record.lineno,
|
50 |
+
}
|
51 |
+
|
52 |
+
# Combine details and colored message
|
53 |
+
return (
|
54 |
+
f"{Fore.WHITE}{details} :: {color}{message}{Style.RESET_ALL}"
|
55 |
+
)
|
56 |
+
|
57 |
+
normal_handler = logging.StreamHandler(sys.stdout)
|
58 |
+
normal_handler.setLevel(logging.DEBUG)
|
59 |
+
normal_handler.addFilter(
|
60 |
+
lambda logRecord: logRecord.levelno < logging.WARNING
|
61 |
+
)
|
62 |
+
|
63 |
+
error_handler = logging.StreamHandler(sys.stderr)
|
64 |
+
error_handler.setLevel(logging.WARNING)
|
65 |
+
|
66 |
+
formatter = ColoredFormatter(
|
67 |
+
"%(asctime)s :: %(levelname)s :: %(module)s :: %(funcName)s :: %(lineno)d"
|
68 |
+
)
|
69 |
+
|
70 |
+
normal_handler.setFormatter(formatter)
|
71 |
+
error_handler.setFormatter(formatter)
|
72 |
+
|
73 |
+
logger.addHandler(normal_handler)
|
74 |
+
logger.addHandler(error_handler)
|
75 |
+
|
76 |
+
|
77 |
+
def log_execution_time(func):
|
78 |
+
"""Module doc string"""
|
79 |
+
|
80 |
+
@wraps(func)
|
81 |
+
def sync_wrapper(*args, **kwargs):
|
82 |
+
start_time = time.time()
|
83 |
+
result = func(*args, **kwargs)
|
84 |
+
end_time = time.time()
|
85 |
+
execution_time = end_time - start_time
|
86 |
+
message_string = (
|
87 |
+
f"{func.__name__} executed in {execution_time:.4f} seconds"
|
88 |
+
)
|
89 |
+
logger.debug(message_string)
|
90 |
+
return result
|
91 |
+
|
92 |
+
@wraps(func)
|
93 |
+
async def async_wrapper(*args, **kwargs):
|
94 |
+
start_time = time.time()
|
95 |
+
result = await func(*args, **kwargs)
|
96 |
+
end_time = time.time()
|
97 |
+
execution_time = end_time - start_time
|
98 |
+
message_string = (
|
99 |
+
f"{func.__name__} executed in {execution_time:.4f} seconds"
|
100 |
+
)
|
101 |
+
logger.debug(message_string)
|
102 |
+
return result
|
103 |
+
|
104 |
+
if asyncio.iscoroutinefunction(func):
|
105 |
+
return async_wrapper
|
106 |
+
return sync_wrapper
|
src/utils/openai_utils.py
CHANGED
@@ -15,21 +15,27 @@ class OpenAIFunctions:
|
|
15 |
@staticmethod
|
16 |
def invoke_model():
|
17 |
"""_summary_"""
|
|
|
18 |
logger.debug("OpenAI invoked")
|
19 |
with st.chat_message("assistant"):
|
20 |
messages = [
|
21 |
{"role": m["role"], "content": m["content"]}
|
22 |
for m in st.session_state.messages
|
23 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
stream = completion(
|
26 |
-
api_key=st.session_state.openai_api_key,
|
27 |
-
model=st.session_state["openai_model"],
|
28 |
-
messages=messages,
|
29 |
-
max_tokens=st.session_state["openai_maxtokens"],
|
30 |
-
stream=True,
|
31 |
-
stream_options={"include_usage": True},
|
32 |
-
)
|
33 |
|
34 |
def stream_data():
|
35 |
for chunk in stream:
|
@@ -50,24 +56,28 @@ class OpenAIFunctions:
|
|
50 |
@staticmethod
|
51 |
def check_openai_api_key():
|
52 |
"""_summary_"""
|
53 |
-
|
54 |
-
|
55 |
-
client = OpenAI(api_key=st.session_state.openai_api_key)
|
56 |
-
client.models.list()
|
57 |
-
logger.debug("OpenAI key Working")
|
58 |
return True
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
@staticmethod
|
16 |
def invoke_model():
|
17 |
"""_summary_"""
|
18 |
+
|
19 |
logger.debug("OpenAI invoked")
|
20 |
with st.chat_message("assistant"):
|
21 |
messages = [
|
22 |
{"role": m["role"], "content": m["content"]}
|
23 |
for m in st.session_state.messages
|
24 |
]
|
25 |
+
comp_args = {}
|
26 |
+
if st.session_state.provider_select == "OpenAI":
|
27 |
+
comp_args["api_key"] = st.session_state.openai_api_key
|
28 |
+
comp_args["model"] = st.session_state["openai_model"]
|
29 |
+
elif st.session_state.provider_select == "lm-studio":
|
30 |
+
comp_args["base_url"] = "http://localhost:1234/v1"
|
31 |
+
comp_args["api_key"] = st.session_state.provider_select
|
32 |
+
comp_args["model"] = "gpt-4o-mini"
|
33 |
+
comp_args["messages"] = messages
|
34 |
+
comp_args["max_tokens"] = st.session_state["openai_maxtokens"]
|
35 |
+
comp_args["stream"] = True
|
36 |
+
comp_args["stream_options"] = {"include_usage": True}
|
37 |
|
38 |
+
stream = completion(**comp_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def stream_data():
|
41 |
for chunk in stream:
|
|
|
56 |
@staticmethod
|
57 |
def check_openai_api_key():
|
58 |
"""_summary_"""
|
59 |
+
if st.session_state.provider_select == "lm-studio":
|
60 |
+
logger.info("Local Provider is Sekected")
|
|
|
|
|
|
|
61 |
return True
|
62 |
+
else:
|
63 |
+
logger.info("Checking OpenAI Key")
|
64 |
+
try:
|
65 |
+
client = OpenAI(api_key=st.session_state.openai_api_key)
|
66 |
+
client.models.list()
|
67 |
+
logger.debug("OpenAI key Working")
|
68 |
+
return True
|
69 |
+
except openai.AuthenticationError as auth_error:
|
70 |
+
with st.chat_message("assistant"):
|
71 |
+
st.error(str(auth_error))
|
72 |
+
logger.error("AuthenticationError: %s", auth_error)
|
73 |
+
return False
|
74 |
+
except openai.OpenAIError as openai_error:
|
75 |
+
with st.chat_message("assistant"):
|
76 |
+
st.error(str(openai_error))
|
77 |
+
logger.error("OpenAIError: %s", openai_error)
|
78 |
+
return False
|
79 |
+
except Exception as general_error:
|
80 |
+
with st.chat_message("assistant"):
|
81 |
+
st.error(str(general_error))
|
82 |
+
logger.error("Unexpected error: %s", general_error)
|
83 |
+
return False
|
src/utils/streamlit_utils.py
CHANGED
@@ -25,36 +25,53 @@ class StreamlitFunctions:
|
|
25 |
def streamlit_side_bar():
|
26 |
"""_summary_"""
|
27 |
with st.sidebar:
|
28 |
-
st.text_input(
|
29 |
-
label="OpenAI API key",
|
30 |
-
value=ConstantVariables.api_key,
|
31 |
-
help="This will not be saved or stored.",
|
32 |
-
type="password",
|
33 |
-
key="api_key",
|
34 |
-
)
|
35 |
-
|
36 |
st.selectbox(
|
37 |
-
"Select
|
38 |
-
ConstantVariables.
|
39 |
-
|
40 |
-
|
41 |
-
st.slider(
|
42 |
-
"Max Tokens",
|
43 |
-
min_value=ConstantVariables.min_token,
|
44 |
-
max_value=ConstantVariables.max_tokens,
|
45 |
-
step=ConstantVariables.step,
|
46 |
-
key="openai_maxtokens",
|
47 |
-
)
|
48 |
-
st.button(
|
49 |
-
"Start Chat",
|
50 |
-
on_click=StreamlitFunctions.start_app,
|
51 |
-
use_container_width=True,
|
52 |
-
)
|
53 |
-
st.button(
|
54 |
-
"Reset History",
|
55 |
-
on_click=StreamlitFunctions.reset_history,
|
56 |
-
use_container_width=True,
|
57 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
@staticmethod
|
60 |
def streamlit_initialize_variables():
|
@@ -66,15 +83,23 @@ class StreamlitFunctions:
|
|
66 |
if "openai_model" not in st.session_state:
|
67 |
st.session_state["openai_model"] = ConstantVariables.default_model
|
68 |
|
|
|
|
|
|
|
69 |
if "openai_api_key" not in st.session_state:
|
70 |
st.session_state["openai_api_key"] = None
|
71 |
|
72 |
if "openai_maxtokens" not in st.session_state:
|
73 |
-
st.session_state["openai_maxtokens"] =
|
|
|
|
|
74 |
|
75 |
if "start_app" not in st.session_state:
|
76 |
st.session_state["start_app"] = False
|
77 |
|
|
|
|
|
|
|
78 |
@staticmethod
|
79 |
def reset_history():
|
80 |
"""_summary_"""
|
@@ -102,7 +127,11 @@ class StreamlitFunctions:
|
|
102 |
if prompt := st.chat_input("Type your Query"):
|
103 |
with st.chat_message("user"):
|
104 |
st.markdown(prompt)
|
105 |
-
st.session_state.messages.append(
|
|
|
|
|
106 |
response = OpenAIFunctions.invoke_model()
|
107 |
logger.debug(response)
|
108 |
-
st.session_state.messages.append(
|
|
|
|
|
|
25 |
def streamlit_side_bar():
|
26 |
"""_summary_"""
|
27 |
with st.sidebar:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
st.selectbox(
|
29 |
+
"Select Provider",
|
30 |
+
ConstantVariables.provider,
|
31 |
+
placeholder="Choose an option",
|
32 |
+
key="provider_select",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
)
|
34 |
+
if st.session_state.provider_select is not None:
|
35 |
+
if st.session_state.provider_select == "OpenAI":
|
36 |
+
st.text_input(
|
37 |
+
label="OpenAI API key",
|
38 |
+
value=ConstantVariables.api_key,
|
39 |
+
help="This will not be saved or stored.",
|
40 |
+
type="password",
|
41 |
+
key="api_key",
|
42 |
+
)
|
43 |
+
|
44 |
+
st.selectbox(
|
45 |
+
"Select the GPT model",
|
46 |
+
ConstantVariables.model_list_tuple,
|
47 |
+
key="openai_model",
|
48 |
+
)
|
49 |
+
|
50 |
+
elif st.session_state.provider_select == "lm-studio":
|
51 |
+
st.header("NOTE")
|
52 |
+
st.text(
|
53 |
+
"lm-studio is configured to work on `http://localhost:1234/v1`"
|
54 |
+
)
|
55 |
+
|
56 |
+
st.slider(
|
57 |
+
"Max Tokens",
|
58 |
+
min_value=ConstantVariables.min_token,
|
59 |
+
max_value=ConstantVariables.max_tokens,
|
60 |
+
step=ConstantVariables.step,
|
61 |
+
key="openai_maxtokens",
|
62 |
+
)
|
63 |
+
|
64 |
+
st.button(
|
65 |
+
"Start Chat",
|
66 |
+
on_click=StreamlitFunctions.start_app,
|
67 |
+
use_container_width=True,
|
68 |
+
)
|
69 |
+
|
70 |
+
st.button(
|
71 |
+
"Reset History",
|
72 |
+
on_click=StreamlitFunctions.reset_history,
|
73 |
+
use_container_width=True,
|
74 |
+
)
|
75 |
|
76 |
@staticmethod
|
77 |
def streamlit_initialize_variables():
|
|
|
83 |
if "openai_model" not in st.session_state:
|
84 |
st.session_state["openai_model"] = ConstantVariables.default_model
|
85 |
|
86 |
+
if "provider_select" not in st.session_state:
|
87 |
+
st.session_state["provider_select"] = None
|
88 |
+
|
89 |
if "openai_api_key" not in st.session_state:
|
90 |
st.session_state["openai_api_key"] = None
|
91 |
|
92 |
if "openai_maxtokens" not in st.session_state:
|
93 |
+
st.session_state["openai_maxtokens"] = (
|
94 |
+
ConstantVariables.default_token
|
95 |
+
)
|
96 |
|
97 |
if "start_app" not in st.session_state:
|
98 |
st.session_state["start_app"] = False
|
99 |
|
100 |
+
if "api_key" not in st.session_state:
|
101 |
+
st.session_state["api_key"] = None
|
102 |
+
|
103 |
@staticmethod
|
104 |
def reset_history():
|
105 |
"""_summary_"""
|
|
|
127 |
if prompt := st.chat_input("Type your Query"):
|
128 |
with st.chat_message("user"):
|
129 |
st.markdown(prompt)
|
130 |
+
st.session_state.messages.append(
|
131 |
+
{"role": "user", "content": prompt}
|
132 |
+
)
|
133 |
response = OpenAIFunctions.invoke_model()
|
134 |
logger.debug(response)
|
135 |
+
st.session_state.messages.append(
|
136 |
+
{"role": "assistant", "content": response[0]}
|
137 |
+
)
|