ZHZ1024 commited on
Commit
b717d11
·
verified ·
1 Parent(s): 8fa2510

Create main.go

Browse files
Files changed (1) hide show
  1. main.go +318 -0
main.go ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/json"
6
+ "io"
7
+ "log"
8
+ "net/http"
9
+ "strings"
10
+ )
11
+
12
+ const (
13
+ defaultPort = "8080"
14
+ geminiOpenAIEndpoint = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions"
15
+ )
16
+
17
+ // RelayServer 中继服务器
18
+ type RelayServer struct {
19
+ client *http.Client
20
+ }
21
+
22
+ // NewRelayServer 创建新的中继服务器
23
+ func NewRelayServer() *RelayServer {
24
+ return &RelayServer{
25
+ client: &http.Client{},
26
+ }
27
+ }
28
+
29
+ // filterRequest 过滤掉Gemini不支持的参数
30
+ func filterRequest(body []byte) ([]byte, error) {
31
+ var requestData map[string]interface{}
32
+ if err := json.Unmarshal(body, &requestData); err != nil {
33
+ return body, nil // 如果解析失败,返回原始数据
34
+ }
35
+
36
+ // Gemini不支持的OpenAI参数列表
37
+ unsupportedParams := []string{
38
+ "frequency_penalty",
39
+ "presence_penalty",
40
+ "logit_bias",
41
+ "user",
42
+ "n",
43
+ "stop",
44
+ "suffix",
45
+ "logprobs",
46
+ "echo",
47
+ "best_of",
48
+ "response_format",
49
+ "seed",
50
+ "tools",
51
+ "tool_choice",
52
+ "parallel_tool_calls",
53
+ }
54
+
55
+ // 删除不支持的参数
56
+ for _, param := range unsupportedParams {
57
+ delete(requestData, param)
58
+ }
59
+
60
+ // 重新序列化
61
+ return json.Marshal(requestData)
62
+ }
63
+
64
+ // handleRequest 处理所有的API请求
65
+ func (s *RelayServer) handleRequest(w http.ResponseWriter, r *http.Request) {
66
+ // 检查是否有Authorization头
67
+ authHeader := r.Header.Get("Authorization")
68
+ if authHeader == "" {
69
+ w.Header().Set("Content-Type", "application/json")
70
+ w.WriteHeader(http.StatusUnauthorized)
71
+ w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`))
72
+ return
73
+ }
74
+
75
+ // 读取请求体
76
+ bodyBytes, err := io.ReadAll(r.Body)
77
+ if err != nil {
78
+ w.Header().Set("Content-Type", "application/json")
79
+ w.WriteHeader(http.StatusBadRequest)
80
+ w.Write([]byte(`{"error": {"message": "Failed to read request body", "type": "invalid_request_error"}}`))
81
+ return
82
+ }
83
+ defer r.Body.Close()
84
+
85
+ // 打印原始请求(调试用)
86
+ log.Printf("Original request body: %s", string(bodyBytes))
87
+
88
+ // 过滤请求参数
89
+ filteredBody, err := filterRequest(bodyBytes)
90
+ if err != nil {
91
+ log.Printf("Failed to filter request: %v", err)
92
+ filteredBody = bodyBytes // 使用原始数据
93
+ }
94
+
95
+ log.Printf("Filtered request body: %s", string(filteredBody))
96
+
97
+ // 创建新的请求
98
+ proxyReq, err := http.NewRequest("POST", geminiOpenAIEndpoint, bytes.NewReader(filteredBody))
99
+ if err != nil {
100
+ w.Header().Set("Content-Type", "application/json")
101
+ w.WriteHeader(http.StatusInternalServerError)
102
+ w.Write([]byte(`{"error": {"message": "Failed to create proxy request", "type": "server_error"}}`))
103
+ return
104
+ }
105
+
106
+ // 复制所有请求头
107
+ for name, values := range r.Header {
108
+ // 跳过Host和Content-Length,这些会自动设置
109
+ if name == "Host" || name == "Content-Length" {
110
+ continue
111
+ }
112
+ for _, value := range values {
113
+ proxyReq.Header.Add(name, value)
114
+ }
115
+ }
116
+
117
+ // 确保Authorization头被正确设置
118
+ proxyReq.Header.Set("Authorization", authHeader)
119
+ proxyReq.Header.Set("Content-Type", "application/json")
120
+
121
+ log.Printf("Request headers being sent to Gemini: %v", proxyReq.Header)
122
+
123
+ // 发送请求
124
+ resp, err := s.client.Do(proxyReq)
125
+ if err != nil {
126
+ log.Printf("Failed to send request to Gemini: %v", err)
127
+ w.Header().Set("Content-Type", "application/json")
128
+ w.WriteHeader(http.StatusBadGateway)
129
+ w.Write([]byte(`{"error": {"message": "Failed to connect to Gemini API", "type": "server_error"}}`))
130
+ return
131
+ }
132
+ defer resp.Body.Close()
133
+
134
+ // 打印响应状态(调试用)
135
+ log.Printf("Response status from Gemini: %d", resp.StatusCode)
136
+
137
+ // 复制响应头
138
+ for name, values := range resp.Header {
139
+ // 跳过一些头部
140
+ if name == "Content-Length" {
141
+ continue
142
+ }
143
+ for _, value := range values {
144
+ w.Header().Add(name, value)
145
+ }
146
+ }
147
+
148
+ // 添加CORS头
149
+ w.Header().Set("Access-Control-Allow-Origin", "*")
150
+
151
+ // 设置状态码
152
+ w.WriteHeader(resp.StatusCode)
153
+
154
+ // 处理流式响应
155
+ if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
156
+ // 确保流式响应的头部设置正确
157
+ w.Header().Set("Cache-Control", "no-cache")
158
+ w.Header().Set("Connection", "keep-alive")
159
+
160
+ // 使用缓冲区进行流式传输
161
+ buf := make([]byte, 1024)
162
+ for {
163
+ n, err := resp.Body.Read(buf)
164
+ if n > 0 {
165
+ if _, writeErr := w.Write(buf[:n]); writeErr != nil {
166
+ log.Printf("Error writing response: %v", writeErr)
167
+ return
168
+ }
169
+ if flusher, ok := w.(http.Flusher); ok {
170
+ flusher.Flush()
171
+ }
172
+ }
173
+ if err != nil {
174
+ if err != io.EOF {
175
+ log.Printf("Error reading response: %v", err)
176
+ }
177
+ break
178
+ }
179
+ }
180
+ } else {
181
+ // 非流式响应
182
+ // 如果是错误响应,打印出来以便调试
183
+ if resp.StatusCode >= 400 {
184
+ bodyBytes, _ := io.ReadAll(resp.Body)
185
+ log.Printf("Error response from Gemini: %s", string(bodyBytes))
186
+ w.Write(bodyBytes)
187
+ } else {
188
+ io.Copy(w, resp.Body)
189
+ }
190
+ }
191
+ }
192
+
193
+ // handleHealth 健康检查端点
194
+ func (s *RelayServer) handleHealth(w http.ResponseWriter, r *http.Request) {
195
+ w.Header().Set("Content-Type", "application/json")
196
+ w.WriteHeader(http.StatusOK)
197
+ w.Write([]byte(`{"status": "ok", "service": "gemini-relay"}`))
198
+ }
199
+
200
+ // corsMiddleware CORS中间件
201
+ func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
202
+ return func(w http.ResponseWriter, r *http.Request) {
203
+ w.Header().Set("Access-Control-Allow-Origin", "*")
204
+ w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
205
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
206
+ w.Header().Set("Access-Control-Max-Age", "86400")
207
+
208
+ if r.Method == "OPTIONS" {
209
+ w.WriteHeader(http.StatusOK)
210
+ return
211
+ }
212
+
213
+ next(w, r)
214
+ }
215
+ }
216
+
217
+ // loggingMiddleware 日志中间件
218
+ func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
219
+ return func(w http.ResponseWriter, r *http.Request) {
220
+ log.Printf("[%s] %s %s %s", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent())
221
+ next(w, r)
222
+ }
223
+ }
224
+
225
+ // 模型映射
226
+ func (s *RelayServer) handleModels(w http.ResponseWriter, r *http.Request) {
227
+ // 检查Authorization
228
+ if r.Header.Get("Authorization") == "" {
229
+ w.Header().Set("Content-Type", "application/json")
230
+ w.WriteHeader(http.StatusUnauthorized)
231
+ w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`))
232
+ return
233
+ }
234
+
235
+ w.Header().Set("Content-Type", "application/json")
236
+ w.WriteHeader(http.StatusOK)
237
+ models := `{
238
+ "object": "list",
239
+ "data": [
240
+ {
241
+ "id": "gemini-1.5-pro",
242
+ "object": "model",
243
+ "created": 1686935002,
244
+ "owned_by": "google"
245
+ },
246
+ {
247
+ "id": "gemini-1.5-flash",
248
+ "object": "model",
249
+ "created": 1686935002,
250
+ "owned_by": "google"
251
+ },
252
+ {
253
+ "id": "gemini-1.5-flash-8b",
254
+ "object": "model",
255
+ "created": 1686935002,
256
+ "owned_by": "google"
257
+ },
258
+ {
259
+ "id": "gemini-2.0-flash-exp",
260
+ "object": "model",
261
+ "created": 1686935002,
262
+ "owned_by": "google"
263
+ }
264
+ ]
265
+ }`
266
+ w.Write([]byte(models))
267
+ }
268
+
269
+ func main() {
270
+ // 创建中继服务器
271
+ server := NewRelayServer()
272
+
273
+ // 设置路由
274
+ mux := http.NewServeMux()
275
+
276
+ // OpenAI兼容的端点
277
+ mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest)))
278
+ mux.HandleFunc("/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest)))
279
+ mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(server.handleModels)))
280
+ mux.HandleFunc("/models", corsMiddleware(loggingMiddleware(server.handleModels)))
281
+
282
+ // 健康检查
283
+ mux.HandleFunc("/health", corsMiddleware(server.handleHealth))
284
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
285
+ w.Header().Set("Content-Type", "application/json")
286
+ w.Write([]byte(`{
287
+ "service": "Gemini API Relay",
288
+ "version": "1.0.0",
289
+ "endpoints": {
290
+ "chat": "/v1/chat/completions",
291
+ "models": "/v1/models",
292
+ "health": "/health"
293
+ },
294
+ "supported_models": [
295
+ "gemini-1.5-pro",
296
+ "gemini-1.5-flash",
297
+ "gemini-1.5-flash-8b",
298
+ "gemini-2.0-flash-exp"
299
+ ],
300
+ "note": "Use Authorization header with 'Bearer YOUR_GEMINI_API_KEY'"
301
+ }`))
302
+ })
303
+
304
+ // 启动服务器
305
+ port := defaultPort
306
+ log.Printf("========================================")
307
+ log.Printf("Gemini API Relay Server")
308
+ log.Printf("Port: %s", port)
309
+ log.Printf("Endpoint: %s", geminiOpenAIEndpoint)
310
+ log.Printf("========================================")
311
+ log.Printf("Usage:")
312
+ log.Printf(" Authorization: Bearer YOUR_GEMINI_API_KEY")
313
+ log.Printf("========================================")
314
+
315
+ if err := http.ListenAndServe(":"+port, mux); err != nil {
316
+ log.Fatalf("Server failed to start: %v", err)
317
+ }
318
+ }