Spaces:
Building
Building
Update chat_handler.py
Browse files- chat_handler.py +40 -30
chat_handler.py
CHANGED
@@ -53,48 +53,60 @@ def _safe_intent_parse(raw: str) -> tuple[str, str]:
|
|
53 |
|
54 |
# βββββββββββββββββββββββββ CONFIG βββββββββββββββββββββββββ #
|
55 |
SPARK_URL = str(cfg.global_config.spark_endpoint).rstrip("/")
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
from config_provider import ConfigProvider
|
61 |
-
# Force reload
|
62 |
ConfigProvider._instance = None
|
63 |
-
|
|
|
64 |
|
65 |
# βββββββββββββββββββββββββ SPARK βββββββββββββββββββββββββ #
|
66 |
def initialize_llm(force_reload=False):
|
67 |
"""Initialize LLM provider based on work_mode"""
|
68 |
global llm_provider
|
69 |
|
70 |
-
#
|
71 |
-
if force_reload
|
72 |
-
cfg =
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
else:
|
85 |
-
# Spark mode
|
86 |
-
spark_token = _get_spark_token()
|
87 |
-
if not spark_token:
|
88 |
-
raise ValueError("Spark token not configured")
|
89 |
-
|
90 |
-
spark_endpoint = str(cfg.global_config.spark_endpoint)
|
91 |
-
llm_provider = SparkLLM(spark_endpoint, spark_token)
|
92 |
-
log("β
Initialized Spark provider")
|
93 |
|
94 |
# βββββββββββββββββββββββββ SPARK βββββββββββββββββββββββββ #
|
95 |
def _get_spark_token() -> Optional[str]:
|
96 |
"""Get Spark token based on work_mode"""
|
97 |
-
cfg =
|
98 |
work_mode = cfg.global_config.work_mode
|
99 |
|
100 |
if work_mode in ("hfcloud", "cloud"):
|
@@ -108,7 +120,7 @@ def _get_spark_token() -> Optional[str]:
|
|
108 |
from dotenv import load_dotenv
|
109 |
load_dotenv()
|
110 |
return os.getenv("SPARK_TOKEN")
|
111 |
-
|
112 |
async def spark_generate(s: Session, prompt: str, user_msg: str) -> str:
|
113 |
"""Call LLM provider with proper error handling"""
|
114 |
try:
|
@@ -219,8 +231,6 @@ async def chat(body: ChatRequest, x_session_id: str = Header(...)):
|
|
219 |
# βββββββββββββββββββββββββ MESSAGE HANDLERS βββββββββββββββββββββββββ #
|
220 |
async def _handle_new_message(session: Session, user_input: str, version) -> str:
|
221 |
"""Handle new message (not parameter followup)"""
|
222 |
-
cfg = get_fresh_config() # Fresh config
|
223 |
-
|
224 |
# Build intent detection prompt
|
225 |
prompt = build_intent_prompt(
|
226 |
version.general_prompt,
|
|
|
53 |
|
54 |
# βββββββββββββββββββββββββ CONFIG βββββββββββββββββββββββββ #
|
55 |
SPARK_URL = str(cfg.global_config.spark_endpoint).rstrip("/")
|
56 |
+
_cfg = None
|
57 |
|
58 |
+
def get_config():
|
59 |
+
"""Get or reload config"""
|
60 |
+
global _cfg
|
61 |
+
if _cfg is None:
|
62 |
+
from config_provider import ConfigProvider
|
63 |
+
_cfg = ConfigProvider.get()
|
64 |
+
return _cfg
|
65 |
+
|
66 |
+
def reload_config():
|
67 |
+
"""Force reload config"""
|
68 |
+
global _cfg
|
69 |
from config_provider import ConfigProvider
|
|
|
70 |
ConfigProvider._instance = None
|
71 |
+
_cfg = ConfigProvider.get()
|
72 |
+
return _cfg
|
73 |
|
74 |
# βββββββββββββββββββββββββ SPARK βββββββββββββββββββββββββ #
|
75 |
def initialize_llm(force_reload=False):
|
76 |
"""Initialize LLM provider based on work_mode"""
|
77 |
global llm_provider
|
78 |
|
79 |
+
# Get fresh config if forced or first time
|
80 |
+
if force_reload:
|
81 |
+
cfg = reload_config()
|
82 |
+
else:
|
83 |
+
cfg = get_config()
|
84 |
+
|
85 |
+
work_mode = cfg.global_config.work_mode
|
86 |
+
|
87 |
+
if cfg.global_config.is_gpt_mode():
|
88 |
+
# GPT mode
|
89 |
+
api_key = cfg.global_config.get_plain_token()
|
90 |
+
if not api_key:
|
91 |
+
raise ValueError("OpenAI API key not configured")
|
92 |
|
93 |
+
model = cfg.global_config.get_gpt_model()
|
94 |
+
llm_provider = GPT4oLLM(api_key, model)
|
95 |
+
log(f"β
Initialized {model} provider")
|
96 |
+
else:
|
97 |
+
# Spark mode
|
98 |
+
spark_token = _get_spark_token()
|
99 |
+
if not spark_token:
|
100 |
+
raise ValueError("Spark token not configured")
|
101 |
|
102 |
+
spark_endpoint = str(cfg.global_config.spark_endpoint)
|
103 |
+
llm_provider = SparkLLM(spark_endpoint, spark_token)
|
104 |
+
log("β
Initialized Spark provider")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# βββββββββββββββββββββββββ SPARK βββββββββββββββββββββββββ #
|
107 |
def _get_spark_token() -> Optional[str]:
|
108 |
"""Get Spark token based on work_mode"""
|
109 |
+
cfg = get_config()
|
110 |
work_mode = cfg.global_config.work_mode
|
111 |
|
112 |
if work_mode in ("hfcloud", "cloud"):
|
|
|
120 |
from dotenv import load_dotenv
|
121 |
load_dotenv()
|
122 |
return os.getenv("SPARK_TOKEN")
|
123 |
+
|
124 |
async def spark_generate(s: Session, prompt: str, user_msg: str) -> str:
|
125 |
"""Call LLM provider with proper error handling"""
|
126 |
try:
|
|
|
231 |
# βββββββββββββββββββββββββ MESSAGE HANDLERS βββββββββββββββββββββββββ #
|
232 |
async def _handle_new_message(session: Session, user_input: str, version) -> str:
|
233 |
"""Handle new message (not parameter followup)"""
|
|
|
|
|
234 |
# Build intent detection prompt
|
235 |
prompt = build_intent_prompt(
|
236 |
version.general_prompt,
|