Feature model postgresql (#842)

* Support postgresql generate

* Update template Var

* Support to generate postgresql model

* Support to generate postgresql model

* Update template

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
master
anqiansong 3 years ago committed by GitHub
parent 476026e393
commit 089cdaa75f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -455,7 +455,48 @@ var (
Usage: "for idea plugin [optional]", Usage: "for idea plugin [optional]",
}, },
}, },
Action: model.MyDataSource, Action: model.MySqlDataSource,
},
},
},
{
Name: "postgresql",
Usage: `generate postgresql model`,
Subcommands: []cli.Command{
{
Name: "datasource",
Usage: `generate model from datasource`,
Flags: []cli.Flag{
cli.StringFlag{
Name: "url",
Usage: `the data source of database,like "root:password@tcp(127.0.0.1:3306)/database`,
},
cli.StringFlag{
Name: "table, t",
Usage: `the table or table globbing patterns in the database`,
},
cli.StringFlag{
Name: "schema, s",
Usage: `the table schema, default is [public]`,
},
cli.BoolFlag{
Name: "cache, c",
Usage: "generate code with cache [optional]",
},
cli.StringFlag{
Name: "dir, d",
Usage: "the target dir",
},
cli.StringFlag{
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
cli.BoolFlag{
Name: "idea",
Usage: "for idea plugin [optional]",
},
},
Action: model.PostgreSqlDataSource,
}, },
}, },
}, },

@ -3,6 +3,7 @@ package builderx
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"github.com/go-xorm/builder" "github.com/go-xorm/builder"
) )
@ -81,13 +82,18 @@ func FieldNames(in interface{}) []string {
} }
// RawFieldNames converts golang struct field into slice string // RawFieldNames converts golang struct field into slice string
func RawFieldNames(in interface{}) []string { func RawFieldNames(in interface{}, postgresSql ...bool) []string {
out := make([]string, 0) out := make([]string, 0)
v := reflect.ValueOf(in) v := reflect.ValueOf(in)
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
} }
var pg bool
if len(postgresSql) > 0 {
pg = postgresSql[0]
}
// we only accept structs // we only accept structs
if v.Kind() != reflect.Struct { if v.Kind() != reflect.Struct {
panic(fmt.Errorf("ToMap only accepts structs; got %T", v)) panic(fmt.Errorf("ToMap only accepts structs; got %T", v))
@ -98,11 +104,32 @@ func RawFieldNames(in interface{}) []string {
// gets us a StructField // gets us a StructField
fi := typ.Field(i) fi := typ.Field(i)
if tagv := fi.Tag.Get(dbTag); tagv != "" { if tagv := fi.Tag.Get(dbTag); tagv != "" {
out = append(out, fmt.Sprintf("`%s`", tagv)) if pg {
out = append(out, fmt.Sprintf("%s", tagv))
} else {
out = append(out, fmt.Sprintf("`%s`", tagv))
}
} else { } else {
out = append(out, fmt.Sprintf(`"%s"`, fi.Name)) if pg {
out = append(out, fmt.Sprintf("%s", fi.Name))
} else {
out = append(out, fmt.Sprintf("`%s`", fi.Name))
}
} }
} }
return out return out
} }
func PostgreSqlJoin(elems []string) string {
var b = new(strings.Builder)
for index, e := range elems {
b.WriteString(fmt.Sprintf("%s = $%d, ", e, index+1))
}
if b.Len() == 0 {
return b.String()
}
return b.String()[0 : b.Len()-2]
}

@ -118,3 +118,8 @@ func TestBuildSqlLike(t *testing.T) {
assert.Equal(t, sql, actualSQL) assert.Equal(t, sql, actualSQL)
assert.Equal(t, args, actualArgs) assert.Equal(t, args, actualArgs)
} }
func TestJoin(t *testing.T) {
ret := PostgreSqlJoin([]string{"name", "age"})
assert.Equal(t, "name = $1, age = $2", ret)
}

