muryshev commited on
Commit
1e5d06f
·
1 Parent(s): 308de05
common/auth.py CHANGED
@@ -11,10 +11,14 @@ SECRET_KEY = os.environ.get("JWT_SECRET", "ooooooh_thats_my_super_secret_key")
11
  ALGORITHM = "HS256"
12
  ACCESS_TOKEN_EXPIRE_MINUTES = 1440
13
 
 
 
 
 
14
  # Захардкоженные пользователи
15
  USERS = [
16
- {"username": "admin", "password": "admin123"},
17
- {"username": "demo", "password": "sTrUPsORPA"},
18
  ]
19
 
20
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login/token")
@@ -39,7 +43,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
39
  username: str = payload.get("sub")
40
  if username is None:
41
  raise HTTPException(status_code=401, detail="Invalid token")
42
- user = next((u for u in USERS if u["username"] == username), None)
43
  if user is None:
44
  raise HTTPException(status_code=401, detail="User not found")
45
  return user
 
11
  ALGORITHM = "HS256"
12
  ACCESS_TOKEN_EXPIRE_MINUTES = 1440
13
 
14
+ class User(BaseModel):
15
+ username: str
16
+ password: str
17
+
18
  # Захардкоженные пользователи
19
  USERS = [
20
+ User(username="admin", password="admin123"),
21
+ User(username="demo", password="sTrUPsORPA")
22
  ]
23
 
24
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login/token")
 
43
  username: str = payload.get("sub")
44
  if username is None:
45
  raise HTTPException(status_code=401, detail="Invalid token")
46
+ user = next((u for u in USERS if u.username == username), None)
47
  if user is None:
48
  raise HTTPException(status_code=401, detail="User not found")
49
  return user
components/llm/deepinfra_api.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  from typing import AsyncGenerator, Optional, List
3
  import httpx
@@ -256,13 +257,17 @@ class DeepInfraApi(LlmApi):
256
  logging.error(f"Request failed: status code {response.status_code}")
257
  logging.error(response.text)
258
 
259
- async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams) -> str:
260
  """
261
  Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат.
262
-
263
  Args:
264
- prompt (str): Входной текст для предсказания.
265
-
 
 
 
 
266
  Returns:
267
  str: Сгенерированный текст.
268
  """
@@ -270,32 +275,41 @@ class DeepInfraApi(LlmApi):
270
  request = self.create_chat_request(request, system_prompt, params)
271
  request["stream"] = True
272
 
273
- print(super().create_headers())
274
- async with client.stream("POST", f"{self.params.url}/v1/openai/chat/completions", json=request, headers=super().create_headers()) as response:
275
- if response.status_code != 200:
276
- # Если ошибка, читаем ответ для получения подробностей
277
- error_content = await response.aread()
278
- raise Exception(f"API error: {error_content.decode('utf-8')}")
279
-
280
- # Для хранения результата
281
- generated_text = ""
282
-
283
- # Асинхронное чтение построчно
284
- async for line in response.aiter_lines():
285
- if line.startswith("data: "): # SSE-сообщения начинаются с "data: "
286
- try:
287
- # Парсим JSON из строки
288
- data = json.loads(line[len("data: "):].strip())
289
- if data == "[DONE]": # Конец потока
290
- break
291
- if "choices" in data and data["choices"]:
292
- # Получаем текст из текущего токена
293
- token_value = data["choices"][0].get("delta", {}).get("content", "")
294
- generated_text += token_value
295
- except json.JSONDecodeError:
296
- continue # Игнорируем строки, которые не удается декодировать
 
 
 
 
 
 
 
 
 
 
297
 
298
- return generated_text.strip()
299
 
300
  async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
301
  params: LlmPredictParams) -> AsyncGenerator[str, None]:
 
1
+ import asyncio
2
  import json
3
  from typing import AsyncGenerator, Optional, List
4
  import httpx
 
257
  logging.error(f"Request failed: status code {response.status_code}")
258
  logging.error(response.text)
259
 
260
+ async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams, max_retries: int = 3, retry_delay: float = 0.5) -> str:
261
  """
262
  Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат.
263
+
264
  Args:
265
+ request (ChatRequest): Запрос чата
266
+ system_prompt: Системный промпт
267
+ params (LlmPredictParams): Параметры предсказания
268
+ max_retries (int): Максимальное количество попыток переподключения (по умолчанию 3)
269
+ retry_delay (float): Задержка между попытками в секундах (по умолчанию 0.5)
270
+
271
  Returns:
272
  str: Сгенерированный текст.
273
  """
 
275
  request = self.create_chat_request(request, system_prompt, params)
276
  request["stream"] = True
277
 
278
+ for attempt in range(max_retries + 1):
279
+ try:
280
+ async with client.stream("POST", f"{self.params.url}/v1/openai/chat/completions",
281
+ json=request,
282
+ headers=super().create_headers()) as response:
283
+
284
+ if response.status_code != 200:
285
+ error_content = await response.aread()
286
+ raise Exception(f"API error: {error_content.decode('utf-8')}")
287
+
288
+ generated_text = ""
289
+
290
+ async for line in response.aiter_lines():
291
+ if line.startswith("data: "):
292
+ try:
293
+ data = json.loads(line[len("data: "):].strip())
294
+ if data == "[DONE]":
295
+ break
296
+ if "choices" in data and data["choices"]:
297
+ token_value = data["choices"][0].get("delta", {}).get("content", "")
298
+ generated_text += token_value
299
+ except json.JSONDecodeError:
300
+ continue
301
+
302
+ return generated_text.strip()
303
+
304
+ except Exception as e:
305
+ if attempt < max_retries:
306
+ # Ждем перед следующей попыткой, если это не последняя попытка
307
+ await asyncio.sleep(retry_delay)
308
+ continue
309
+ else:
310
+ # Если исчерпаны все попытки, пробрасываем исключение
311
+ raise Exception(f"predict_chat_stream failed after {max_retries} retries: {str(e)}")
312
 
 
313
 
