Spaces:
Running
on
T4
Running
on
T4
update
Browse files- common/auth.py +7 -3
- components/llm/deepinfra_api.py +43 -29
- components/services/dialogue.py +5 -2
- routes/auth.py +10 -6
- routes/llm.py +7 -1
common/auth.py
CHANGED
@@ -11,10 +11,14 @@ SECRET_KEY = os.environ.get("JWT_SECRET", "ooooooh_thats_my_super_secret_key")
|
|
11 |
ALGORITHM = "HS256"
|
12 |
ACCESS_TOKEN_EXPIRE_MINUTES = 1440
|
13 |
|
|
|
|
|
|
|
|
|
14 |
# Захардкоженные пользователи
|
15 |
USERS = [
|
16 |
-
|
17 |
-
|
18 |
]
|
19 |
|
20 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login/token")
|
@@ -39,7 +43,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
39 |
username: str = payload.get("sub")
|
40 |
if username is None:
|
41 |
raise HTTPException(status_code=401, detail="Invalid token")
|
42 |
-
user = next((u for u in USERS if u
|
43 |
if user is None:
|
44 |
raise HTTPException(status_code=401, detail="User not found")
|
45 |
return user
|
|
|
11 |
ALGORITHM = "HS256"
|
12 |
ACCESS_TOKEN_EXPIRE_MINUTES = 1440
|
13 |
|
14 |
+
class User(BaseModel):
|
15 |
+
username: str
|
16 |
+
password: str
|
17 |
+
|
18 |
# Захардкоженные пользователи
|
19 |
USERS = [
|
20 |
+
User(username="admin", password="admin123"),
|
21 |
+
User(username="demo", password="sTrUPsORPA")
|
22 |
]
|
23 |
|
24 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login/token")
|
|
|
43 |
username: str = payload.get("sub")
|
44 |
if username is None:
|
45 |
raise HTTPException(status_code=401, detail="Invalid token")
|
46 |
+
user = next((u for u in USERS if u.username == username), None)
|
47 |
if user is None:
|
48 |
raise HTTPException(status_code=401, detail="User not found")
|
49 |
return user
|
components/llm/deepinfra_api.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import json
|
2 |
from typing import AsyncGenerator, Optional, List
|
3 |
import httpx
|
@@ -256,13 +257,17 @@ class DeepInfraApi(LlmApi):
|
|
256 |
logging.error(f"Request failed: status code {response.status_code}")
|
257 |
logging.error(response.text)
|
258 |
|
259 |
-
async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams) -> str:
|
260 |
"""
|
261 |
Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат.
|
262 |
-
|
263 |
Args:
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
266 |
Returns:
|
267 |
str: Сгенерированный текст.
|
268 |
"""
|
@@ -270,32 +275,41 @@ class DeepInfraApi(LlmApi):
|
|
270 |
request = self.create_chat_request(request, system_prompt, params)
|
271 |
request["stream"] = True
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
-
return generated_text.strip()
|
299 |
|
300 |
async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
|
301 |
params: LlmPredictParams) -> AsyncGenerator[str, None]:
|
|
|
1 |
+
import asyncio
|
2 |
import json
|
3 |
from typing import AsyncGenerator, Optional, List
|
4 |
import httpx
|
|
|
257 |
logging.error(f"Request failed: status code {response.status_code}")
|
258 |
logging.error(response.text)
|
259 |
|
260 |
+
async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams, max_retries: int = 3, retry_delay: float = 0.5) -> str:
|
261 |
"""
|
262 |
Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат.
|
263 |
+
|
264 |
Args:
|
265 |
+
request (ChatRequest): Запрос чата
|
266 |
+
system_prompt: Системный промпт
|
267 |
+
params (LlmPredictParams): Параметры предсказания
|
268 |
+
max_retries (int): Максимальное количество попыток переподключения (по умолчанию 3)
|
269 |
+
retry_delay (float): Задержка между попытками в секундах (по умолчанию 0.5)
|
270 |
+
|
271 |
Returns:
|
272 |
str: Сгенерированный текст.
|
273 |
"""
|
|
|
275 |
request = self.create_chat_request(request, system_prompt, params)
|
276 |
request["stream"] = True
|
277 |
|
278 |
+
for attempt in range(max_retries + 1):
|
279 |
+
try:
|
280 |
+
async with client.stream("POST", f"{self.params.url}/v1/openai/chat/completions",
|
281 |
+
json=request,
|
282 |
+
headers=super().create_headers()) as response:
|
283 |
+
|
284 |
+
if response.status_code != 200:
|
285 |
+
error_content = await response.aread()
|
286 |
+
raise Exception(f"API error: {error_content.decode('utf-8')}")
|
287 |
+
|
288 |
+
generated_text = ""
|
289 |
+
|
290 |
+
async for line in response.aiter_lines():
|
291 |
+
if line.startswith("data: "):
|
292 |
+
try:
|
293 |
+
data = json.loads(line[len("data: "):].strip())
|
294 |
+
if data == "[DONE]":
|
295 |
+
break
|
296 |
+
if "choices" in data and data["choices"]:
|
297 |
+
token_value = data["choices"][0].get("delta", {}).get("content", "")
|
298 |
+
generated_text += token_value
|
299 |
+
except json.JSONDecodeError:
|
300 |
+
continue
|
301 |
+
|
302 |
+
return generated_text.strip()
|
303 |
+
|
304 |
+
except Exception as e:
|
305 |
+
if attempt < max_retries:
|
306 |
+
# Ждем перед следующей попыткой, если это не последняя попытка
|
307 |
+
await asyncio.sleep(retry_delay)
|
308 |
+
continue
|
309 |
+
else:
|
310 |
+
# Если исчерпаны все попытки, пробрасываем исключение
|
311 |
+
raise Exception(f"predict_chat_stream failed after {max_retries} retries: {str(e)}")
|
312 |
|
|
|
313 |
|
314 |
async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
|
315 |
params: LlmPredictParams) -> AsyncGenerator[str, None]:
|
components/services/dialogue.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
import re
|
4 |
-
from typing import List
|
5 |
|
6 |
from pydantic import BaseModel
|
7 |
|
@@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
|
19 |
class QEResult(BaseModel):
|
20 |
use_search: bool
|
21 |
search_query: str | None
|
|
|
22 |
|
23 |
|
24 |
class DialogueService:
|
@@ -71,6 +72,7 @@ class DialogueService:
|
|
71 |
return QEResult(
|
72 |
use_search=from_chat is not None,
|
73 |
search_query=from_chat.content if from_chat else None,
|
|
|
74 |
)
|
75 |
|
76 |
def get_qe_result_from_chat(self, history: List[Message]) -> QEResult:
|
@@ -129,7 +131,8 @@ class DialogueService:
|
|
129 |
else:
|
130 |
raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
|
131 |
|
132 |
-
return QEResult(use_search=bool_var, search_query=second_part
|
|
|
133 |
|
134 |
def _get_search_query(self, history: List[Message]) -> Message | None:
|
135 |
"""
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
import re
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
|
6 |
from pydantic import BaseModel
|
7 |
|
|
|
19 |
class QEResult(BaseModel):
|
20 |
use_search: bool
|
21 |
search_query: str | None
|
22 |
+
debug_message: Optional[str | None] = ""
|
23 |
|
24 |
|
25 |
class DialogueService:
|
|
|
72 |
return QEResult(
|
73 |
use_search=from_chat is not None,
|
74 |
search_query=from_chat.content if from_chat else None,
|
75 |
+
debug_message=response
|
76 |
)
|
77 |
|
78 |
def get_qe_result_from_chat(self, history: List[Message]) -> QEResult:
|
|
|
131 |
else:
|
132 |
raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
|
133 |
|
134 |
+
return QEResult(use_search=bool_var, search_query=second_part,
|
135 |
+
debug_message=input_text)
|
136 |
|
137 |
def _get_search_query(self, history: List[Message]) -> Message | None:
|
138 |
"""
|
routes/auth.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
from fastapi import APIRouter, Body, Form, HTTPException
|
3 |
from datetime import timedelta
|
4 |
import common.auth as auth
|
5 |
|
@@ -7,8 +7,8 @@ router = APIRouter(prefix="/auth", tags=["Auth"])
|
|
7 |
|
8 |
def authenticate_user(username: str, password: str):
|
9 |
"""Проверяет, существует ли пользователь и правильный ли пароль."""
|
10 |
-
user = next((u for u in auth.USERS if u
|
11 |
-
if not user or user
|
12 |
raise HTTPException(status_code=401, detail="Неверный логин или пароль")
|
13 |
return user
|
14 |
|
@@ -20,7 +20,7 @@ def generate_access_token(username: str):
|
|
20 |
async def login_common(username: str, password: str):
|
21 |
"""Общий метод аутентификации."""
|
22 |
user = authenticate_user(username, password)
|
23 |
-
access_token = generate_access_token(user
|
24 |
return {"access_token": access_token, "token_type": "bearer"}
|
25 |
|
26 |
@router.post("/login", summary="Авторизация через JSON")
|
@@ -31,4 +31,8 @@ async def login_json(request: auth.LoginRequest = Body(...)):
|
|
31 |
@router.post("/login/token", summary="Авторизация через Form-Data")
|
32 |
async def login_form(username: str = Form(...), password: str = Form(...)):
|
33 |
"""Принимает Form-Data (x-www-form-urlencoded)."""
|
34 |
-
return await login_common(username, password)
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, Optional
|
2 |
+
from fastapi import APIRouter, Body, Depends, Form, HTTPException
|
3 |
from datetime import timedelta
|
4 |
import common.auth as auth
|
5 |
|
|
|
7 |
|
8 |
def authenticate_user(username: str, password: str):
|
9 |
"""Проверяет, существует ли пользователь и правильный ли пароль."""
|
10 |
+
user = next((u for u in auth.USERS if u.username == username), None)
|
11 |
+
if not user or user.password != password:
|
12 |
raise HTTPException(status_code=401, detail="Неверный логин или пароль")
|
13 |
return user
|
14 |
|
|
|
20 |
async def login_common(username: str, password: str):
|
21 |
"""Общий метод аутентификации."""
|
22 |
user = authenticate_user(username, password)
|
23 |
+
access_token = generate_access_token(user.username)
|
24 |
return {"access_token": access_token, "token_type": "bearer"}
|
25 |
|
26 |
@router.post("/login", summary="Авторизация через JSON")
|
|
|
31 |
@router.post("/login/token", summary="Авторизация через Form-Data")
|
32 |
async def login_form(username: str = Form(...), password: str = Form(...)):
|
33 |
"""Принимает Form-Data (x-www-form-urlencoded)."""
|
34 |
+
return await login_common(username, password)
|
35 |
+
|
36 |
+
@router.post("/checktoken", summary="Проверяет, аутентифицирован ли пользователь")
|
37 |
+
async def check_token(current_user: Annotated[auth.User, Depends(auth.get_current_user)]):
|
38 |
+
return {"current_user": current_user.username}
|
routes/llm.py
CHANGED
@@ -123,7 +123,13 @@ async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prom
|
|
123 |
"""
|
124 |
try:
|
125 |
qe_result = await dialogue_service.get_qe_result(request.history)
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
except Exception as e:
|
128 |
logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True)
|
129 |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"
|
|
|
123 |
"""
|
124 |
try:
|
125 |
qe_result = await dialogue_service.get_qe_result(request.history)
|
126 |
+
qe_event = {
|
127 |
+
"event": "debug",
|
128 |
+
"data": {
|
129 |
+
"text": qe_result.debug_message
|
130 |
+
}
|
131 |
+
}
|
132 |
+
yield f"data: {json.dumps(qe_event, ensure_ascii=False)}\n\n"
|
133 |
except Exception as e:
|
134 |
logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True)
|
135 |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"
|