Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import re | |
from abc import ABC, abstractmethod | |
from collections import defaultdict | |
from collections.abc import Hashable | |
from pathlib import Path | |
from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union | |
from PIL import Image | |
from pydantic import Field, field_validator | |
from tenacity import retry, stop_after_attempt, stop_after_delay | |
from ...base import BotBase | |
from ...utils.env import EnvVar | |
from ...utils.general import LRUCache | |
from ...utils.registry import registry | |
from .prompt.base import _OUTPUT_PARSER, StrParser | |
from .prompt.parser import BaseOutputParser | |
from .prompt.prompt import PromptTemplate | |
from .schemas import Message | |
import copy | |
from collections.abc import Iterator | |
T = TypeVar("T", str, dict, list) | |
class BaseLLM(BotBase, ABC): | |
cache: bool = False | |
lru_cache: LRUCache = Field(default=LRUCache(EnvVar.LLM_CACHE_NUM)) | |
def workflow_instance_id(self) -> str: | |
if hasattr(self, "_parent"): | |
return self._parent.workflow_instance_id | |
return None | |
def workflow_instance_id(self, value: str): | |
if hasattr(self, "_parent"): | |
self._parent.workflow_instance_id = value | |
def _call(self, records: List[Message], **kwargs) -> str: | |
"""Run the LLM on the given prompt and input.""" | |
async def _acall(self, records: List[Message], **kwargs) -> str: | |
"""Run the LLM on the given prompt and input.""" | |
raise NotImplementedError("Async generation not implemented for this LLM.") | |
def generate(self, records: List[Message], **kwargs) -> str: # TODO: use python native lru cache | |
"""Run the LLM on the given prompt and input.""" | |
if self.cache: | |
key = self._cache_key(records) | |
cached_res = self.lru_cache.get(key) | |
if cached_res: | |
return cached_res | |
else: | |
gen = self._call(records, **kwargs) | |
self.lru_cache.put(key, gen) | |
return gen | |
else: | |
return self._call(records, **kwargs) | |
async def agenerate(self, records: List[str], **kwargs) -> str: | |
"""Run the LLM on the given prompt and input.""" | |
if self.cache: | |
key = self._cache_key(records) | |
cached_res = self.lru_cache.get(key) | |
if cached_res: | |
return cached_res | |
else: | |
gen = await self._acall(records, **kwargs) | |
self.lru_cache.put(key, gen) | |
return gen | |
else: | |
return await self._acall(records, **kwargs) | |
def _cache_key(self, records: List[Message]) -> int: | |
return str([item.model_dump() for item in records]) | |
def dict(self, *args, **kwargs): | |
kwargs["exclude"] = {"lru_cache"} | |
return super().model_dump(*args, **kwargs) | |
def json(self, *args, **kwargs): | |
kwargs["exclude"] = {"lru_cache"} | |
return super().model_dump_json(*args, **kwargs) | |
T = TypeVar("T", str, dict, list) | |
class BaseLLMBackend(BotBase, ABC): | |
"""Prompts prepare and LLM infer""" | |
output_parser: Optional[BaseOutputParser] = None | |
prompts: List[PromptTemplate] = [] | |
llm: BaseLLM | |
def token_usage(self): | |
if not hasattr(self, 'workflow_instance_id'): | |
raise AttributeError("workflow_instance_id not set") | |
return dict(self.stm(self.workflow_instance_id).get('token_usage', defaultdict(int))) | |
def set_output_parser(cls, output_parser: Union[BaseOutputParser, Dict, None]): | |
if output_parser is None: | |
return StrParser() | |
elif isinstance(output_parser, BaseOutputParser): | |
return output_parser | |
elif isinstance(output_parser, dict): | |
return _OUTPUT_PARSER[output_parser["name"]](**output_parser) | |
else: | |
raise ValueError | |
def set_prompts( | |
cls, prompts: List[Union[PromptTemplate, Dict, str]] | |
) -> List[PromptTemplate]: | |
init_prompts = [] | |
for prompt in prompts: | |
prompt = copy.deepcopy(prompt) | |
if isinstance(prompt, Path): | |
if prompt.suffix == ".prompt": | |
init_prompts.append(PromptTemplate.from_file(prompt)) | |
elif isinstance(prompt, str): | |
if prompt.endswith(".prompt"): | |
init_prompts.append(PromptTemplate.from_file(prompt)) | |
init_prompts.append(PromptTemplate.from_template(prompt)) | |
elif isinstance(prompt, dict): | |
init_prompts.append(PromptTemplate.from_config(prompt)) | |
elif isinstance(prompt, PromptTemplate): | |
init_prompts.append(prompt) | |
else: | |
raise ValueError( | |
"Prompt only support str, dict and PromptTemplate object" | |
) | |
return init_prompts | |
def set_llm(cls, llm: Union[BaseLLM, Dict]): | |
if isinstance(llm, dict): | |
return registry.get_llm(llm["name"])(**llm) | |
elif isinstance(llm, BaseLLM): | |
return llm | |
else: | |
raise ValueError("LLM only support dict and BaseLLM object") | |
def prep_prompt( | |
self, input_list: List[Dict[str, Any]], prompts=None, **kwargs | |
) -> List[List[Message]]: | |
"""Prepare prompts from inputs.""" | |
if prompts is None: | |
prompts = self.prompts | |
images = [] | |
if len(kwargs_images := kwargs.get("images", [])): | |
images = kwargs_images | |
processed_prompts = [] | |
for inputs in input_list: | |
records = [] | |
for prompt in prompts: | |
selected_inputs = {k: inputs.get(k, "") for k in prompt.input_variables} | |
prompt_str = prompt.template | |
parts = re.split(r"(\{\{.*?\}\})", prompt_str) | |
formatted_parts = [] | |
for part in parts: | |
if part.startswith("{{") and part.endswith("}}"): | |
part = part[2:-2].strip() | |
value = selected_inputs[part] | |
if isinstance(value, (Image.Image, list)): | |
formatted_parts.extend( | |
[value] if isinstance(value, Image.Image) else value | |
) | |
else: | |
formatted_parts.append(str(value)) | |
else: | |
formatted_parts.append(str(part)) | |
formatted_parts = ( | |
formatted_parts[0] if len(formatted_parts) == 1 else formatted_parts | |
) | |
if prompt.role == "system": | |
records.append(Message.system(formatted_parts)) | |
elif prompt.role == "user": | |
records.append(Message.user(formatted_parts)) | |
if len(images): | |
records.append(Message.user(images)) | |
processed_prompts.append(records) | |
return processed_prompts | |
def infer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]: | |
prompts = self.prep_prompt(input_list, **kwargs) | |
res = [] | |
stm_token_usage = self.stm(self.workflow_instance_id).get('token_usage', defaultdict(int)) | |
def process_stream(self, stream_output): | |
for chunk in stream_output: | |
if chunk.usage is not None: | |
for key, value in chunk.usage.dict().items(): | |
if key in ["prompt_tokens", "completion_tokens", 'total_tokens']: | |
if value is not None: | |
stm_token_usage[key] += value | |
self.stm(self.workflow_instance_id)['token_usage'] = stm_token_usage | |
yield chunk | |
for prompt in prompts: | |
output = self.llm.generate(prompt, **kwargs) | |
if not isinstance(output, Iterator): | |
for key, value in output.get("usage", {}).items(): | |
if key in ["prompt_tokens", "completion_tokens", 'total_tokens']: | |
if value is not None: | |
stm_token_usage[key] += value | |
if not self.llm.stream: | |
for choice in output["choices"]: | |
if choice.get("message"): | |
choice["message"]["content"] = self.output_parser.parse( | |
choice["message"]["content"] | |
) | |
res.append(output) | |
else: | |
res.append(process_stream(self, output)) | |
self.stm(self.workflow_instance_id)['token_usage'] = stm_token_usage | |
return res | |
async def ainfer(self, input_list: List[Dict[str, Any]], **kwargs) -> List[T]: | |
prompts = self.prep_prompt(input_list) | |
res = [] | |
for prompt in prompts: | |
output = await self.llm.agenerate(prompt, **kwargs) | |
for key, value in output["usage"].items(): | |
self.token_usage[key] += value | |
for choice in output["choices"]: | |
if choice.get("message"): | |
choice["message"]["content"] = self.output_parser.parse( | |
choice["message"]["content"] | |
) | |
res.append(output) | |
return res | |
def simple_infer(self, **kwargs: Any) -> T: | |
return self.infer([kwargs])[0] | |
async def simple_ainfer(self, **kwargs: Any) -> T: | |
return await self.ainfer([kwargs])[0] | |