@ -7,6 +7,7 @@ import (
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stores/postgres"
"github.com/tal-tech/go-zero/core/stores/sqlx" "github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
@ -17,14 +18,15 @@ import (
) )
const ( const (
flagSrc = "src" flagSrc = "src"
flagDir = "dir" flagDir = "dir"
flagCache = "cache" flagCache = "cache"
flagIdea = "idea" flagIdea = "idea"
flagURL = "url" flagURL = "url"
flagTable = "table" flagTable = "table"
flagStyle = "style" flagStyle = "style"
flagDatabase = "database" flagDatabase = "database"
flagSchema = "schema"
) )
var errNotMatched = errors.New("sql not matched") var errNotMatched = errors.New("sql not matched")
@ -45,8 +47,8 @@ func MysqlDDL(ctx *cli.Context) error {
return fromDDl(src, dir, cfg, cache, idea, database) return fromDDl(src, dir, cfg, cache, idea, database)
} }
// MyDataSource generates model code from datasource // MySqlDataSource generates model code from datasource
func MyDataSource(ctx *cli.Context) error { func MySqlDataSource(ctx *cli.Context) error {
url := strings.TrimSpace(ctx.String(flagURL)) url := strings.TrimSpace(ctx.String(flagURL))
dir := strings.TrimSpace(ctx.String(flagDir)) dir := strings.TrimSpace(ctx.String(flagDir))
cache := ctx.Bool(flagCache) cache := ctx.Bool(flagCache)
@ -58,7 +60,28 @@ func MyDataSource(ctx *cli.Context) error {
return err return err
} }
return fromDataSource(url, pattern, dir, cfg, cache, idea) return fromMysqlDataSource(url, pattern, dir, cfg, cache, idea)
}
// PostgreSqlDataSource generates model code from datasource
func PostgreSqlDataSource(ctx *cli.Context) error {
url := strings.TrimSpace(ctx.String(flagURL))
dir := strings.TrimSpace(ctx.String(flagDir))
cache := ctx.Bool(flagCache)
idea := ctx.Bool(flagIdea)
style := ctx.String(flagStyle)
schema := ctx.String(flagSchema)
if len(schema) == 0 {
schema = "public"
}
pattern := strings.TrimSpace(ctx.String(flagTable))
cfg, err := config.NewConfig(style)
if err != nil {
return err
}
return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea)
} }
func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database string) error { func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database string) error {
@ -92,7 +115,7 @@ func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database str
return nil return nil
} }
func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error { func fromMysqlDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error {
log := console.NewConsole(idea) log := console.NewConsole(idea)
if len(url) == 0 { if len(url) == 0 {
log.Error("%v", "expected data source of mysql, but nothing found") log.Error("%v", "expected data source of mysql, but nothing found")
@ -154,3 +177,58 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo
return generator.StartFromInformationSchema(matchTables, cache) return generator.StartFromInformationSchema(matchTables, cache)
} }
func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea bool) error {
log := console.NewConsole(idea)
if len(url) == 0 {
log.Error("%v", "expected data source of mysql, but nothing found")
return nil
}
if len(pattern) == 0 {
log.Error("%v", "expected table or table globbing patterns, but nothing found")
return nil
}
db := postgres.New(url)
im := model.NewPostgreSqlModel(db)
tables, err := im.GetAllTables(schema)
if err != nil {
return err
}
matchTables := make(map[string]*model.Table)
for _, item := range tables {
match, err := filepath.Match(pattern, item)
if err != nil {
return err
}
if !match {
continue
}
columnData, err := im.FindColumns(schema, item)
if err != nil {
return err
}
table, err := columnData.Convert()
if err != nil {
return err
}
matchTables[item] = table
}
if len(matchTables) == 0 {
return errors.New("no tables matched")
}
generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log), gen.WithPostgreSql())
if err != nil {
return err
}
return generator.StartFromInformationSchema(matchTables, cache)
}

