roxky commited on
Commit
8acf8eb
·
1 Parent(s): b5b61d9

Add Custom Feature provider

Browse files
Files changed (1) hide show
  1. app.py +221 -1
app.py CHANGED
@@ -11,6 +11,226 @@ ssl.create_default_context = partial(
11
  import g4f.api
12
  import g4f.Provider
13
 
14
- g4f.Provider.__map__["Feature"] = g4f.Provider.Custom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  app = g4f.api.create_app_with_demo_and_debug()
 
11
  import g4f.api
12
  import g4f.Provider
13
 
14
+ import json
15
+ import time
16
+ import requests
17
+
18
+ from g4f.Provider.helper import filter_none
19
+ from g4f.Provider.base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
20
+ from g4f.typing import Union, Optional, AsyncResult, Messages, ImagesType
21
+ from g4f.requests import StreamSession, raise_for_status
22
+ from g4f.providers.response import FinishReason, ToolCalls, Usage, Reasoning, ImageResponse
23
+ from g4f.errors import MissingAuthError, ResponseError
24
+ from g4f.image import to_data_uri
25
+ from g4f import debug
26
+
27
+ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
28
+ api_base = ""
29
+ supports_message_history = True
30
+ supports_system_message = True
31
+ default_model = ""
32
+ fallback_models = []
33
+ sort_models = True
34
+ verify = None
35
+
36
+ @classmethod
37
+ def get_models(cls, api_key: str = None, api_base: str = None) -> list[str]:
38
+ if not cls.models:
39
+ try:
40
+ headers = {}
41
+ if api_base is None:
42
+ api_base = cls.api_base
43
+ if api_key is not None:
44
+ headers["authorization"] = f"Bearer {api_key}"
45
+ response = requests.get(f"{api_base}/models", headers=headers, verify=cls.verify)
46
+ raise_for_status(response)
47
+ data = response.json()
48
+ data = data.get("data") if isinstance(data, dict) else data
49
+ cls.image_models = [model.get("id") for model in data if model.get("image")]
50
+ cls.models = [model.get("id") for model in data]
51
+ if cls.sort_models:
52
+ cls.models.sort()
53
+ except Exception as e:
54
+ debug.log(e)
55
+ return cls.fallback_models
56
+ return cls.models
57
+
58
+ @classmethod
59
+ async def create_async_generator(
60
+ cls,
61
+ model: str,
62
+ messages: Messages,
63
+ proxy: str = None,
64
+ timeout: int = 120,
65
+ images: ImagesType = None,
66
+ api_key: str = None,
67
+ api_endpoint: str = None,
68
+ api_base: str = None,
69
+ temperature: float = None,
70
+ max_tokens: int = None,
71
+ top_p: float = None,
72
+ stop: Union[str, list[str]] = None,
73
+ stream: bool = False,
74
+ prompt: str = None,
75
+ headers: dict = None,
76
+ impersonate: str = None,
77
+ tools: Optional[list] = None,
78
+ extra_data: dict = {},
79
+ **kwargs
80
+ ) -> AsyncResult:
81
+ if cls.needs_auth and api_key is None:
82
+ raise MissingAuthError('Add a "api_key"')
83
+ async with StreamSession(
84
+ proxy=proxy,
85
+ headers=cls.get_headers(stream, api_key, headers),
86
+ timeout=timeout,
87
+ impersonate=impersonate,
88
+ ) as session:
89
+ model = cls.get_model(model, api_key=api_key, api_base=api_base)
90
+ if api_base is None:
91
+ api_base = cls.api_base
92
+
93
+ # Proxy for image generation feature
94
+ if prompt and model and model in cls.image_models:
95
+ data = {
96
+ "prompt": messages[-1]["content"] if prompt is None else prompt,
97
+ "model": model,
98
+ }
99
+ async with session.post(f"{api_base.rstrip('/')}/images/generations", json=data, ssl=cls.verify) as response:
100
+ data = await response.json()
101
+ cls.raise_error(data)
102
+ await raise_for_status(response)
103
+ yield ImageResponse([image["url"] for image in data["data"]], prompt)
104
+ return
105
+
106
+ if images is not None and messages:
107
+ if not model and hasattr(cls, "default_vision_model"):
108
+ model = cls.default_vision_model
109
+ last_message = messages[-1].copy()
110
+ last_message["content"] = [
111
+ *[{
112
+ "type": "image_url",
113
+ "image_url": {"url": to_data_uri(image)}
114
+ } for image, _ in images],
115
+ {
116
+ "type": "text",
117
+ "text": messages[-1]["content"]
118
+ }
119
+ ]
120
+ messages[-1] = last_message
121
+ data = filter_none(
122
+ messages=messages,
123
+ model=model,
124
+ temperature=temperature,
125
+ max_tokens=max_tokens,
126
+ top_p=top_p,
127
+ stop=stop,
128
+ stream=stream,
129
+ tools=tools,
130
+ **extra_data
131
+ )
132
+ if api_endpoint is None:
133
+ api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
134
+ async with session.post(api_endpoint, json=data, ssl=cls.verify) as response:
135
+ content_type = response.headers.get("content-type", "text/event-stream" if stream else "application/json")
136
+ if content_type.startswith("application/json"):
137
+ data = await response.json()
138
+ cls.raise_error(data)
139
+ await raise_for_status(response)
140
+ choice = data["choices"][0]
141
+ if "content" in choice["message"] and choice["message"]["content"]:
142
+ yield choice["message"]["content"].strip()
143
+ elif "tool_calls" in choice["message"]:
144
+ yield ToolCalls(choice["message"]["tool_calls"])
145
+ if "usage" in data:
146
+ yield Usage(**data["usage"])
147
+ if "finish_reason" in choice and choice["finish_reason"] is not None:
148
+ yield FinishReason(choice["finish_reason"])
149
+ return
150
+ elif content_type.startswith("text/event-stream"):
151
+ await raise_for_status(response)
152
+ first = True
153
+ is_thinking = 0
154
+ async for line in response.iter_lines():
155
+ if line.startswith(b"data: "):
156
+ chunk = line[6:]
157
+ if chunk == b"[DONE]":
158
+ break
159
+ data = json.loads(chunk)
160
+ cls.raise_error(data)
161
+ choice = data["choices"][0]
162
+ if "content" in choice["delta"] and choice["delta"]["content"]:
163
+ delta = choice["delta"]["content"]
164
+ if first:
165
+ delta = delta.lstrip()
166
+ if delta:
167
+ first = False
168
+ if is_thinking:
169
+ if "</think>" in delta:
170
+ yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
171
+ is_thinking = 0
172
+ else:
173
+ yield Reasoning(delta)
174
+ elif "<think>" in delta:
175
+ is_thinking = time.time()
176
+ yield Reasoning(None, "Is thinking...")
177
+ else:
178
+ yield delta
179
+ if "usage" in data and data["usage"]:
180
+ yield Usage(**data["usage"])
181
+ if "finish_reason" in choice and choice["finish_reason"] is not None:
182
+ yield FinishReason(choice["finish_reason"])
183
+ break
184
+ else:
185
+ await raise_for_status(response)
186
+ raise ResponseError(f"Not supported content-type: {content_type}")
187
+
188
+ @classmethod
189
+ def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
190
+ return {
191
+ "Accept": "text/event-stream" if stream else "application/json",
192
+ "Content-Type": "application/json",
193
+ **(
194
+ {"Authorization": f"Bearer {api_key}"}
195
+ if api_key is not None else {}
196
+ ),
197
+ **({} if headers is None else headers)
198
+ }
199
+
200
+ class Feature(OpenaiTemplate):
201
+ url = "https://ahe.hopto.org"
202
+ working = True
203
+ verify = False
204
+
205
+ models = [
206
+ *list(set(g4f.Provider.OpenaiAccount.get_models())),
207
+ *g4f.Provider.HuggingChat.get_models(),
208
+ "MiniMax"
209
+ ]
210
+
211
+ @classmethod
212
+ def get_model(cls, model, **kwargs):
213
+ if model == "MiniMax":
214
+ cls.api_base = f"{cls.url}/api/HailuoAI"
215
+ elif model in g4f.Provider.OpenaiAccount.get_models():
216
+ cls.api_base = f"{cls.url}/api/OpenaiAccount"
217
+ elif model in g4f.Provider.HuggingChat.get_models():
218
+ cls.api_base = f"{cls.url}/api/HuggingChat"
219
+ else:
220
+ cls.api_base = f"{cls.url}/v1"
221
+ return model
222
+
223
+ @classmethod
224
+ async def create_async_generator(
225
+ cls,
226
+ model: str,
227
+ messages: Messages,
228
+ api_key: str = None,
229
+ **kwargs
230
+ ) -> AsyncResult:
231
+ async for chunk in super().create_async_generator(model, messages, **kwargs):
232
+ yield chunk
233
+
234
+ g4f.Provider.__map__["Feature"] = Feature
235
 
236
  app = g4f.api.create_app_with_demo_and_debug()