package main

import (
	"bufio"
	"embed"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"io/fs"
	"log"
	"net/http"
	"os"
	"strings"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/joho/godotenv"
)

//go:embed web/*
var staticFiles embed.FS

type Config struct {
	APIPrefix     string
	APIKey        string
	MaxRetryCount int
	RetryDelay    time.Duration
	FakeHeaders   map[string]string
}

var config Config

func init() {
	godotenv.Load()
	config = Config{
		APIKey:        getEnv("API_KEY", ""),
		MaxRetryCount: getIntEnv("MAX_RETRY_COUNT", 3),
		RetryDelay:    getDurationEnv("RETRY_DELAY", 5000),
		FakeHeaders: map[string]string{
			"Accept":             "*/*",
			"Accept-Encoding":    "gzip, deflate, br, zstd",
			"Accept-Language":    "zh-CN,zh;q=0.9",
			"Origin":             "https://duckduckgo.com/",
			"Cookie":             "l=wt-wt; ah=wt-wt; dcm=6",
			"Dnt":                "1",
			"Priority":           "u=1, i",
			"Referer":            "https://duckduckgo.com/",
			"Sec-Ch-Ua":          `"Microsoft Edge";v="129", "Not(A:Brand";v="8", "Chromium";v="129"`,
			"Sec-Ch-Ua-Mobile":   "?0",
			"Sec-Ch-Ua-Platform": `"Windows"`,
			"Sec-Fetch-Dest":     "empty",
			"Sec-Fetch-Mode":     "cors",
			"Sec-Fetch-Site":     "same-origin",
			"User-Agent":         "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36",
		},
	}
}

func authMiddleware() gin.HandlerFunc {
	return func(c *gin.Context) {
		apiKey := c.GetHeader("Authorization")
		if apiKey == "" {
			apiKey = c.Query("api_key")
		}

		// 当未提供或者 API Key 不合法时,允许匿名访问
		if apiKey == "" || !strings.HasPrefix(apiKey, "Bearer ") {
			c.Next() // 未提供 API 密钥时允许继续请求
			return
		}

		apiKey = strings.TrimPrefix(apiKey, "Bearer ")

		if apiKey != config.APIKey {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
			c.Abort()
			return
		}

		c.Next()
	}
}

func main() {
	r := gin.Default()
	r.Use(corsMiddleware())
	// 1. 映射 Web 目录到 /web 路由
	//r.Static("/web", "./web") // 确保 ./web 文件夹存在,并包含 index.html
	subFS, err := fs.Sub(staticFiles, "web")
	if err != nil {
		log.Fatal(err)
	}
	r.StaticFS("/web", http.FS(subFS))

	// 2. 根路径重定向到 /web
	r.GET("/", func(c *gin.Context) {
		c.Redirect(http.StatusMovedPermanently, "/web")
	})
	// r.GET("/", func(c *gin.Context) {
	// 	c.JSON(http.StatusOK, gin.H{"message": "API 服务运行中~"})
	// })
	// 3. 健康检查
	r.GET("/ping", func(c *gin.Context) {
		c.JSON(http.StatusOK, gin.H{"message": "pong"})
	})
	// 4. API 路由组
	// authorized := r.Group("/")
	// authorized.Use(authMiddleware())
	// {
	// 	authorized.GET("/hf/v1/models", handleModels)
	// 	authorized.POST("/hf/v1/chat/completions", handleCompletion)
	// }
	apiGroup := r.Group("/")
	apiGroup.Use(authMiddleware()) // 可以选择性地提供 API 密钥
	{
		// 原始路径 /hf/v1/*
		apiGroup.GET("/hf/v1/models", handleModels)
		apiGroup.POST("/hf/v1/chat/completions", handleCompletion)

		// 新路径 /api/v1/*
		apiGroup.GET("/api/v1/models", handleModels)
		apiGroup.POST("/api/v1/chat/completions", handleCompletion)

		// 新路径 /v1/*
		apiGroup.GET("/v1/models", handleModels)
		apiGroup.POST("/v1/chat/completions", handleCompletion)

		// 新路径 /completions
		apiGroup.POST("/completions", handleCompletion)
	}
	// 5. 从环境变量中读取端口号
	port := os.Getenv("PORT")
	if port == "" {
		port = "7860"
	}
	r.Run(":" + port)
}

