import logging import re from typing import List, Optional from pydantic import BaseModel from components.llm.common import ChatRequest, LlmPredictParams, Message from components.llm.deepinfra_api import DeepInfraApi from components.llm.prompts import PROMPT_QE from components.services.llm_config import LLMConfigService logger = logging.getLogger(__name__) class QEResult(BaseModel): use_search: bool search_query: str | None debug_message: Optional[str | None] = "" class DialogueService: def __init__( self, llm_api: DeepInfraApi, llm_config_service: LLMConfigService, ) -> None: self.prompt = PROMPT_QE self.llm_api = llm_api p = llm_config_service.get_default() self.llm_params = LlmPredictParams( temperature=p.temperature, top_p=p.top_p, min_p=p.min_p, seed=p.seed, frequency_penalty=p.frequency_penalty, presence_penalty=p.presence_penalty, n_predict=p.n_predict, ) async def get_qe_result(self, history: List[Message]) -> QEResult: """ Получает результат QE. Args: history: История диалога в виде списка сообщений Returns: QEResult: Результат QE """ request = self._get_qe_request(history) response = await self.llm_api.predict_chat_stream( request, "", self.llm_params, ) logger.info(f"QE response: {response}") try: return self._postprocess_qe(response) except Exception as e: logger.error(f"Error in _postprocess_qe: {e}") from_chat = self._get_search_query(history) return QEResult( use_search=from_chat is not None, search_query=from_chat.content if from_chat else None, debug_message=response, ) def get_qe_result_from_chat(self, history: List[Message]) -> QEResult: from_chat = self._get_search_query(history) return QEResult( use_search=from_chat is not None, search_query=from_chat.content if from_chat else None, ) def _get_qe_request(self, history: List[Message]) -> ChatRequest: """ Подготавливает полный промпт для QE запроса. Args: history: История диалога в виде списка сообщений Returns: str: Отформатированный промпт с историей диалога """ formatted_history = "\n".join( [self._format_message(msg) for msg in history] ).strip() message = self.prompt.format(history=formatted_history) return ChatRequest( history=[Message(role="user", content=message, searchResults='')] ) def _format_message(self, message: Message) -> str: """ Форматирует сообщение для запроса QE. Args: message: Сообщение для форматирования """ # if message.searchResults: # return f'{message.role}: {message.content}\n\n{message.searchResults}\n' return f'{message.role}: {message.content}' @staticmethod def _postprocess_qe(input_text: str) -> QEResult: # Находим все вхождения квадратных скобок matches = re.findall(r'\[([^\]]*)\]', input_text) # Проверяем количество найденных скобок if len(matches) != 2: raise ValueError("В тексте должно быть ровно две пары квадратных скобок.") # Извлекаем значения из скобок first_part = matches[0].strip().lower() second_part = matches[1].strip() if first_part == "да": bool_var = True elif first_part == "нет": bool_var = False else: raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.") return QEResult( use_search=bool_var, search_query=second_part, debug_message=input_text ) def _get_search_query(self, history: List[Message]) -> Message | None: """ Получает запрос для поиска на основе последнего сообщения пользователя. """ return next( ( msg for msg in reversed(history) if msg.role == "user" and (msg.searchResults is None or not msg.searchResults) ), None, )