Spaces:
Sleeping
Sleeping
import os | |
import json | |
from g4f.providers.response import Reasoning, JsonConversation, FinishReason | |
from g4f.typing import AsyncResult, Messages | |
import json | |
import re | |
import time | |
import logging | |
from urllib.parse import quote_plus | |
from fastapi import FastAPI, Response, Request | |
from fastapi.responses import RedirectResponse | |
from g4f.image import images_dir, copy_images | |
import g4f.api | |
import g4f.Provider | |
from g4f.Provider.base_provider import AsyncGeneratorProvider, ProviderModelMixin | |
from g4f.typing import AsyncResult, Messages | |
from g4f.requests import StreamSession | |
from g4f.providers.response import ProviderInfo, JsonConversation, PreviewResponse, SynthesizeData, TitleGeneration, RequestLogin | |
from g4f.providers.response import Parameters, FinishReason, Usage, Reasoning | |
from g4f.errors import ModelNotSupportedError | |
from g4f import debug | |
import demo | |
class BackendApi(AsyncGeneratorProvider, ProviderModelMixin): | |
url = "https://ahe.hopto.org" | |
working = True | |
ssl = False | |
models = [ | |
*g4f.Provider.OpenaiAccount.get_models(), | |
*g4f.Provider.PerplexityLabs.get_models(), | |
"flux", | |
"flux-pro", | |
"MiniMax-01", | |
"Microsoft Copilot", | |
] | |
def get_model(cls, model): | |
if "MiniMax" in model: | |
model = "MiniMax" | |
elif "Copilot" in model: | |
model = "Copilot" | |
elif "FLUX" in model: | |
model = f"flux-{model.split('-')[-1]}" | |
elif "flux" in model: | |
model = model.split(' ')[-1] | |
elif model in g4f.Provider.OpenaiAccount.get_models(): | |
pass | |
elif model in g4f.Provider.PerplexityLabs.get_models(): | |
pass | |
else: | |
raise ModelNotSupportedError(f"Model: {model}") | |
return model | |
def get_provider(cls, model): | |
if model.startswith("MiniMax"): | |
return "HailuoAI" | |
elif model == "Copilot" or "dall-e" in model: | |
return "CopilotAccount" | |
elif model in g4f.Provider.OpenaiAccount.get_models(): | |
return "OpenaiAccount" | |
elif model in g4f.Provider.PerplexityLabs.get_models(): | |
return "PerplexityLabs" | |
return None | |
async def create_async_generator( | |
cls, | |
model: str, | |
messages: Messages, | |
api_key: str = None, | |
proxy: str = None, | |
timeout: int = 0, | |
**kwargs | |
) -> AsyncResult: | |
debug.log(f"{__name__}: {api_key}") | |
if "dall-e" in model and "prompt" not in kwargs: | |
kwargs["prompt"] = messages[-1]["content"] | |
messages[-1]["content"] = f"Generate a image: {kwargs['prompt']}" | |
async with StreamSession( | |
proxy=proxy, | |
headers={"Accept": "text/event-stream", **demo.headers}, | |
timeout=timeout | |
) as session: | |
model = cls.get_model(model) | |
provider = cls.get_provider(model) | |
async with session.post(f"{cls.url}/backend-api/v2/conversation", json={ | |
**kwargs, | |
"model": model, | |
"messages": messages, | |
"provider": provider | |
}, ssl=cls.ssl) as response: | |
is_thinking = 0 | |
async for line in response.iter_lines(): | |
response.raise_for_status() | |
data = json.loads(line) | |
data_type = data.pop("type") | |
if data_type == "provider": | |
yield ProviderInfo(**data[data_type]) | |
provider = data[data_type]["name"] | |
elif data_type == "conversation": | |
yield JsonConversation(**data[data_type][provider] if provider in data[data_type] else data[data_type][""]) | |
elif data_type == "conversation_id": | |
pass | |
elif data_type == "message": | |
yield Exception(data) | |
elif data_type == "preview": | |
yield PreviewResponse(data[data_type]) | |
elif data_type == "content": | |
def on_image(match): | |
extension = match.group(3).split(".")[-1].split("?")[0] | |
extension = "" if not extension or len(extension) > 4 else f".{extension}" | |
filename = f"{int(time.time())}_{quote_plus(match.group(1)[:100], '')}{extension}" | |
download_url = f"/download/{filename}?url={cls.url}{match.group(3)}" | |
return f"[](/images/{filename})" | |
if "<think>" in data[data_type]: | |
data[data_type] = data[data_type].split("<think>", 1) | |
yield data[data_type][0] | |
yield Reasoning(data[data_type][1]) | |
yield Reasoning(None, "Is thinking...") | |
is_thinking = time.time() | |
if "</think>" in data[data_type][1]: | |
data[data_type][1] = data[data_type].split("</think>", 1) | |
yield Reasoning(data[data_type][0]) | |
yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds") | |
yield data[data_type][1] | |
is_thinking = 0 | |
elif is_thinking: | |
yield Reasoning(data[data_type]) | |
else: | |
yield re.sub(r'\[\!\[(.+?)\]\(([^)]+?)\)\]\(([^)]+?)\)', on_image, data[data_type]) | |
elif data_type =="synthesize": | |
yield SynthesizeData(**data[data_type]) | |
elif data_type == "parameters": | |
yield Parameters(**data[data_type]) | |
elif data_type == "usage": | |
yield Usage(**data[data_type]) | |
elif data_type == "reasoning": | |
yield Reasoning(**data) | |
elif data_type == "login": | |
pass | |
elif data_type == "title": | |
yield TitleGeneration(data[data_type]) | |
elif data_type == "finish": | |
yield FinishReason(data[data_type]["reason"]) | |
elif data_type == "log": | |
debug.log(data[data_type]) | |
else: | |
debug.log(f"Unknown data: ({data_type}) {data}") | |
g4f.Provider.__map__["Feature"] = BackendApi | |
import asyncio | |
import uuid | |
from aiohttp import ClientSession, ClientError | |
from g4f.typing import Optional, Cookies | |
from g4f.image import is_accepted_format | |
async def copy_images( | |
images: list[str], | |
cookies: Optional[Cookies] = None, | |
headers: dict = None, | |
proxy: Optional[str] = None, | |
add_url: bool = True, | |
target: str = None, | |
ssl: bool = None | |
) -> list[str]: | |
if add_url: | |
add_url = not cookies | |
#ensure_images_dir() | |
async with ClientSession( | |
#connector=get_connector(proxy=proxy), | |
cookies=cookies, | |
headers=headers, | |
) as session: | |
async def copy_image(image: str, target: str = None) -> str: | |
if target is None or len(images) > 1: | |
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") | |
try: | |
if image.startswith("data:"): | |
pass | |
#with open(target, "wb") as f: | |
# f.write(extract_data_uri(image)) | |
else: | |
try: | |
async with session.get(image, proxy=proxy, ssl=ssl) as response: | |
response.raise_for_status() | |
with open(target, "wb") as f: | |
async for chunk in response.content.iter_chunked(4096): | |
f.write(chunk) | |
except ClientError as e: | |
debug.log(f"copy_images failed: {e.__class__.__name__}: {e}") | |
return image | |
if "." not in target: | |
with open(target, "rb") as f: | |
extension = is_accepted_format(f.read(12)).split("/")[-1] | |
extension = "jpg" if extension == "jpeg" else extension | |
new_target = f"{target}.{extension}" | |
os.rename(target, new_target) | |
target = new_target | |
finally: | |
if "." not in target and os.path.exists(target): | |
os.unlink(target) | |
return f"/images/{os.path.basename(target)}{'?url=' + image if add_url and not image.startswith('data:') else ''}" | |
return await asyncio.gather(*[copy_image(image, target) for image in images]) | |
def create_app(): | |
g4f.debug.logging = True | |
g4f.api.AppConfig.gui = True | |
g4f.api.AppConfig.demo = False | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
g4f.api.CORSMiddleware, | |
allow_origin_regex=".*", | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
api = g4f.api.Api(app) | |
api.register_routes() | |
api.register_authorization() | |
api.register_validation_exception_handler() | |
async def download(filename, request: Request): | |
filename = os.path.basename(filename) | |
if "." not in filename: | |
target = os.path.join(images_dir, filename) | |
filename = f"{filename}.jpg" | |
target = os.path.join(images_dir, filename) | |
if not os.path.exists(target): | |
url = str(request.query_params).split("url=", 1)[1] | |
if url: | |
source_url = url.replace("%2F", "/").replace("%3A", ":").replace("%3F", "?") | |
await copy_images( | |
[source_url], | |
target=target, | |
ssl=False, | |
headers=demo.headers if source_url.startswith(BackendApi.url) else None) | |
if not os.path.exists(target): | |
return Response(status_code=404) | |
return RedirectResponse(f"/images/{filename}") | |
gui_app = g4f.api.WSGIMiddleware(g4f.api.get_gui_app(g4f.api.AppConfig.demo)) | |
app.mount("/", gui_app) | |
return app | |
class NoHomeFilter(logging.Filter): | |
def filter(self, record): | |
if '"GET / HTTP/1.1" 200 OK' in record.getMessage(): | |
return False | |
if '"GET /static/' in record.getMessage(): | |
return False | |
return True | |
logging.getLogger("uvicorn.access").addFilter(NoHomeFilter()) | |
app = create_app() |