// Copyright 2019 GoAdmin Core Team. All rights reserved.
// Use of this source code is governed by a Apache-2.0 style
// license that can be found in the LICENSE file.

package dialect

import (
	"strings"

	"github.com/GoAdminGroup/go-admin/modules/config"
)

// Dialect is methods set of different driver.
type Dialect interface {
	// GetName get dialect's name
	GetName() string

	// ShowColumns show columns of specified table
	ShowColumns(table string) string

	// ShowColumnsWithComment show columns with coment of specified table
	ShowColumnsWithComment(schema, table string) string

	// ShowTables show tables of database
	ShowTables() string

	// Insert
	Insert(comp *SQLComponent) string

	// Delete
	Delete(comp *SQLComponent) string

	// Update
	Update(comp *SQLComponent) string

	// Select
	Select(comp *SQLComponent) string

	// GetDelimiter return the delimiter of Dialect.
	GetDelimiter() string
}

// GetDialect return the default Dialect.
func GetDialect() Dialect {
	return GetDialectByDriver(config.GetDatabases().GetDefault().Driver)
}

// GetDialectByDriver return the Dialect of given driver.
func GetDialectByDriver(driver string) Dialect {
	switch driver {
	case "mysql":
		return mysql{
			commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
		}
	case "mssql":
		return mssql{
			commonDialect: commonDialect{delimiter: "[", delimiter2: "]"},
		}
	case "postgresql":
		return postgresql{
			commonDialect: commonDialect{delimiter: `"`, delimiter2: `"`},
		}
	case "sqlite":
		return sqlite{
			commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
		}
	case "oceanbase":
		return oceanbase{
			commonDialect: commonDialect{delimiter: "`", delimiter2: "`"},
		}
	default:
		return commonDialect{delimiter: "`", delimiter2: "`"}
	}
}

// H is a shorthand of map.
type H map[string]interface{}

// SQLComponent is a sql components set.
type SQLComponent struct {
	Fields     []string
	Functions  []string
	TableName  string
	Wheres     []Where
	Leftjoins  []Join
	Args       []interface{}
	Order      string
	Offset     string
	Limit      string
	WhereRaws  string
	UpdateRaws []RawUpdate
	Group      string
	Statement  string
	Values     H
}

// Where contains the operation and field.
type Where struct {
	Operation string
	Field     string
	Qmark     string
}

// Join contains the table and field and operation.
type Join struct {
	Table     string
	FieldA    string
	Operation string
	FieldB    string
}

// RawUpdate contains the expression and arguments.
type RawUpdate struct {
	Expression string
	Args       []interface{}
}

// *******************************
// internal help function
// *******************************

func (sql *SQLComponent) getLimit() string {
	if sql.Limit == "" {
		return ""
	}
	return " limit " + sql.Limit + " "
}

func (sql *SQLComponent) getOffset() string {
	if sql.Offset == "" {
		return ""
	}
	return " offset " + sql.Offset + " "
}

func (sql *SQLComponent) getOrderBy() string {
	if sql.Order == "" {
		return ""
	}
	return " order by " + sql.Order + " "
}

func (sql *SQLComponent) getGroupBy() string {
	if sql.Group == "" {
		return ""
	}
	return " group by " + sql.Group + " "
}

func (sql *SQLComponent) getJoins(delimiter, delimiter2 string) string {
	if len(sql.Leftjoins) == 0 {
		return ""
	}
	joins := ""
	for _, join := range sql.Leftjoins {
		joins += " left join " + wrap(delimiter, delimiter2, join.Table) + " on " +
			sql.processLeftJoinField(join.FieldA, delimiter, delimiter2) + " " + join.Operation + " " +
			sql.processLeftJoinField(join.FieldB, delimiter, delimiter2) + " "
	}
	return joins
}

func (sql *SQLComponent) processLeftJoinField(field, delimiter, delimiter2 string) string {
	arr := strings.Split(field, ".")
	if len(arr) > 0 {
		return delimiter + arr[0] + delimiter2 + "." + delimiter + arr[1] + delimiter2
	}
	return field
}

