Spaces:
Build error
Build error
import asyncio | |
from functools import partial | |
from typing import Any, Callable | |
from openhands.core.exceptions import UserCancelledError | |
from openhands.core.logger import openhands_logger as logger | |
from openhands.llm.async_llm import LLM_RETRY_EXCEPTIONS, AsyncLLM | |
from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS | |
class StreamingLLM(AsyncLLM): | |
"""Streaming LLM class.""" | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
super().__init__(*args, **kwargs) | |
self._async_streaming_completion = partial( | |
self._call_acompletion, | |
model=self.config.model, | |
api_key=self.config.api_key.get_secret_value() | |
if self.config.api_key | |
else None, | |
base_url=self.config.base_url, | |
api_version=self.config.api_version, | |
custom_llm_provider=self.config.custom_llm_provider, | |
max_tokens=self.config.max_output_tokens, | |
timeout=self.config.timeout, | |
temperature=self.config.temperature, | |
top_p=self.config.top_p, | |
drop_params=self.config.drop_params, | |
stream=True, # Ensure streaming is enabled | |
) | |
async_streaming_completion_unwrapped = self._async_streaming_completion | |
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any: | |
messages: list[dict[str, Any]] | dict[str, Any] = [] | |
# some callers might send the model and messages directly | |
# litellm allows positional args, like completion(model, messages, **kwargs) | |
# see llm.py for more details | |
if len(args) > 1: | |
messages = args[1] if len(args) > 1 else args[0] | |
kwargs['messages'] = messages | |
# remove the first args, they're sent in kwargs | |
args = args[2:] | |
elif 'messages' in kwargs: | |
messages = kwargs['messages'] | |
# ensure we work with a list of messages | |
messages = messages if isinstance(messages, list) else [messages] | |
# if we have no messages, something went very wrong | |
if not messages: | |
raise ValueError( | |
'The messages list is empty. At least one message is required.' | |
) | |
# Set reasoning effort for models that support it | |
if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS: | |
kwargs['reasoning_effort'] = self.config.reasoning_effort | |
self.log_prompt(messages) | |
try: | |
# Directly call and await litellm_acompletion | |
resp = await async_streaming_completion_unwrapped(*args, **kwargs) | |
# For streaming we iterate over the chunks | |
async for chunk in resp: | |
# Check for cancellation before yielding the chunk | |
if ( | |
hasattr(self.config, 'on_cancel_requested_fn') | |
and self.config.on_cancel_requested_fn is not None | |
and await self.config.on_cancel_requested_fn() | |
): | |
raise UserCancelledError( | |
'LLM request cancelled due to CANCELLED state' | |
) | |
# with streaming, it is "delta", not "message"! | |
message_back = chunk['choices'][0]['delta'].get('content', '') | |
if message_back: | |
self.log_response(message_back) | |
self._post_completion(chunk) | |
yield chunk | |
except UserCancelledError: | |
logger.debug('LLM request cancelled by user.') | |
raise | |
except Exception as e: | |
logger.error(f'Completion Error occurred:\n{e}') | |
raise | |
finally: | |
# sleep for 0.1 seconds to allow the stream to be flushed | |
if kwargs.get('stream', False): | |
await asyncio.sleep(0.1) | |
self._async_streaming_completion = async_streaming_completion_wrapper | |
def async_streaming_completion(self) -> Callable: | |
"""Decorator for the async litellm acompletion function with streaming.""" | |
return self._async_streaming_completion | |