314
  async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
315
  params: LlmPredictParams) -> AsyncGenerator[str, None]:
components/services/dialogue.py CHANGED
@@ -1,7 +1,7 @@
1
  import logging
2
  import os
3
  import re
4
- from typing import List
5
 
6
  from pydantic import BaseModel
7
 
@@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
19
  class QEResult(BaseModel):
20
  use_search: bool
21
  search_query: str | None
 
22
 
23
 
24
  class DialogueService:
@@ -71,6 +72,7 @@ class DialogueService:
71
  return QEResult(
72
  use_search=from_chat is not None,
73
  search_query=from_chat.content if from_chat else None,
 
74
  )
75
 
76
  def get_qe_result_from_chat(self, history: List[Message]) -> QEResult:
@@ -129,7 +131,8 @@ class DialogueService:
129
  else:
130
  raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
131
 
132
- return QEResult(use_search=bool_var, search_query=second_part)
 
133
 
134
  def _get_search_query(self, history: List[Message]) -> Message | None:
135
  """
 
1
  import logging
2
  import os
3
  import re
4
+ from typing import List, Optional, Tuple
5
 
6
  from pydantic import BaseModel
7
 
 
19
  class QEResult(BaseModel):
20
  use_search: bool
21
  search_query: str | None
22
+ debug_message: Optional[str | None] = ""
23
 
24
 
25
  class DialogueService:
 
72
  return QEResult(
73
  use_search=from_chat is not None,
74
  search_query=from_chat.content if from_chat else None,
75
+ debug_message=response
76
  )
77
 
78
  def get_qe_result_from_chat(self, history: List[Message]) -> QEResult:
 
131
  else:
132
  raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
133
 
134
+ return QEResult(use_search=bool_var, search_query=second_part,
135
+ debug_message=input_text)
136
 
137
  def _get_search_query(self, history: List[Message]) -> Message | None:
138
  """
routes/auth.py CHANGED
@@ -1,5 +1,5 @@
1
- from typing import Optional
2
- from fastapi import APIRouter, Body, Form, HTTPException
3
  from datetime import timedelta
4
  import common.auth as auth
5
 
@@ -7,8 +7,8 @@ router = APIRouter(prefix="/auth", tags=["Auth"])
7
 
8
  def authenticate_user(username: str, password: str):
9
  """Проверяет, существует ли пользователь и правильный ли пароль."""
10
- user = next((u for u in auth.USERS if u["username"] == username), None)
11
- if not user or user["password"] != password:
12
  raise HTTPException(status_code=401, detail="Неверный логин или пароль")
13
  return user
14
 
@@ -20,7 +20,7 @@ def generate_access_token(username: str):
20
  async def login_common(username: str, password: str):
21
  """Общий метод аутентификации."""
22
  user = authenticate_user(username, password)
23
- access_token = generate_access_token(user["username"])
24
  return {"access_token": access_token, "token_type": "bearer"}
25
 
26
  @router.post("/login", summary="Авторизация через JSON")
@@ -31,4 +31,8 @@ async def login_json(request: auth.LoginRequest = Body(...)):
31
  @router.post("/login/token", summary="Авторизация через Form-Data")
32
  async def login_form(username: str = Form(...), password: str = Form(...)):
33
  """Принимает Form-Data (x-www-form-urlencoded)."""
34
- return await login_common(username, password)
 
 
 
 
 
1
+ from typing import Annotated, Optional
2
+ from fastapi import APIRouter, Body, Depends, Form, HTTPException
3
  from datetime import timedelta
4
  import common.auth as auth
5
 
 
7
 
8
  def authenticate_user(username: str, password: str):
9
  """Проверяет, существует ли пользователь и правильный ли пароль."""
10
+ user = next((u for u in auth.USERS if u.username == username), None)
11
+ if not user or user.password != password:
12
  raise HTTPException(status_code=401, detail="Неверный логин или пароль")
13
  return user
14
 
 
20
  async def login_common(username: str, password: str):
21
  """Общий метод аутентификации."""
22
  user = authenticate_user(username, password)
23
+ access_token = generate_access_token(user.username)
24
  return {"access_token": access_token, "token_type": "bearer"}
25
 
26
  @router.post("/login", summary="Авторизация через JSON")
 
31
  @router.post("/login/token", summary="Авторизация через Form-Data")
32
  async def login_form(username: str = Form(...), password: str = Form(...)):
33
  """Принимает Form-Data (x-www-form-urlencoded)."""
34
+ return await login_common(username, password)
35
+
36
+ @router.post("/checktoken", summary="Проверяет, аутентифицирован ли пользователь")
37
+ async def check_token(current_user: Annotated[auth.User, Depends(auth.get_current_user)]):
38
+ return {"current_user": current_user.username}
routes/llm.py CHANGED
@@ -123,7 +123,13 @@ async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prom
123
  """
124
  try:
125
  qe_result = await dialogue_service.get_qe_result(request.history)
126
-
 
 
 
 
 
 
127
  except Exception as e:
128
  logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True)
129
  yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"
 
123
  """
124
  try:
125
  qe_result = await dialogue_service.get_qe_result(request.history)
126
+ qe_event = {
127
+ "event": "debug",
128
+ "data": {
129
+ "text": qe_result.debug_message
130
+ }
131
+ }
132
+ yield f"data: {json.dumps(qe_event, ensure_ascii=False)}\n\n"
133
  except Exception as e:
134
  logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True)
135
  yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n"