func handleModels(c *gin.Context) {
	models := []gin.H{
		{"id": "gpt-4o-mini", "object": "model", "owned_by": "ddg"},
		{"id": "claude-3-haiku", "object": "model", "owned_by": "ddg"},
		{"id": "llama-3.1-70b", "object": "model", "owned_by": "ddg"},
		{"id": "mixtral-8x7b", "object": "model", "owned_by": "ddg"},
	}
	c.JSON(http.StatusOK, gin.H{"object": "list", "data": models})
}

func handleCompletion(c *gin.Context) {
	var req struct {
		Model    string `json:"model"`
		Messages []struct {
			Role    string      `json:"role"`
			Content interface{} `json:"content"`
		} `json:"messages"`
		Stream bool `json:"stream"`
	}

	if err := c.ShouldBindJSON(&req); err != nil {
		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
		return
	}

	model := convertModel(req.Model)
	content := prepareMessages(req.Messages)
	// log.Printf("messages: %v", content)

	reqBody := map[string]interface{}{
		"model": model,
		"messages": []map[string]interface{}{
			{
				"role":    "user",
				"content": content,
			},
		},
	}

	body, err := json.Marshal(reqBody)
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("请求体序列化失败: %v", err)})
		return
	}

	token, err := requestToken()
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{"error": "无法获取token"})
		return
	}

	upstreamReq, err := http.NewRequest("POST", "https://duckduckgo.com/duckchat/v1/chat", strings.NewReader(string(body)))
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("创建请求失败: %v", err)})
		return
	}

	for k, v := range config.FakeHeaders {
		upstreamReq.Header.Set(k, v)
	}
	upstreamReq.Header.Set("x-vqd-4", token)
	upstreamReq.Header.Set("Content-Type", "application/json")

	client := &http.Client{
		Timeout: 30 * time.Second,
	}

	resp, err := client.Do(upstreamReq)
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("请求失败: %v", err)})
		return
	}
	defer resp.Body.Close()

	if req.Stream {
		// 启用 SSE 流式响应
		c.Writer.Header().Set("Content-Type", "text/event-stream")
		c.Writer.Header().Set("Cache-Control", "no-cache")
		c.Writer.Header().Set("Connection", "keep-alive")

		flusher, ok := c.Writer.(http.Flusher)
		if !ok {
			c.JSON(http.StatusInternalServerError, gin.H{"error": "Streaming not supported"})
			return
		}

		reader := bufio.NewReader(resp.Body)
		for {
			line, err := reader.ReadString('\n')
			if err != nil {
				if err != io.EOF {
					log.Printf("读取流式响应失败: %v", err)
				}
				break
			}
			if strings.HasPrefix(line, "data: ") {
				// 解析响应中的 JSON 数据块
				line = strings.TrimPrefix(line, "data: ")
				line = strings.TrimSpace(line)
				// 忽略非 JSON 数据块(例如特殊标记 [DONE])
				if line == "[DONE]" {
					//log.Printf("响应行 DONE, 即将跳过")
					break
				}
				var chunk map[string]interface{}
				if err := json.Unmarshal([]byte(line), &chunk); err != nil {
					log.Printf("解析响应行失败: %v", err)
					continue
				}

				// 检查 chunk 是否包含 message
				if msg, exists := chunk["message"]; exists && msg != nil {
					if msgStr, ok := msg.(string); ok {
						response := map[string]interface{}{
							"id":      "chatcmpl-QXlha2FBbmROaXhpZUFyZUF3ZXNvbWUK",
							"object":  "chat.completion.chunk",
							"created": time.Now().Unix(),
							"model":   model,
							"choices": []map[string]interface{}{
								{
									"index": 0,
									"delta": map[string]string{
										"content": msgStr,
									},
									"finish_reason": nil,
								},
							},
						}
						// 将响应格式化为 SSE 数据块
						sseData, _ := json.Marshal(response)
						sseMessage := fmt.Sprintf("data: %s\n\n", sseData)

						// 发送数据并刷新缓冲区
						_, writeErr := c.Writer.Write([]byte(sseMessage))
						if writeErr != nil {
							log.Printf("写入响应失败: %v", writeErr)
							break
						}
						flusher.Flush()
					} else {
						log.Printf("chunk[message] 不是字符串: %v", msg)
					}
				} else {
					// 解析行中有空行
					log.Println("chunk 中未包含 message 或 message 为 nil")
				}
			}
		}
	} else {
		// 非流式响应,返回完整的 JSON
		var fullResponse strings.Builder
		reader := bufio.NewReader(resp.Body)

		for {
			line, err := reader.ReadString('\n')
			if err == io.EOF {
				break
			} else if err != nil {
				log.Printf("读取响应失败: %v", err)
				break
			}

			if strings.HasPrefix(line, "data: ") {
				line = strings.TrimPrefix(line, "data: ")
				line = strings.TrimSpace(line)

				if line == "[DONE]" {
					break
				}

				var chunk map[string]interface{}
				if err := json.Unmarshal([]byte(line), &chunk); err != nil {
					log.Printf("解析响应行失败: %v", err)
					continue
				}

				if message, exists := chunk["message"]; exists {
					if msgStr, ok := message.(string); ok {
						fullResponse.WriteString(msgStr)
					}
				}
			}
		}

		// 返回完整 JSON 响应
		response := map[string]interface{}{
			"id":      "chatcmpl-QXlha2FBbmROaXhpZUFyZUF3ZXNvbWUK",
			"object":  "chat.completion",
			"created": time.Now().Unix(),
			"model":   model,
			"usage": map[string]int{
				"prompt_tokens":     0,
				"completion_tokens": 0,
				"total_tokens":      0,
			},
			"choices": []map[string]interface{}{
				{
					"message": map[string]string{
						"role":    "assistant",
						"content": fullResponse.String(),
					},
					"index": 0,
				},
			},
		}

		c.JSON(http.StatusOK, response)
	}
}

