File size: 7,317 Bytes
79859e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from __future__ import annotations
import logging
import os
import asyncio
from typing import Iterator
from flask import send_from_directory
from inspect import signature
from g4f import version, models
from g4f import get_last_provider, ChatCompletion
from g4f.errors import VersionNotFoundError
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData
from g4f.client.service import convert_to_provider
from g4f import debug
logger = logging.getLogger(__name__)
conversations: dict[dict[str, BaseConversation]] = {}
class Api:
@staticmethod
def get_models():
return models._all_models
@staticmethod
def get_provider_models(provider: str, api_key: str = None):
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
if api_key is not None and "api_key" in signature(provider.get_models).parameters:
models = provider.get_models(api_key=api_key)
else:
models = provider.get_models()
return [
{
"model": model,
"default": model == provider.default_model,
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
"image": False if provider.image_models is None else model in provider.image_models,
}
for model in models
]
return []
@staticmethod
def get_providers() -> dict[str, str]:
return {
provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__)
+ (" (Image Generation)" if getattr(provider, "image_models", None) else "")
+ (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "")
+ (" (WebDriver)" if "webdriver" in provider.get_parameters() else "")
+ (" (Auth)" if provider.needs_auth else "")
for provider in __providers__
if provider.working
}
@staticmethod
def get_version() -> dict:
try:
current_version = version.utils.current_version
except VersionNotFoundError:
current_version = None
return {
"version": current_version,
"latest_version": version.utils.latest_version,
}
def serve_images(self, name):
ensure_images_dir()
return send_from_directory(os.path.abspath(images_dir), name)
def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
model = json_data.get('model') or models.default
provider = json_data.get('provider')
messages = json_data['messages']
api_key = json_data.get("api_key")
if api_key is not None:
kwargs["api_key"] = api_key
do_web_search = json_data.get('web_search')
if do_web_search and provider:
provider_handler = convert_to_provider(provider)
if hasattr(provider_handler, "get_parameters"):
if "web_search" in provider_handler.get_parameters():
kwargs['web_search'] = True
do_web_search = False
if do_web_search:
from .internet import get_search_message
messages[-1]["content"] = get_search_message(messages[-1]["content"])
if json_data.get("auto_continue"):
kwargs['auto_continue'] = True
conversation_id = json_data.get("conversation_id")
if conversation_id and provider:
if provider in conversations and conversation_id in conversations[provider]:
kwargs["conversation"] = conversations[provider][conversation_id]
return {
"model": model,
"provider": provider,
"messages": messages,
"stream": True,
"ignore_stream": True,
"return_conversation": True,
**kwargs
}
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
debug.logs = []
print_callback = debug.log_handler
def log_handler(text: str):
debug.logs.append(text)
print_callback(text)
debug.log_handler = log_handler
try:
result = ChatCompletion.create(**kwargs)
first = True
if isinstance(result, ImageResponse):
if first:
first = False
yield self._format_json("provider", get_last_provider(True))
yield self._format_json("content", str(result))
else:
for chunk in result:
if first:
first = False
yield self._format_json("provider", get_last_provider(True))
if isinstance(chunk, BaseConversation):
if provider:
if provider not in conversations:
conversations[provider] = {}
conversations[provider][conversation_id] = chunk
yield self._format_json("conversation", conversation_id)
elif isinstance(chunk, Exception):
logger.exception(chunk)
yield self._format_json("message", get_error_message(chunk))
elif isinstance(chunk, ImagePreview):
yield self._format_json("preview", chunk.to_string())
elif isinstance(chunk, ImageResponse):
images = chunk
if download_images:
images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images))
elif isinstance(chunk, SynthesizeData):
yield self._format_json("synthesize", chunk.to_json())
elif not isinstance(chunk, FinishReason):
yield self._format_json("content", str(chunk))
if debug.logs:
for log in debug.logs:
yield self._format_json("log", str(log))
debug.logs = []
except Exception as e:
logger.exception(e)
yield self._format_json('error', get_error_message(e))
def _format_json(self, response_type: str, content):
return {
'type': response_type,
response_type: content
}
def get_error_message(exception: Exception) -> str:
message = f"{type(exception).__name__}: {exception}"
provider = get_last_provider()
if provider is None:
return message
return f"{provider.__name__}: {message}" |