@ -9,7 +9,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func genDelete(table Table, withCache bool) (string, string, error) { func genDelete(table Table, withCache, postgreSql bool) (string, string, error) {
keySet := collection.NewSet() keySet := collection.NewSet()
keyVariableSet := collection.NewSet() keyVariableSet := collection.NewSet()
keySet.AddStr(table.PrimaryCacheKey.KeyExpression) keySet.AddStr(table.PrimaryCacheKey.KeyExpression)
@ -34,8 +34,9 @@ func genDelete(table Table, withCache bool) (string, string, error) {
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
"dataType": table.PrimaryKey.DataType, "dataType": table.PrimaryKey.DataType,
"keys": strings.Join(keySet.KeysStr(), "\n"), "keys": strings.Join(keySet.KeysStr(), "\n"),
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql),
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "), "keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
"postgreSql": postgreSql,
}) })
if err != nil { if err != nil {
return "", "", err return "", "", err

@ -6,7 +6,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func genFindOne(table Table, withCache bool) (string, string, error) { func genFindOne(table Table, withCache, postgreSql bool) (string, string, error) {
camel := table.Name.ToCamel() camel := table.Name.ToCamel()
text, err := util.LoadTemplate(category, findOneTemplateFile, template.FindOne) text, err := util.LoadTemplate(category, findOneTemplateFile, template.FindOne)
if err != nil { if err != nil {
@ -19,11 +19,12 @@ func genFindOne(table Table, withCache bool) (string, string, error) {
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": camel, "upperStartCamelObject": camel,
"lowerStartCamelObject": stringx.From(camel).Untitle(), "lowerStartCamelObject": stringx.From(camel).Untitle(),
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql),
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
"dataType": table.PrimaryKey.DataType, "dataType": table.PrimaryKey.DataType,
"cacheKey": table.PrimaryCacheKey.KeyExpression, "cacheKey": table.PrimaryCacheKey.KeyExpression,
"cacheKeyVariable": table.PrimaryCacheKey.KeyLeft, "cacheKeyVariable": table.PrimaryCacheKey.KeyLeft,
"postgreSql": postgreSql,
}) })
if err != nil { if err != nil {
return "", "", err return "", "", err

@ -15,7 +15,7 @@ type findOneCode struct {
cacheExtra string cacheExtra string
} }
func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, error) {
text, err := util.LoadTemplate(category, findOneByFieldTemplateFile, template.FindOneByField) text, err := util.LoadTemplate(category, findOneByFieldTemplateFile, template.FindOneByField)
if err != nil { if err != nil {
return nil, err return nil, err
@ -25,7 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
var list []string var list []string
camelTableName := table.Name.ToCamel() camelTableName := table.Name.ToCamel()
for _, key := range table.UniqueCacheKey { for _, key := range table.UniqueCacheKey {
in, paramJoinString, originalFieldString := convertJoin(key) in, paramJoinString, originalFieldString := convertJoin(key, postgreSql)
output, err := t.Execute(map[string]interface{}{ output, err := t.Execute(map[string]interface{}{
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
@ -38,6 +38,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
"lowerStartCamelField": paramJoinString, "lowerStartCamelField": paramJoinString,
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(), "upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
"originalField": originalFieldString, "originalField": originalFieldString,
"postgreSql": postgreSql,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -87,7 +88,8 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
"primaryKeyLeft": table.PrimaryCacheKey.VarLeft, "primaryKeyLeft": table.PrimaryCacheKey.VarLeft,
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()), "originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql),
"postgreSql": postgreSql,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -106,13 +108,17 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
}, nil }, nil
} }
func convertJoin(key Key) (in, paramJoinString, originalFieldString string) { func convertJoin(key Key, postgreSql bool) (in, paramJoinString, originalFieldString string) {
var inJoin, paramJoin, argJoin Join var inJoin, paramJoin, argJoin Join
for _, f := range key.Fields { for index, f := range key.Fields {
param := stringx.From(f.Name.ToCamel()).Untitle() param := stringx.From(f.Name.ToCamel()).Untitle()
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType)) inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
paramJoin = append(paramJoin, param) paramJoin = append(paramJoin, param)
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source()))) if postgreSql {
argJoin = append(argJoin, fmt.Sprintf("%s = $%d", wrapWithRawString(f.Name.Source(), postgreSql), index+1))
} else {
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source(), postgreSql)))
}
} }
if len(inJoin) > 0 { if len(inJoin) > 0 {
in = inJoin.With(", ").Source() in = inJoin.With(", ").Source()

@ -29,8 +29,9 @@ type (
// source string // source string
dir string dir string
console.Console console.Console
pkg string pkg string
cfg *config.Config cfg *config.Config
isPostgreSql bool
} }
// Option defines a function with argument defaultGenerator // Option defines a function with argument defaultGenerator
@ -84,6 +85,13 @@ func WithConsoleOption(c console.Console) Option {
} }
} }
// WithPostgreSql marks defaultGenerator.isPostgreSql true
func WithPostgreSql() Option {
return func(generator *defaultGenerator) {
generator.isPostgreSql = true
}
}
func newDefaultOption() Option { func newDefaultOption() Option {
return func(generator *defaultGenerator) { return func(generator *defaultGenerator) {
generator.Console = console.NewColorConsole() generator.Console = console.NewColorConsole()
@ -219,34 +227,34 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
table.UniqueCacheKey = uniqueKey table.UniqueCacheKey = uniqueKey
table.ContainsUniqueCacheKey = len(uniqueKey) > 0 table.ContainsUniqueCacheKey = len(uniqueKey) > 0
varsCode, err := genVars(table, withCache) varsCode, err := genVars(table, withCache, g.isPostgreSql)
if err != nil { if err != nil {
return "", err return "", err
} }
insertCode, insertCodeMethod, err := genInsert(table, withCache) insertCode, insertCodeMethod, err := genInsert(table, withCache, g.isPostgreSql)
if err != nil { if err != nil {
return "", err return "", err
} }
findCode := make([]string, 0) findCode := make([]string, 0)
findOneCode, findOneCodeMethod, err := genFindOne(table, withCache) findOneCode, findOneCodeMethod, err := genFindOne(table, withCache, g.isPostgreSql)
if err != nil { if err != nil {
return "", err return "", err
} }
ret, err := genFindOneByField(table, withCache) ret, err := genFindOneByField(table, withCache, g.isPostgreSql)
if err != nil { if err != nil {
return "", err return "", err
} }
findCode = append(findCode, findOneCode, ret.findOneMethod) findCode = append(findCode, findOneCode, ret.findOneMethod)
updateCode, updateCodeMethod, err := genUpdate(table, withCache) updateCode, updateCodeMethod, err := genUpdate(table, withCache, g.isPostgreSql)
if err != nil { if err != nil {
return "", err return "", err
} }
deleteCode, deleteCodeMethod, err := genDelete(table, withCache) deleteCode, deleteCodeMethod, err := genDelete(table, withCache, g.isPostgreSql)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -258,7 +266,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", err return "", err
} }
newCode, err := genNew(table, withCache) newCode, err := genNew(table, withCache, g.isPostgreSql)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -309,7 +317,11 @@ func (g *defaultGenerator) executeModel(code *code) (*bytes.Buffer, error) {
return output, nil return output, nil
} }
func wrapWithRawString(v string) string { func wrapWithRawString(v string, postgreSql bool) string {
if postgreSql {
return v
}
if v == "`" { if v == "`" {
return v return v
} }

@ -92,10 +92,11 @@ func TestNamingModel(t *testing.T) {
} }
func TestWrapWithRawString(t *testing.T) { func TestWrapWithRawString(t *testing.T) {
assert.Equal(t, "``", wrapWithRawString("")) assert.Equal(t, "``", wrapWithRawString("", false))
assert.Equal(t, "``", wrapWithRawString("``")) assert.Equal(t, "``", wrapWithRawString("``", false))
assert.Equal(t, "`a`", wrapWithRawString("a")) assert.Equal(t, "`a`", wrapWithRawString("a", false))
assert.Equal(t, "` `", wrapWithRawString(" ")) assert.Equal(t, "a", wrapWithRawString("a", true))
assert.Equal(t, "` `", wrapWithRawString(" ", false))
} }
func TestFields(t *testing.T) { func TestFields(t *testing.T) {

@ -1,6 +1,7 @@
package gen package gen
import ( import (
"fmt"
"strings" "strings"
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
@ -9,7 +10,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func genInsert(table Table, withCache bool) (string, string, error) { func genInsert(table Table, withCache, postgreSql bool) (string, string, error) {
keySet := collection.NewSet() keySet := collection.NewSet()
keyVariableSet := collection.NewSet() keyVariableSet := collection.NewSet()
for _, key := range table.UniqueCacheKey { for _, key := range table.UniqueCacheKey {
@ -19,6 +20,7 @@ func genInsert(table Table, withCache bool) (string, string, error) {
expressions := make([]string, 0) expressions := make([]string, 0)
expressionValues := make([]string, 0) expressionValues := make([]string, 0)
var count int
for _, field := range table.Fields { for _, field := range table.Fields {
camel := field.Name.ToCamel() camel := field.Name.ToCamel()
if camel == "CreateTime" || camel == "UpdateTime" { if camel == "CreateTime" || camel == "UpdateTime" {
@ -31,7 +33,12 @@ func genInsert(table Table, withCache bool) (string, string, error) {
} }
} }
expressions = append(expressions, "?") count += 1
if postgreSql {
expressions = append(expressions, fmt.Sprintf("$%d", count))
} else {
expressions = append(expressions, "?")
}
expressionValues = append(expressionValues, "data."+camel) expressionValues = append(expressionValues, "data."+camel)
} }

@ -1,20 +1,27 @@
package gen package gen
import ( import (
"fmt"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
func genNew(table Table, withCache bool) (string, error) { func genNew(table Table, withCache, postgreSql bool) (string, error) {
text, err := util.LoadTemplate(category, modelNewTemplateFile, template.New) text, err := util.LoadTemplate(category, modelNewTemplateFile, template.New)
if err != nil { if err != nil {
return "", err return "", err
} }
t := fmt.Sprintf(`"%s"`, wrapWithRawString(table.Name.Source(), postgreSql))
if postgreSql {
t = "`" + fmt.Sprintf(`"%s"."%s"`, table.Db.Source(), table.Name.Source()) + "`"
}
output, err := util.With("new"). output, err := util.With("new").
Parse(text). Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"table": wrapWithRawString(table.Name.Source()), "table": t,
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": table.Name.ToCamel(), "upperStartCamelObject": table.Name.ToCamel(),
}) })

@ -9,7 +9,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func genUpdate(table Table, withCache bool) (string, string, error) { func genUpdate(table Table, withCache, postgreSql bool) (string, string, error) {
expressionValues := make([]string, 0) expressionValues := make([]string, 0)
for _, field := range table.Fields { for _, field := range table.Fields {
camel := field.Name.ToCamel() camel := field.Name.ToCamel()
@ -50,8 +50,9 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
"primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression, "primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression,
"primaryKeyVariable": table.PrimaryCacheKey.KeyLeft, "primaryKeyVariable": table.PrimaryCacheKey.KeyLeft,
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql),
"expressionValues": strings.Join(expressionValues, ", "), "expressionValues": strings.Join(expressionValues, ", "),
"postgreSql": postgreSql,
}) })
if err != nil { if err != nil {
return "", "", nil return "", "", nil

@ -8,7 +8,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func genVars(table Table, withCache bool) (string, error) { func genVars(table Table, withCache, postgreSql bool) (string, error) {
keys := make([]string, 0) keys := make([]string, 0)
keys = append(keys, table.PrimaryCacheKey.VarExpression) keys = append(keys, table.PrimaryCacheKey.VarExpression)
for _, v := range table.UniqueCacheKey { for _, v := range table.UniqueCacheKey {
@ -27,8 +27,9 @@ func genVars(table Table, withCache bool) (string, error) {
"upperStartCamelObject": camel, "upperStartCamelObject": camel,
"cacheKeys": strings.Join(keys, "\n"), "cacheKeys": strings.Join(keys, "\n"),
"autoIncrement": table.PrimaryKey.AutoIncrement, "autoIncrement": table.PrimaryKey.AutoIncrement,
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql),
"withCache": withCache, "withCache": withCache,
"postgreSql": postgreSql,
}) })
if err != nil { if err != nil {
return "", err return "", err

@ -0,0 +1,234 @@
package model
import (
"database/sql"
"strings"
"github.com/tal-tech/go-zero/core/stores/sqlx"
)
var (
p2m = map[string]string{
"int8": "bigint",
"numeric": "bigint",
"float8": "double",
"float4": "float",
"int2": "smallint",
"int4": "integer",
}
)
// PostgreSqlModel gets table information from information_schema、pg_catalog
type PostgreSqlModel struct {
conn sqlx.SqlConn
}
// PostgreColumn describes a column in table
type PostgreColumn struct {
Num sql.NullInt32 `db:"num"`
Field sql.NullString `db:"field"`
Type sql.NullString `db:"type"`
NotNull sql.NullBool `db:"not_null"`
Comment sql.NullString `db:"comment"`
ColumnDefault sql.NullString `db:"column_default"`
IdentityIncrement sql.NullInt32 `db:"identity_increment"`
}
// PostgreIndex describes an index for a column
type PostgreIndex struct {
IndexName sql.NullString `db:"index_name"`
IndexId sql.NullInt32 `db:"index_id"`
IsUnique sql.NullBool `db:"is_unique"`
IsPrimary sql.NullBool `db:"is_primary"`
ColumnName sql.NullString `db:"column_name"`
IndexSort sql.NullInt32 `db:"index_sort"`
}
// NewPostgreSqlModel creates an instance and return
func NewPostgreSqlModel(conn sqlx.SqlConn) *PostgreSqlModel {
return &PostgreSqlModel{
conn: conn,
}
}
// GetAllTables selects all tables from TABLE_SCHEMA
func (m *PostgreSqlModel) GetAllTables(schema string) ([]string, error) {
query := `select table_name from information_schema.tables where table_schema = $1`
var tables []string
err := m.conn.QueryRows(&tables, query, schema)
if err != nil {
return nil, err
}
return tables, nil
}
// FindColumns return columns in specified database and table
func (m *PostgreSqlModel) FindColumns(schema, table string) (*ColumnData, error) {
querySql := `select t.num,t.field,t.type,t.not_null,t.comment, c.column_default, identity_increment
from (
SELECT a.attnum AS num,
c.relname,
a.attname AS field,
t.typname AS type,
a.atttypmod AS lengthvar,
a.attnotnull AS not_null,
b.description AS comment
FROM pg_class c,
pg_attribute a
LEFT OUTER JOIN pg_description b ON a.attrelid = b.objoid AND a.attnum = b.objsubid,
pg_type t
WHERE c.relname = $1
and a.attnum > 0
and a.attrelid = c.oid
and a.atttypid = t.oid
ORDER BY a.attnum) AS t
left join information_schema.columns AS c on t.relname = c.table_name
and t.field = c.column_name and c.table_schema = $2`
var reply []*PostgreColumn
err := m.conn.QueryRowsPartial(&reply, querySql, table, schema)
if err != nil {
return nil, err
}
list, err := m.getColumns(schema, table, reply)
if err != nil {
return nil, err
}
var columnData ColumnData
columnData.Db = schema
columnData.Table = table
columnData.Columns = list
return &columnData, nil
}
func (m *PostgreSqlModel) getColumns(schema, table string, in []*PostgreColumn) ([]*Column, error) {
index, err := m.getIndex(schema, table)
if err != nil {
return nil, err
}
var list []*Column
for _, e := range in {
var dft interface{}
if len(e.ColumnDefault.String) > 0 {
dft = e.ColumnDefault
}
isNullAble := "YES"
if e.NotNull.Bool {
isNullAble = "NO"
}
extra := "auto_increment"
if e.IdentityIncrement.Int32 != 1 {
extra = ""
}
if len(index[e.Field.String]) > 0 {
for _, i := range index[e.Field.String] {
list = append(list, &Column{
DbColumn: &DbColumn{
Name: e.Field.String,
DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String),
Extra: extra,
Comment: e.Comment.String,
ColumnDefault: dft,
IsNullAble: isNullAble,
OrdinalPosition: int(e.Num.Int32),
},
Index: i,
})
}
} else {
list = append(list, &Column{
DbColumn: &DbColumn{
Name: e.Field.String,
DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String),
Extra: extra,
Comment: e.Comment.String,
ColumnDefault: dft,
IsNullAble: isNullAble,
OrdinalPosition: int(e.Num.Int32),
},
})
}
}
return list, nil
}
func (m *PostgreSqlModel) convertPostgreSqlTypeIntoMysqlType(in string) string {
r, ok := p2m[strings.ToLower(in)]
if ok {
return r
}
return in
}
func (m *PostgreSqlModel) getIndex(schema, table string) (map[string][]*DbIndex, error) {
indexes, err := m.FindIndex(schema, table)
if err != nil {
return nil, err
}
var index = make(map[string][]*DbIndex)
for _, e := range indexes {
if e.IsPrimary.Bool {
index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{
IndexName: indexPri,
SeqInIndex: int(e.IndexSort.Int32),
})
continue
}
nonUnique := 0
if !e.IsUnique.Bool {
nonUnique = 1
}
index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{
IndexName: e.IndexName.String,
NonUnique: nonUnique,
SeqInIndex: int(e.IndexSort.Int32),
})
}
return index, nil
}
// FindIndex finds index with given schema, table and column.
func (m *PostgreSqlModel) FindIndex(schema, table string) ([]*PostgreIndex, error) {
querySql := `select A.INDEXNAME AS index_name,
C.INDEXRELID AS index_id,
C.INDISUNIQUE AS is_unique,
C.INDISPRIMARY AS is_primary,
G.ATTNAME AS column_name,
G.attnum AS index_sort
from PG_AM B
left join PG_CLASS F on
B.OID = F.RELAM
left join PG_STAT_ALL_INDEXES E on
F.OID = E.INDEXRELID
left join PG_INDEX C on
E.INDEXRELID = C.INDEXRELID
left outer join PG_DESCRIPTION D on
C.INDEXRELID = D.OBJOID,
PG_INDEXES A,
pg_attribute G
where A.SCHEMANAME = E.SCHEMANAME
and A.TABLENAME = E.RELNAME
and A.INDEXNAME = E.INDEXRELNAME
and F.oid = G.attrelid
and E.SCHEMANAME = $1
and E.RELNAME = $2
order by C.INDEXRELID,G.attnum`
var reply []*PostgreIndex
err := m.conn.QueryRowsPartial(&reply, querySql, schema, table)
if err != nil {
return nil, err
}
return reply, nil
}

@ -10,9 +10,9 @@ func (m *default{{.upperStartCamelObject}}Model) Delete({{.lowerStartCamelPrimar
{{.keys}} {{.keys}}
_, err {{if .containsIndexCache}}={{else}}:={{end}} m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { _, err {{if .containsIndexCache}}={{else}}:={{end}} m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = ?", m.table) query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table)
return conn.Exec(query, {{.lowerStartCamelPrimaryKey}}) return conn.Exec(query, {{.lowerStartCamelPrimaryKey}})
}, {{.keyValues}}){{else}}query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = ?", m.table) }, {{.keyValues}}){{else}}query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table)
_,err:=m.conn.Exec(query, {{.lowerStartCamelPrimaryKey}}){{end}} _,err:=m.conn.Exec(query, {{.lowerStartCamelPrimaryKey}}){{end}}
return err return err
} }

@ -6,7 +6,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOne({{.lowerStartCamelPrima
{{if .withCache}}{{.cacheKey}} {{if .withCache}}{{.cacheKey}}
var resp {{.upperStartCamelObject}} var resp {{.upperStartCamelObject}}
err := m.QueryRow(&resp, {{.cacheKeyVariable}}, func(conn sqlx.SqlConn, v interface{}) error { err := m.QueryRow(&resp, {{.cacheKeyVariable}}, func(conn sqlx.SqlConn, v interface{}) error {
query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table) query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}} limit 1", {{.lowerStartCamelObject}}Rows, m.table)
return conn.QueryRow(v, query, {{.lowerStartCamelPrimaryKey}}) return conn.QueryRow(v, query, {{.lowerStartCamelPrimaryKey}})
}) })
switch err { switch err {
@ -16,7 +16,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOne({{.lowerStartCamelPrima
return nil, ErrNotFound return nil, ErrNotFound
default: default:
return nil, err return nil, err
}{{else}}query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table) }{{else}}query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}} limit 1", {{.lowerStartCamelObject}}Rows, m.table)
var resp {{.upperStartCamelObject}} var resp {{.upperStartCamelObject}}
err := m.conn.QueryRow(&resp, query, {{.lowerStartCamelPrimaryKey}}) err := m.conn.QueryRow(&resp, query, {{.lowerStartCamelPrimaryKey}})
switch err { switch err {
@ -71,7 +71,7 @@ func (m *default{{.upperStartCamelObject}}Model) formatPrimary(primary interface
} }
func (m *default{{.upperStartCamelObject}}Model) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error { func (m *default{{.upperStartCamelObject}}Model) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where {{.originalPrimaryField}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table ) query := fmt.Sprintf("select %s from %s where {{.originalPrimaryField}} = {{if .postgreSql}}$1{{else}}?{{end}} limit 1", {{.lowerStartCamelObject}}Rows, m.table )
return conn.QueryRow(v, query, primary) return conn.QueryRow(v, query, primary)
} }
` `

@ -5,7 +5,7 @@ var New = `
func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) {{.upperStartCamelObject}}Model { func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) {{.upperStartCamelObject}}Model {
return &default{{.upperStartCamelObject}}Model{ return &default{{.upperStartCamelObject}}Model{
{{if .withCache}}CachedConn: sqlc.NewConn(conn, c){{else}}conn:conn{{end}}, {{if .withCache}}CachedConn: sqlc.NewConn(conn, c){{else}}conn:conn{{end}},
table: "{{.table}}", table: {{.table}},
} }
} }
` `

@ -5,9 +5,9 @@ var Update = `
func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error { func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error {
{{if .withCache}}{{.keys}} {{if .withCache}}{{.keys}}
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
return conn.Exec(query, {{.expressionValues}}) return conn.Exec(query, {{.expressionValues}})
}, {{.keyValues}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) }, {{.keyValues}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
_,err:=m.conn.Exec(query, {{.expressionValues}}){{end}} _,err:=m.conn.Exec(query, {{.expressionValues}}){{end}}
return err return err
} }

@ -5,11 +5,14 @@ import "fmt"
// Vars defines a template for var block in model // Vars defines a template for var block in model
var Vars = fmt.Sprintf(` var Vars = fmt.Sprintf(`
var ( var (
{{.lowerStartCamelObject}}FieldNames = builderx.RawFieldNames(&{{.upperStartCamelObject}}{}) {{.lowerStartCamelObject}}FieldNames = builderx.RawFieldNames(&{{.upperStartCamelObject}}{}{{if .postgreSql}},true{{end}})
{{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",") {{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",")
{{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ",") {{.lowerStartCamelObject}}RowsExpectAutoSet = {{if .postgreSql}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{else}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{end}}
{{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?" {{.lowerStartCamelObject}}RowsWithPlaceHolder = {{if .postgreSql}}builderx.PostgreSqlJoin(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s")){{else}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?"{{end}}
{{if .withCache}}{{.cacheKeys}}{{end}} {{if .withCache}}{{.cacheKeys}}{{end}}
) )
`, "`", "`", "`", "`", "`", "`", "`", "`") `, "", "", "", "", // postgreSql mode
"`", "`", "`", "`",
"", "", "", "", // postgreSql mode
"`", "`", "`", "`")

Loading…
Cancel
Save