//func requestToken() (string, error) {
//	req, err := http.NewRequest("GET", "https://duckduckgo.com/duckchat/v1/status", nil)
//	if err != nil {
//		return "", fmt.Errorf("创建请求失败: %v", err)
//	}
//	for k, v := range config.FakeHeaders {
//		req.Header.Set(k, v)
//	}
//	req.Header.Set("x-vqd-accept", "1")
//
//	client := &http.Client{
//		Timeout: 10 * time.Second,
//	}
//
//	log.Println("发送 token 请求")
//	resp, err := client.Do(req)
//	if err != nil {
//		return "", fmt.Errorf("请求失败: %v", err)
//	}
//	defer resp.Body.Close()
//
//	if resp.StatusCode != http.StatusOK {
//		bodyBytes, _ := io.ReadAll(resp.Body)
//		bodyString := string(bodyBytes)
//		log.Printf("requestToken: 非200响应: %d, 内容: %s\n", resp.StatusCode, bodyString)
//		return "", fmt.Errorf("非200响应: %d, 内容: %s", resp.StatusCode, bodyString)
//	}
//
//	token := resp.Header.Get("x-vqd-4")
//	if token == "" {
//		return "", errors.New("响应中未包含x-vqd-4头")
//	}
//
//	// log.Printf("获取到的 token: %s\n", token)
//	return token, nil
//}