func (sql *SQLComponent) getFields(delimiter, delimiter2 string) string {
	if len(sql.Fields) == 0 {
		return "*"
	}
	fields := ""
	if len(sql.Leftjoins) == 0 {
		for k, field := range sql.Fields {
			if sql.Functions[k] != "" {
				fields += sql.Functions[k] + "(" + wrap(delimiter, delimiter2, field) + "),"
			} else {
				fields += wrap(delimiter, delimiter2, field) + ","
			}
		}
	} else {
		for _, field := range sql.Fields {
			arr := strings.Split(field, ".")
			if len(arr) > 1 {
				fields += wrap(delimiter, delimiter2, arr[0]) + "." + wrap(delimiter, delimiter2, arr[1]) + ","
			} else {
				fields += wrap(delimiter, delimiter2, field) + ","
			}
		}
	}
	return fields[:len(fields)-1]
}

func wrap(delimiter, delimiter2, field string) string {
	if field == "*" {
		return "*"
	}
	return delimiter + field + delimiter2
}

func (sql *SQLComponent) getWheres(delimiter, delimiter2 string) string {
	if len(sql.Wheres) == 0 {
		if sql.WhereRaws != "" {
			return " where " + sql.WhereRaws
		}
		return ""
	}
	wheres := " where "
	var arr []string
	for _, where := range sql.Wheres {
		arr = strings.Split(where.Field, ".")
		if len(arr) > 1 {
			wheres += arr[0] + "." + wrap(delimiter, delimiter2, arr[1]) + " " + where.Operation + " " + where.Qmark + " and "
		} else {
			wheres += wrap(delimiter, delimiter2, where.Field) + " " + where.Operation + " " + where.Qmark + " and "
		}
	}

	if sql.WhereRaws != "" {
		return wheres + sql.WhereRaws
	}
	return wheres[:len(wheres)-5]
}

func (sql *SQLComponent) prepareUpdate(delimiter, delimiter2 string) {
	fields := ""
	args := make([]interface{}, 0)

	if len(sql.Values) != 0 {

		for key, value := range sql.Values {
			fields += wrap(delimiter, delimiter2, key) + " = ?, "
			args = append(args, value)
		}

		if len(sql.UpdateRaws) == 0 {
			fields = fields[:len(fields)-2]
		} else {
			for i := 0; i < len(sql.UpdateRaws); i++ {
				if i == len(sql.UpdateRaws)-1 {
					fields += sql.UpdateRaws[i].Expression + " "
				} else {
					fields += sql.UpdateRaws[i].Expression + ","
				}
				args = append(args, sql.UpdateRaws[i].Args...)
			}
		}

		sql.Args = append(args, sql.Args...)
	} else {
		if len(sql.UpdateRaws) == 0 {
			panic("prepareUpdate: wrong parameter")
		} else {
			for i := 0; i < len(sql.UpdateRaws); i++ {
				if i == len(sql.UpdateRaws)-1 {
					fields += sql.UpdateRaws[i].Expression + " "
				} else {
					fields += sql.UpdateRaws[i].Expression + ","
				}
				args = append(args, sql.UpdateRaws[i].Args...)
			}
		}
		sql.Args = append(args, sql.Args...)
	}

	sql.Statement = "update " + delimiter + sql.TableName + delimiter2 + " set " + fields + sql.getWheres(delimiter, delimiter2)
}

func (sql *SQLComponent) prepareInsert(delimiter, delimiter2 string) {
	fields := " ("
	quesMark := "("

	for key, value := range sql.Values {
		fields += wrap(delimiter, delimiter2, key) + ","
		quesMark += "?,"
		sql.Args = append(sql.Args, value)
	}
	fields = fields[:len(fields)-1] + ")"
	quesMark = quesMark[:len(quesMark)-1] + ")"

	sql.Statement = "insert into " + delimiter + sql.TableName + delimiter2 + fields + " values " + quesMark
}