package main | |
import ( | |
"bufio" | |
"embed" | |
"encoding/json" | |
"errors" | |
"fmt" | |
"io" | |
"io/fs" | |
"log" | |
"net/http" | |
"os" | |
"strings" | |
"time" | |
"" | |
"" | |
) | |
//go:embed web/* | |
var staticFiles embed.FS | |
type Config struct { | |
APIPrefix string | |
APIKey string | |
MaxRetryCount int | |
RetryDelay time.Duration | |
FakeHeaders map[string]string | |
} | |
var config Config | |
func init() { | |
godotenv.Load() | |
config = Config{ | |
APIKey: getEnv("API_KEY", ""), | |
MaxRetryCount: getIntEnv("MAX_RETRY_COUNT", 3), | |
RetryDelay: getDurationEnv("RETRY_DELAY", 5000), | |
FakeHeaders: map[string]string{ | |
"Accept": "*/*", | |
"Accept-Encoding": "gzip, deflate, br, zstd", | |
"Accept-Language": "zh-CN,zh;q=0.9", | |
"Origin": "", | |
"Cookie": "l=wt-wt; ah=wt-wt; dcm=6", | |
"Dnt": "1", | |
"Priority": "u=1, i", | |
"Referer": "", | |
"Sec-Ch-Ua": `"Microsoft Edge";v="129", "Not(A:Brand";v="8", "Chromium";v="129"`, | |
"Sec-Ch-Ua-Mobile": "?0", | |
"Sec-Ch-Ua-Platform": `"Windows"`, | |
"Sec-Fetch-Dest": "empty", | |
"Sec-Fetch-Mode": "cors", | |
"Sec-Fetch-Site": "same-origin", | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/ Safari/537.36", | |
}, | |
} | |
} | |
func authMiddleware() gin.HandlerFunc { | |
return func(c *gin.Context) { | |
apiKey := c.GetHeader("Authorization") | |
if apiKey == "" { | |
apiKey = c.Query("api_key") | |
} | |
// 当未提供或者 API Key 不合法时,允许匿名访问 | |
if apiKey == "" || !strings.HasPrefix(apiKey, "Bearer ") { | |
c.Next() // 未提供 API 密钥时允许继续请求 | |
return | |
} | |
apiKey = strings.TrimPrefix(apiKey, "Bearer ") | |
if apiKey != config.APIKey { | |
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"}) | |
c.Abort() | |
return | |
} | |
c.Next() | |
} | |
} | |
func main() { | |
r := gin.Default() | |
r.Use(corsMiddleware()) | |
// 1. 映射 Web 目录到 /web 路由 | |
//r.Static("/web", "./web") // 确保 ./web 文件夹存在,并包含 index.html | |
subFS, err := fs.Sub(staticFiles, "web") | |
if err != nil { | |
log.Fatal(err) | |
} | |
r.StaticFS("/web", http.FS(subFS)) | |
// 2. 根路径重定向到 /web | |
r.GET("/", func(c *gin.Context) { | |
c.Redirect(http.StatusMovedPermanently, "/web") | |
}) | |
// r.GET("/", func(c *gin.Context) { | |
// c.JSON(http.StatusOK, gin.H{"message": "API 服务运行中~"}) | |
// }) | |
// 3. 健康检查 | |
r.GET("/ping", func(c *gin.Context) { | |
c.JSON(http.StatusOK, gin.H{"message": "pong"}) | |
}) | |
// 4. API 路由组 | |
// authorized := r.Group("/") | |
// authorized.Use(authMiddleware()) | |
// { | |
// authorized.GET("/hf/v1/models", handleModels) | |
// authorized.POST("/hf/v1/chat/completions", handleCompletion) | |
// } | |
apiGroup := r.Group("/") | |
apiGroup.Use(authMiddleware()) // 可以选择性地提供 API 密钥 | |
{ | |
// 原始路径 /hf/v1/* | |
apiGroup.GET("/hf/v1/models", handleModels) | |
apiGroup.POST("/hf/v1/chat/completions", handleCompletion) | |
// 新路径 /api/v1/* | |
apiGroup.GET("/api/v1/models", handleModels) | |
apiGroup.POST("/api/v1/chat/completions", handleCompletion) | |
// 新路径 /v1/* | |
apiGroup.GET("/v1/models", handleModels) | |
apiGroup.POST("/v1/chat/completions", handleCompletion) | |
// 新路径 /completions | |
apiGroup.POST("/completions", handleCompletion) | |
} | |
// 5. 从环境变量中读取端口号 | |
port := os.Getenv("PORT") | |
if port == "" { | |
port = "7860" | |
} | |
r.Run(":" + port) | |
} | |
func handleModels(c *gin.Context) { | |
models := []gin.H{ | |
{"id": "gpt-4o-mini", "object": "model", "owned_by": "ddg"}, | |
{"id": "claude-3-haiku", "object": "model", "owned_by": "ddg"}, | |
{"id": "llama-3.1-70b", "object": "model", "owned_by": "ddg"}, | |
{"id": "mixtral-8x7b", "object": "model", "owned_by": "ddg"}, | |
} | |
c.JSON(http.StatusOK, gin.H{"object": "list", "data": models}) | |
} | |
func handleCompletion(c *gin.Context) { | |
var req struct { | |
Model string `json:"model"` | |
Messages []struct { | |
Role string `json:"role"` | |
Content interface{} `json:"content"` | |
} `json:"messages"` | |
Stream bool `json:"stream"` | |
} | |
if err := c.ShouldBindJSON(&req); err != nil { | |
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
return | |
} | |
model := convertModel(req.Model) | |
content := prepareMessages(req.Messages) | |
// log.Printf("messages: %v", content) | |
reqBody := map[string]interface{}{ | |
"model": model, | |
"messages": []map[string]interface{}{ | |
{ | |
"role": "user", | |
"content": content, | |
}, | |
}, | |
} | |
body, err := json.Marshal(reqBody) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("请求体序列化失败: %v", err)}) | |
return | |
} | |
token, err := requestToken() | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": "无法获取token"}) | |
return | |
} | |
upstreamReq, err := http.NewRequest("POST", "", strings.NewReader(string(body))) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("创建请求失败: %v", err)}) | |
return | |
} | |
for k, v := range config.FakeHeaders { | |
upstreamReq.Header.Set(k, v) | |
} | |
upstreamReq.Header.Set("x-vqd-4", token) | |
upstreamReq.Header.Set("Content-Type", "application/json") | |
client := &http.Client{ | |
Timeout: 30 * time.Second, | |
} | |
resp, err := client.Do(upstreamReq) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("请求失败: %v", err)}) | |
return | |
} | |
defer resp.Body.Close() | |
if req.Stream { | |
// 启用 SSE 流式响应 | |
c.Writer.Header().Set("Content-Type", "text/event-stream") | |
c.Writer.Header().Set("Cache-Control", "no-cache") | |
c.Writer.Header().Set("Connection", "keep-alive") | |
flusher, ok := c.Writer.(http.Flusher) | |
if !ok { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": "Streaming not supported"}) | |
return | |
} | |
reader := bufio.NewReader(resp.Body) | |
for { | |
line, err := reader.ReadString('\n') | |
if err != nil { | |
if err != io.EOF { | |
log.Printf("读取流式响应失败: %v", err) | |
} | |
break | |
} | |
if strings.HasPrefix(line, "data: ") { | |
// 解析响应中的 JSON 数据块 | |
line = strings.TrimPrefix(line, "data: ") | |
line = strings.TrimSpace(line) | |
// 忽略非 JSON 数据块(例如特殊标记 [DONE]) | |
if line == "[DONE]" { | |
//log.Printf("响应行 DONE, 即将跳过") | |
break | |
} | |
var chunk map[string]interface{} | |
if err := json.Unmarshal([]byte(line), &chunk); err != nil { | |
log.Printf("解析响应行失败: %v", err) | |
continue | |
} | |
// 检查 chunk 是否包含 message | |
if msg, exists := chunk["message"]; exists && msg != nil { | |
if msgStr, ok := msg.(string); ok { | |
response := map[string]interface{}{ | |
"id": "chatcmpl-QXlha2FBbmROaXhpZUFyZUF3ZXNvbWUK", | |
"object": "chat.completion.chunk", | |
"created": time.Now().Unix(), | |
"model": model, | |
"choices": []map[string]interface{}{ | |
{ | |
"index": 0, | |
"delta": map[string]string{ | |
"content": msgStr, | |
}, | |
"finish_reason": nil, | |
}, | |
}, | |
} | |
// 将响应格式化为 SSE 数据块 | |
sseData, _ := json.Marshal(response) | |
sseMessage := fmt.Sprintf("data: %s\n\n", sseData) | |
// 发送数据并刷新缓冲区 | |
_, writeErr := c.Writer.Write([]byte(sseMessage)) | |
if writeErr != nil { | |
log.Printf("写入响应失败: %v", writeErr) | |
break | |
} | |
flusher.Flush() | |
} else { | |
log.Printf("chunk[message] 不是字符串: %v", msg) | |
} | |
} else { | |
// 解析行中有空行 | |
log.Println("chunk 中未包含 message 或 message 为 nil") | |
} | |
} | |
} | |
} else { | |
// 非流式响应,返回完整的 JSON | |
var fullResponse strings.Builder | |
reader := bufio.NewReader(resp.Body) | |
for { | |
line, err := reader.ReadString('\n') | |
if err == io.EOF { | |
break | |
} else if err != nil { | |
log.Printf("读取响应失败: %v", err) | |
break | |
} | |
if strings.HasPrefix(line, "data: ") { | |
line = strings.TrimPrefix(line, "data: ") | |
line = strings.TrimSpace(line) | |
if line == "[DONE]" { | |
break | |
} | |
var chunk map[string]interface{} | |
if err := json.Unmarshal([]byte(line), &chunk); err != nil { | |
log.Printf("解析响应行失败: %v", err) | |
continue | |
} | |
if message, exists := chunk["message"]; exists { | |
if msgStr, ok := message.(string); ok { | |
fullResponse.WriteString(msgStr) | |
} | |
} | |
} | |
} | |
// 返回完整 JSON 响应 | |
response := map[string]interface{}{ | |
"id": "chatcmpl-QXlha2FBbmROaXhpZUFyZUF3ZXNvbWUK", | |
"object": "chat.completion", | |
"created": time.Now().Unix(), | |
"model": model, | |
"usage": map[string]int{ | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0, | |
}, | |
"choices": []map[string]interface{}{ | |
{ | |
"message": map[string]string{ | |
"role": "assistant", | |
"content": fullResponse.String(), | |
}, | |
"index": 0, | |
}, | |
}, | |
} | |
c.JSON(http.StatusOK, response) | |
} | |
} | |
//func requestToken() (string, error) { | |
// req, err := http.NewRequest("GET", "", nil) | |
// if err != nil { | |
// return "", fmt.Errorf("创建请求失败: %v", err) | |
// } | |
// for k, v := range config.FakeHeaders { | |
// req.Header.Set(k, v) | |
// } | |
// req.Header.Set("x-vqd-accept", "1") | |
// | |
// client := &http.Client{ | |
// Timeout: 10 * time.Second, | |
// } | |
// | |
// log.Println("发送 token 请求") | |
// resp, err := client.Do(req) | |
// if err != nil { | |
// return "", fmt.Errorf("请求失败: %v", err) | |
// } | |
// defer resp.Body.Close() | |
// | |
// if resp.StatusCode != http.StatusOK { | |
// bodyBytes, _ := io.ReadAll(resp.Body) | |
// bodyString := string(bodyBytes) | |
// log.Printf("requestToken: 非200响应: %d, 内容: %s\n", resp.StatusCode, bodyString) | |
// return "", fmt.Errorf("非200响应: %d, 内容: %s", resp.StatusCode, bodyString) | |
// } | |
// | |
// token := resp.Header.Get("x-vqd-4") | |
// if token == "" { | |
// return "", errors.New("响应中未包含x-vqd-4头") | |
// } | |
// | |
// // log.Printf("获取到的 token: %s\n", token) | |
// return token, nil | |
//} | |
func requestToken() (string, error) { | |
url := "" | |
client := &http.Client{ | |
Timeout: 15 * time.Second, // 设置超时时间 | |
} | |
maxRetries := config.MaxRetryCount | |
retryDelay := config.RetryDelay | |
for attempt := 0; attempt < maxRetries; attempt++ { | |
if attempt > 0 { | |
log.Printf("requestToken: 第 %d 次重试,等待 %v...", attempt, retryDelay) | |
time.Sleep(retryDelay) | |
} | |
log.Printf("requestToken: 发送 GET 请求到 %s", url) | |
// 创建请求 | |
req, err := http.NewRequest("GET", url, nil) | |
if err != nil { | |
log.Printf("requestToken: 创建请求失败: %v", err) | |
return "", fmt.Errorf("无法创建请求: %w", err) | |
} | |
// 添加假头部 | |
for k, v := range config.FakeHeaders { | |
req.Header.Set(k, v) | |
} | |
req.Header.Set("x-vqd-accept", "1") | |
// 发送请求 | |
resp, err := client.Do(req) | |
if err != nil { | |
log.Printf("requestToken: 请求失败: %v", err) | |
continue // 网络通信失败,进行重试 | |
} | |
defer resp.Body.Close() | |
// 检查状态码是否为 200 | |
if resp.StatusCode != http.StatusOK { | |
bodyBytes, _ := io.ReadAll(resp.Body) // 读取响应体,错误时也需要记录响应内容 | |
bodyString := string(bodyBytes) | |
log.Printf("requestToken: 非200响应,状态码=%d, 响应内容: %s", resp.StatusCode, bodyString) | |
continue | |
} | |
// 尝试从头部提取 token | |
token := resp.Header.Get("x-vqd-4") | |
if token == "" { | |
log.Println("requestToken: 响应中未包含 x-vqd-4 头部") | |
bodyBytes, _ := io.ReadAll(resp.Body) | |
bodyString := string(bodyBytes) | |
log.Printf("requestToken: 响应内容: %s", bodyString) | |
continue | |
} | |
// 成功获取到 token | |
log.Printf("requestToken: 成功获取到 token: %s", token) | |
return token, nil | |
} | |
// 如果所有重试均失败,返回错误 | |
return "", errors.New("requestToken: 无法获取到 token,多次重试仍失败") | |
} | |
func prepareMessages(messages []struct { | |
Role string `json:"role"` | |
Content interface{} `json:"content"` | |
}) string { | |
var contentBuilder strings.Builder | |
for _, msg := range messages { | |
// Determine the role - 'system' becomes 'user' | |
role := msg.Role | |
if role == "system" { | |
role = "user" | |
} | |
// Process the content as string | |
contentStr := "" | |
switch v := msg.Content.(type) { | |
case string: | |
contentStr = v | |
case []interface{}: | |
for _, item := range v { | |
if itemMap, ok := item.(map[string]interface{}); ok { | |
if text, exists := itemMap["text"].(string); exists { | |
contentStr += text | |
} | |
} | |
} | |
default: | |
contentStr = fmt.Sprintf("%v", msg.Content) | |
} | |
// Append the role and content to the builder | |
contentBuilder.WriteString(fmt.Sprintf("%s:%s;\r\n", role, contentStr)) | |
} | |
return contentBuilder.String() | |
} | |
func convertModel(inputModel string) string { | |
switch strings.ToLower(inputModel) { | |
case "claude-3-haiku": | |
return "claude-3-haiku-20240307" | |
case "llama-3.1-70b": | |
return "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" | |
case "mixtral-8x7b": | |
return "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
default: | |
return "gpt-4o-mini" | |
} | |
} | |
func corsMiddleware() gin.HandlerFunc { | |
return func(c *gin.Context) { | |
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") | |
c.Writer.Header().Set("Access-Control-Allow-Methods", "*") | |
c.Writer.Header().Set("Access-Control-Allow-Headers", "*") | |
if c.Request.Method == http.MethodOptions { | |
c.AbortWithStatus(http.StatusNoContent) | |
return | |
} | |
c.Next() | |
} | |
} | |
func getEnv(key, fallback string) string { | |
if value, exists := os.LookupEnv(key); exists { | |
return value | |
} | |
return fallback | |
} | |
func getIntEnv(key string, fallback int) int { | |
if value, exists := os.LookupEnv(key); exists { | |
var intValue int | |
fmt.Sscanf(value, "%d", &intValue) | |
return intValue | |
} | |
return fallback | |
} | |
func getDurationEnv(key string, fallback int) time.Duration { | |
return time.Duration(getIntEnv(key, fallback)) * time.Millisecond | |
} | |