func requestToken() (string, error) {
	url := "https://duckduckgo.com/duckchat/v1/status"
	client := &http.Client{
		Timeout: 15 * time.Second, // 设置超时时间
	}

	maxRetries := config.MaxRetryCount
	retryDelay := config.RetryDelay

	for attempt := 0; attempt < maxRetries; attempt++ {
		if attempt > 0 {
			log.Printf("requestToken: 第 %d 次重试,等待 %v...", attempt, retryDelay)
			time.Sleep(retryDelay)
		}
		log.Printf("requestToken: 发送 GET 请求到 %s", url)

		// 创建请求
		req, err := http.NewRequest("GET", url, nil)
		if err != nil {
			log.Printf("requestToken: 创建请求失败: %v", err)
			return "", fmt.Errorf("无法创建请求: %w", err)
		}

		// 添加假头部
		for k, v := range config.FakeHeaders {
			req.Header.Set(k, v)
		}
		req.Header.Set("x-vqd-accept", "1")

		// 发送请求
		resp, err := client.Do(req)
		if err != nil {
			log.Printf("requestToken: 请求失败: %v", err)
			continue // 网络通信失败,进行重试
		}
		defer resp.Body.Close()

		// 检查状态码是否为 200
		if resp.StatusCode != http.StatusOK {
			bodyBytes, _ := io.ReadAll(resp.Body) // 读取响应体,错误时也需要记录响应内容
			bodyString := string(bodyBytes)
			log.Printf("requestToken: 非200响应,状态码=%d, 响应内容: %s", resp.StatusCode, bodyString)
			continue
		}

		// 尝试从头部提取 token
		token := resp.Header.Get("x-vqd-4")
		if token == "" {
			log.Println("requestToken: 响应中未包含 x-vqd-4 头部")
			bodyBytes, _ := io.ReadAll(resp.Body)
			bodyString := string(bodyBytes)
			log.Printf("requestToken: 响应内容: %s", bodyString)
			continue
		}

		// 成功获取到 token
		log.Printf("requestToken: 成功获取到 token: %s", token)
		return token, nil
	}

	// 如果所有重试均失败,返回错误
	return "", errors.New("requestToken: 无法获取到 token,多次重试仍失败")
}

func prepareMessages(messages []struct {
	Role    string      `json:"role"`
	Content interface{} `json:"content"`
}) string {
	var contentBuilder strings.Builder

	for _, msg := range messages {
		// Determine the role - 'system' becomes 'user'
		role := msg.Role
		if role == "system" {
			role = "user"
		}

		// Process the content as string
		contentStr := ""
		switch v := msg.Content.(type) {
		case string:
			contentStr = v
		case []interface{}:
			for _, item := range v {
				if itemMap, ok := item.(map[string]interface{}); ok {
					if text, exists := itemMap["text"].(string); exists {
						contentStr += text
					}
				}
			}
		default:
			contentStr = fmt.Sprintf("%v", msg.Content)
		}

		// Append the role and content to the builder
		contentBuilder.WriteString(fmt.Sprintf("%s:%s;\r\n", role, contentStr))
	}

	return contentBuilder.String()
}

func convertModel(inputModel string) string {
	switch strings.ToLower(inputModel) {
	case "claude-3-haiku":
		return "claude-3-haiku-20240307"
	case "llama-3.1-70b":
		return "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
	case "mixtral-8x7b":
		return "mistralai/Mixtral-8x7B-Instruct-v0.1"
	default:
		return "gpt-4o-mini"
	}
}

func corsMiddleware() gin.HandlerFunc {
	return func(c *gin.Context) {
		c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
		c.Writer.Header().Set("Access-Control-Allow-Methods", "*")
		c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
		if c.Request.Method == http.MethodOptions {
			c.AbortWithStatus(http.StatusNoContent)
			return
		}
		c.Next()
	}
}

func getEnv(key, fallback string) string {
	if value, exists := os.LookupEnv(key); exists {
		return value
	}
	return fallback
}

func getIntEnv(key string, fallback int) int {
	if value, exists := os.LookupEnv(key); exists {
		var intValue int
		fmt.Sscanf(value, "%d", &intValue)
		return intValue
	}
	return fallback
}

func getDurationEnv(key string, fallback int) time.Duration {
	return time.Duration(getIntEnv(key, fallback)) * time.Millisecond
}