|
|
|
|
|
|
|
|
|
package db |
|
|
|
import ( |
|
dbsql "database/sql" |
|
"errors" |
|
"regexp" |
|
"strconv" |
|
"strings" |
|
"sync" |
|
|
|
"github.com/GoAdminGroup/go-admin/modules/db/dialect" |
|
"github.com/GoAdminGroup/go-admin/modules/logger" |
|
) |
|
|
|
|
|
type SQL struct { |
|
dialect.SQLComponent |
|
diver Connection |
|
dialect dialect.Dialect |
|
conn string |
|
tx *dbsql.Tx |
|
} |
|
|
|
|
|
var SQLPool = sync.Pool{ |
|
New: func() interface{} { |
|
return &SQL{ |
|
SQLComponent: dialect.SQLComponent{ |
|
Fields: make([]string, 0), |
|
TableName: "", |
|
Args: make([]interface{}, 0), |
|
Wheres: make([]dialect.Where, 0), |
|
Leftjoins: make([]dialect.Join, 0), |
|
UpdateRaws: make([]dialect.RawUpdate, 0), |
|
WhereRaws: "", |
|
Order: "", |
|
Group: "", |
|
Limit: "", |
|
}, |
|
diver: nil, |
|
dialect: nil, |
|
} |
|
}, |
|
} |
|
|
|
|
|
type H map[string]interface{} |
|
|
|
|
|
func newSQL() *SQL { |
|
return SQLPool.Get().(*SQL) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
func Table(table string) *SQL { |
|
sql := newSQL() |
|
sql.TableName = table |
|
sql.conn = "default" |
|
return sql |
|
} |
|
|
|
|
|
func WithDriver(conn Connection) *SQL { |
|
sql := newSQL() |
|
sql.diver = conn |
|
sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
|
sql.conn = "default" |
|
return sql |
|
} |
|
|
|
|
|
func WithDriverAndConnection(connName string, conn Connection) *SQL { |
|
sql := newSQL() |
|
sql.diver = conn |
|
sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
|
sql.conn = connName |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) WithDriver(conn Connection) *SQL { |
|
sql.diver = conn |
|
sql.dialect = dialect.GetDialectByDriver(conn.Name()) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) WithConnection(conn string) *SQL { |
|
sql.conn = conn |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) WithTx(tx *dbsql.Tx) *SQL { |
|
sql.tx = tx |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) Table(table string) *SQL { |
|
sql.clean() |
|
sql.TableName = table |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) Select(fields ...string) *SQL { |
|
sql.Fields = fields |
|
sql.Functions = make([]string, len(fields)) |
|
reg, _ := regexp.Compile(`(.*?)\((.*?)\)`) |
|
for k, field := range fields { |
|
res := reg.FindAllStringSubmatch(field, -1) |
|
if len(res) > 0 && len(res[0]) > 2 { |
|
sql.Functions[k] = res[0][1] |
|
sql.Fields[k] = res[0][2] |
|
} |
|
} |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) OrderBy(fields ...string) *SQL { |
|
if len(fields) == 0 { |
|
panic("wrong order field") |
|
} |
|
for i := 0; i < len(fields); i++ { |
|
if i == len(fields)-2 { |
|
sql.Order += " " + sql.wrap(fields[i]) + " " + fields[i+1] |
|
return sql |
|
} |
|
sql.Order += " " + sql.wrap(fields[i]) + " and " |
|
} |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) OrderByRaw(order string) *SQL { |
|
if order != "" { |
|
sql.Order += " " + order |
|
} |
|
return sql |
|
} |
|
|
|
func (sql *SQL) GroupBy(fields ...string) *SQL { |
|
if len(fields) == 0 { |
|
panic("wrong group by field") |
|
} |
|
for i := 0; i < len(fields); i++ { |
|
if i == len(fields)-1 { |
|
sql.Group += " " + sql.wrap(fields[i]) |
|
} else { |
|
sql.Group += " " + sql.wrap(fields[i]) + "," |
|
} |
|
} |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) GroupByRaw(group string) *SQL { |
|
if group != "" { |
|
sql.Group += " " + group |
|
} |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) Skip(offset int) *SQL { |
|
sql.Offset = strconv.Itoa(offset) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) Take(take int) *SQL { |
|
sql.Limit = strconv.Itoa(take) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) Where(field string, operation string, arg interface{}) *SQL { |
|
sql.Wheres = append(sql.Wheres, dialect.Where{ |
|
Field: field, |
|
Operation: operation, |
|
Qmark: "?", |
|
}) |
|
sql.Args = append(sql.Args, arg) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) WhereIn(field string, arg []interface{}) *SQL { |
|
if len(arg) == 0 { |
|
panic("wrong parameter") |
|
} |
|
sql.Wheres = append(sql.Wheres, dialect.Where{ |
|
Field: field, |
|
Operation: "in", |
|
Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", |
|
}) |
|
sql.Args = append(sql.Args, arg...) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) WhereNotIn(field string, arg []interface{}) *SQL { |
|
if len(arg) == 0 { |
|
panic("wrong parameter") |
|
} |
|
sql.Wheres = append(sql.Wheres, dialect.Where{ |
|
Field: field, |
|
Operation: "not in", |
|
Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", |
|
}) |
|
sql.Args = append(sql.Args, arg...) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) Find(arg interface{}) (map[string]interface{}, error) { |
|
return sql.Where("id", "=", arg).First() |
|
} |
|
|
|
|
|
func (sql *SQL) Count() (int64, error) { |
|
var ( |
|
res map[string]interface{} |
|
err error |
|
driver = sql.diver.Name() |
|
) |
|
|
|
if res, err = sql.Select("count(*)").First(); err != nil { |
|
return 0, err |
|
} |
|
|
|
if driver == DriverPostgresql { |
|
return res["count"].(int64), nil |
|
} else if driver == DriverMssql { |
|
return res[""].(int64), nil |
|
} |
|
|
|
return res["count(*)"].(int64), nil |
|
} |
|
|
|
|
|
func (sql *SQL) Sum(field string) (float64, error) { |
|
var ( |
|
res map[string]interface{} |
|
err error |
|
key = "sum(" + sql.wrap(field) + ")" |
|
) |
|
if res, err = sql.Select("sum(" + field + ")").First(); err != nil { |
|
return 0, err |
|
} |
|
|
|
if res == nil { |
|
return 0, nil |
|
} |
|
|
|
if r, ok := res[key].(float64); ok { |
|
return r, nil |
|
} else if r, ok := res[key].([]uint8); ok { |
|
return strconv.ParseFloat(string(r), 64) |
|
} else { |
|
return 0, nil |
|
} |
|
} |
|
|
|
|
|
func (sql *SQL) Max(field string) (interface{}, error) { |
|
var ( |
|
res map[string]interface{} |
|
err error |
|
key = "max(" + sql.wrap(field) + ")" |
|
) |
|
if res, err = sql.Select("max(" + field + ")").First(); err != nil { |
|
return 0, err |
|
} |
|
|
|
if res == nil { |
|
return 0, nil |
|
} |
|
|
|
return res[key], nil |
|
} |
|
|
|
|
|
func (sql *SQL) Min(field string) (interface{}, error) { |
|
var ( |
|
res map[string]interface{} |
|
err error |
|
key = "min(" + sql.wrap(field) + ")" |
|
) |
|
if res, err = sql.Select("min(" + field + ")").First(); err != nil { |
|
return 0, err |
|
} |
|
|
|
if res == nil { |
|
return 0, nil |
|
} |
|
|
|
return res[key], nil |
|
} |
|
|
|
|
|
func (sql *SQL) Avg(field string) (interface{}, error) { |
|
var ( |
|
res map[string]interface{} |
|
err error |
|
key = "avg(" + sql.wrap(field) + ")" |
|
) |
|
if res, err = sql.Select("avg(" + field + ")").First(); err != nil { |
|
return 0, err |
|
} |
|
|
|
if res == nil { |
|
return 0, nil |
|
} |
|
|
|
return res[key], nil |
|
} |
|
|
|
|
|
func (sql *SQL) WhereRaw(raw string, args ...interface{}) *SQL { |
|
sql.WhereRaws = raw |
|
sql.Args = append(sql.Args, args...) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) UpdateRaw(raw string, args ...interface{}) *SQL { |
|
sql.UpdateRaws = append(sql.UpdateRaws, dialect.RawUpdate{ |
|
Expression: raw, |
|
Args: args, |
|
}) |
|
return sql |
|
} |
|
|
|
|
|
func (sql *SQL) LeftJoin(table string, fieldA string, operation string, fieldB string) *SQL { |
|
sql.Leftjoins = append(sql.Leftjoins, dialect.Join{ |
|
FieldA: fieldA, |
|
FieldB: fieldB, |
|
Table: table, |
|
Operation: operation, |
|
}) |
|
return sql |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
type TxFn func(tx *dbsql.Tx) (error, map[string]interface{}) |
|
|
|
|
|
|
|
func (sql *SQL) WithTransaction(fn TxFn) (res map[string]interface{}, err error) { |
|
|
|
tx := sql.diver.BeginTxAndConnection(sql.conn) |
|
|
|
defer func() { |
|
if p := recover(); p != nil { |
|
|
|
_ = tx.Rollback() |
|
panic(p) |
|
} else if err != nil { |
|
|
|
_ = tx.Rollback() |
|
} else { |
|
|
|
err = tx.Commit() |
|
} |
|
}() |
|
|
|
err, res = fn(tx) |
|
return |
|
} |
|
|
|
|
|
|
|
func (sql *SQL) WithTransactionByLevel(level dbsql.IsolationLevel, fn TxFn) (res map[string]interface{}, err error) { |
|
|
|
tx := sql.diver.BeginTxWithLevelAndConnection(sql.conn, level) |
|
|
|
defer func() { |
|
if p := recover(); p != nil { |
|
|
|
_ = tx.Rollback() |
|
panic(p) |
|
} else if err != nil { |
|
|
|
_ = tx.Rollback() |
|
} else { |
|
|
|
err = tx.Commit() |
|
} |
|
}() |
|
|
|
err, res = fn(tx) |
|
return |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (sql *SQL) First() (map[string]interface{}, error) { |
|
defer RecycleSQL(sql) |
|
|
|
sql.dialect.Select(&sql.SQLComponent) |
|
|
|
res, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
if len(res) < 1 { |
|
return nil, errors.New("out of index") |
|
} |
|
return res[0], nil |
|
} |
|
|
|
|
|
func (sql *SQL) All() ([]map[string]interface{}, error) { |
|
defer RecycleSQL(sql) |
|
|
|
sql.dialect.Select(&sql.SQLComponent) |
|
|
|
return sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
} |
|
|
|
|
|
func (sql *SQL) ShowColumns() ([]map[string]interface{}, error) { |
|
defer RecycleSQL(sql) |
|
|
|
return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumns(sql.TableName)) |
|
} |
|
|
|
|
|
func (sql *SQL) ShowColumnsWithComment(database string) ([]map[string]interface{}, error) { |
|
defer RecycleSQL(sql) |
|
|
|
return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumnsWithComment(database, sql.TableName)) |
|
} |
|
|
|
|
|
func (sql *SQL) ShowTables() ([]string, error) { |
|
defer RecycleSQL(sql) |
|
|
|
models, err := sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowTables()) |
|
|
|
if err != nil { |
|
return []string{}, err |
|
} |
|
|
|
tables := make([]string, 0) |
|
if len(models) == 0 { |
|
return tables, nil |
|
} |
|
|
|
key := "Tables_in_" + sql.TableName |
|
if sql.diver.Name() == DriverPostgresql || sql.diver.Name() == DriverSqlite { |
|
key = "tablename" |
|
} else if sql.diver.Name() == DriverMssql { |
|
key = "TABLE_NAME" |
|
} else if _, ok := models[0][key].(string); !ok { |
|
key = "Tables_in_" + strings.ToLower(sql.TableName) |
|
} |
|
|
|
for i := 0; i < len(models); i++ { |
|
|
|
if sql.diver.Name() == DriverSqlite && models[i][key].(string) == "sqlite_sequence" { |
|
continue |
|
} |
|
|
|
tables = append(tables, models[i][key].(string)) |
|
} |
|
|
|
return tables, nil |
|
} |
|
|
|
|
|
func (sql *SQL) Update(values dialect.H) (int64, error) { |
|
defer RecycleSQL(sql) |
|
|
|
sql.Values = values |
|
|
|
sql.dialect.Update(&sql.SQLComponent) |
|
|
|
res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
|
if err != nil { |
|
return 0, err |
|
} |
|
|
|
if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
|
return 0, errors.New("no affect row") |
|
} |
|
|
|
return res.LastInsertId() |
|
} |
|
|
|
|
|
func (sql *SQL) Delete() error { |
|
defer RecycleSQL(sql) |
|
|
|
sql.dialect.Delete(&sql.SQLComponent) |
|
|
|
res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
|
if err != nil { |
|
return err |
|
} |
|
|
|
if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
|
return errors.New("no affect row") |
|
} |
|
|
|
return nil |
|
} |
|
|
|
|
|
func (sql *SQL) Exec() (int64, error) { |
|
defer RecycleSQL(sql) |
|
|
|
sql.dialect.Update(&sql.SQLComponent) |
|
|
|
res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
|
if err != nil { |
|
return 0, err |
|
} |
|
|
|
if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
|
return 0, errors.New("no affect row") |
|
} |
|
|
|
return res.LastInsertId() |
|
} |
|
|
|
const postgresInsertCheckTableName = "goadmin_menu|goadmin_permissions|goadmin_roles|goadmin_users" |
|
|
|
|
|
func (sql *SQL) Insert(values dialect.H) (int64, error) { |
|
defer RecycleSQL(sql) |
|
|
|
sql.Values = values |
|
|
|
sql.dialect.Insert(&sql.SQLComponent) |
|
|
|
if sql.diver.Name() == DriverPostgresql && (strings.Contains(postgresInsertCheckTableName, sql.TableName)) { |
|
|
|
resMap, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement+" RETURNING id", sql.Args...) |
|
|
|
if err != nil { |
|
|
|
|
|
_, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
|
if err != nil { |
|
return 0, err |
|
} |
|
|
|
res, err := sql.diver.QueryWithConnection(sql.conn, `SELECT max("id") as "id" FROM "`+sql.TableName+`"`) |
|
|
|
if err != nil { |
|
return 0, err |
|
} |
|
|
|
if len(res) != 0 { |
|
return res[0]["id"].(int64), nil |
|
} |
|
|
|
return 0, err |
|
} |
|
|
|
if len(resMap) == 0 { |
|
return 0, errors.New("no affect row") |
|
} |
|
|
|
return resMap[0]["id"].(int64), nil |
|
} |
|
|
|
res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) |
|
|
|
if err != nil { |
|
return 0, err |
|
} |
|
|
|
if affectRow, _ := res.RowsAffected(); affectRow < 1 { |
|
return 0, errors.New("no affect row") |
|
} |
|
|
|
return res.LastInsertId() |
|
} |
|
|
|
func (sql *SQL) wrap(field string) string { |
|
return sql.diver.GetDelimiter() + field + sql.diver.GetDelimiter2() |
|
} |
|
|
|
func (sql *SQL) clean() { |
|
sql.Functions = make([]string, 0) |
|
sql.Group = "" |
|
sql.Values = make(map[string]interface{}) |
|
sql.Fields = make([]string, 0) |
|
sql.TableName = "" |
|
sql.Wheres = make([]dialect.Where, 0) |
|
sql.Leftjoins = make([]dialect.Join, 0) |
|
sql.Args = make([]interface{}, 0) |
|
sql.Order = "" |
|
sql.Offset = "" |
|
sql.Limit = "" |
|
sql.WhereRaws = "" |
|
sql.UpdateRaws = make([]dialect.RawUpdate, 0) |
|
sql.Statement = "" |
|
} |
|
|
|
|
|
func RecycleSQL(sql *SQL) { |
|
|
|
logger.LogSQL(sql.Statement, sql.Args) |
|
|
|
sql.clean() |
|
|
|
sql.conn = "" |
|
sql.diver = nil |
|
sql.tx = nil |
|
sql.dialect = nil |
|
|
|
SQLPool.Put(sql) |
|
} |
|
|