From 089cdaa75f0530527f822474887766b6c94ee30f Mon Sep 17 00:00:00 2001 From: anqiansong Date: Fri, 23 Jul 2021 11:45:15 +0800 Subject: [PATCH] Feature model postgresql (#842) * Support postgresql generate * Update template Var * Support to generate postgresql model * Support to generate postgresql model * Update template Co-authored-by: anqiansong --- tools/goctl/goctl.go | 43 +++- tools/goctl/model/sql/builderx/builder.go | 33 ++- .../goctl/model/sql/builderx/builder_test.go | 5 + tools/goctl/model/sql/command/command.go | 100 +++++++- tools/goctl/model/sql/gen/delete.go | 5 +- tools/goctl/model/sql/gen/findone.go | 5 +- tools/goctl/model/sql/gen/findonebyfield.go | 18 +- tools/goctl/model/sql/gen/gen.go | 32 ++- tools/goctl/model/sql/gen/gen_test.go | 9 +- tools/goctl/model/sql/gen/insert.go | 11 +- tools/goctl/model/sql/gen/new.go | 11 +- tools/goctl/model/sql/gen/update.go | 5 +- tools/goctl/model/sql/gen/vars.go | 5 +- .../goctl/model/sql/model/postgresqlmodel.go | 234 ++++++++++++++++++ tools/goctl/model/sql/template/delete.go | 4 +- tools/goctl/model/sql/template/find.go | 6 +- tools/goctl/model/sql/template/new.go | 2 +- tools/goctl/model/sql/template/update.go | 4 +- tools/goctl/model/sql/template/vars.go | 11 +- 19 files changed, 484 insertions(+), 59 deletions(-) create mode 100644 tools/goctl/model/sql/model/postgresqlmodel.go diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index 89cd49c9..1dd3c335 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -455,7 +455,48 @@ var ( Usage: "for idea plugin [optional]", }, }, - Action: model.MyDataSource, + Action: model.MySqlDataSource, + }, + }, + }, + { + Name: "postgresql", + Usage: `generate postgresql model`, + Subcommands: []cli.Command{ + { + Name: "datasource", + Usage: `generate model from datasource`, + Flags: []cli.Flag{ + cli.StringFlag{ + Name: "url", + Usage: `the data source of database,like "root:password@tcp(127.0.0.1:3306)/database`, + }, + cli.StringFlag{ + Name: "table, t", + Usage: `the table or table globbing patterns in the database`, + }, + cli.StringFlag{ + Name: "schema, s", + Usage: `the table schema, default is [public]`, + }, + cli.BoolFlag{ + Name: "cache, c", + Usage: "generate code with cache [optional]", + }, + cli.StringFlag{ + Name: "dir, d", + Usage: "the target dir", + }, + cli.StringFlag{ + Name: "style", + Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]", + }, + cli.BoolFlag{ + Name: "idea", + Usage: "for idea plugin [optional]", + }, + }, + Action: model.PostgreSqlDataSource, }, }, }, diff --git a/tools/goctl/model/sql/builderx/builder.go b/tools/goctl/model/sql/builderx/builder.go index 8b7ed2cf..a50e8915 100644 --- a/tools/goctl/model/sql/builderx/builder.go +++ b/tools/goctl/model/sql/builderx/builder.go @@ -3,6 +3,7 @@ package builderx import ( "fmt" "reflect" + "strings" "github.com/go-xorm/builder" ) @@ -81,13 +82,18 @@ func FieldNames(in interface{}) []string { } // RawFieldNames converts golang struct field into slice string -func RawFieldNames(in interface{}) []string { +func RawFieldNames(in interface{}, postgresSql ...bool) []string { out := make([]string, 0) v := reflect.ValueOf(in) if v.Kind() == reflect.Ptr { v = v.Elem() } + var pg bool + if len(postgresSql) > 0 { + pg = postgresSql[0] + } + // we only accept structs if v.Kind() != reflect.Struct { panic(fmt.Errorf("ToMap only accepts structs; got %T", v)) @@ -98,11 +104,32 @@ func RawFieldNames(in interface{}) []string { // gets us a StructField fi := typ.Field(i) if tagv := fi.Tag.Get(dbTag); tagv != "" { - out = append(out, fmt.Sprintf("`%s`", tagv)) + if pg { + out = append(out, fmt.Sprintf("%s", tagv)) + } else { + out = append(out, fmt.Sprintf("`%s`", tagv)) + } } else { - out = append(out, fmt.Sprintf(`"%s"`, fi.Name)) + if pg { + out = append(out, fmt.Sprintf("%s", fi.Name)) + } else { + out = append(out, fmt.Sprintf("`%s`", fi.Name)) + } } } return out } + +func PostgreSqlJoin(elems []string) string { + var b = new(strings.Builder) + for index, e := range elems { + b.WriteString(fmt.Sprintf("%s = $%d, ", e, index+1)) + } + + if b.Len() == 0 { + return b.String() + } + + return b.String()[0 : b.Len()-2] +} diff --git a/tools/goctl/model/sql/builderx/builder_test.go b/tools/goctl/model/sql/builderx/builder_test.go index 8199963b..31bf7680 100644 --- a/tools/goctl/model/sql/builderx/builder_test.go +++ b/tools/goctl/model/sql/builderx/builder_test.go @@ -118,3 +118,8 @@ func TestBuildSqlLike(t *testing.T) { assert.Equal(t, sql, actualSQL) assert.Equal(t, args, actualArgs) } + +func TestJoin(t *testing.T) { + ret := PostgreSqlJoin([]string{"name", "age"}) + assert.Equal(t, "name = $1, age = $2", ret) +} diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index c3a228b0..282b6989 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -7,6 +7,7 @@ import ( "github.com/go-sql-driver/mysql" "github.com/tal-tech/go-zero/core/logx" + "github.com/tal-tech/go-zero/core/stores/postgres" "github.com/tal-tech/go-zero/core/stores/sqlx" "github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" @@ -17,14 +18,15 @@ import ( ) const ( - flagSrc = "src" - flagDir = "dir" - flagCache = "cache" - flagIdea = "idea" - flagURL = "url" - flagTable = "table" - flagStyle = "style" + flagSrc = "src" + flagDir = "dir" + flagCache = "cache" + flagIdea = "idea" + flagURL = "url" + flagTable = "table" + flagStyle = "style" flagDatabase = "database" + flagSchema = "schema" ) var errNotMatched = errors.New("sql not matched") @@ -45,8 +47,8 @@ func MysqlDDL(ctx *cli.Context) error { return fromDDl(src, dir, cfg, cache, idea, database) } -// MyDataSource generates model code from datasource -func MyDataSource(ctx *cli.Context) error { +// MySqlDataSource generates model code from datasource +func MySqlDataSource(ctx *cli.Context) error { url := strings.TrimSpace(ctx.String(flagURL)) dir := strings.TrimSpace(ctx.String(flagDir)) cache := ctx.Bool(flagCache) @@ -58,7 +60,28 @@ func MyDataSource(ctx *cli.Context) error { return err } - return fromDataSource(url, pattern, dir, cfg, cache, idea) + return fromMysqlDataSource(url, pattern, dir, cfg, cache, idea) +} + +// PostgreSqlDataSource generates model code from datasource +func PostgreSqlDataSource(ctx *cli.Context) error { + url := strings.TrimSpace(ctx.String(flagURL)) + dir := strings.TrimSpace(ctx.String(flagDir)) + cache := ctx.Bool(flagCache) + idea := ctx.Bool(flagIdea) + style := ctx.String(flagStyle) + schema := ctx.String(flagSchema) + if len(schema) == 0 { + schema = "public" + } + + pattern := strings.TrimSpace(ctx.String(flagTable)) + cfg, err := config.NewConfig(style) + if err != nil { + return err + } + + return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea) } func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database string) error { @@ -92,7 +115,7 @@ func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database str return nil } -func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error { +func fromMysqlDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error { log := console.NewConsole(idea) if len(url) == 0 { log.Error("%v", "expected data source of mysql, but nothing found") @@ -154,3 +177,58 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo return generator.StartFromInformationSchema(matchTables, cache) } + +func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea bool) error { + log := console.NewConsole(idea) + if len(url) == 0 { + log.Error("%v", "expected data source of mysql, but nothing found") + return nil + } + + if len(pattern) == 0 { + log.Error("%v", "expected table or table globbing patterns, but nothing found") + return nil + } + db := postgres.New(url) + im := model.NewPostgreSqlModel(db) + + tables, err := im.GetAllTables(schema) + if err != nil { + return err + } + + matchTables := make(map[string]*model.Table) + for _, item := range tables { + match, err := filepath.Match(pattern, item) + if err != nil { + return err + } + + if !match { + continue + } + + columnData, err := im.FindColumns(schema, item) + if err != nil { + return err + } + + table, err := columnData.Convert() + if err != nil { + return err + } + + matchTables[item] = table + } + + if len(matchTables) == 0 { + return errors.New("no tables matched") + } + + generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log), gen.WithPostgreSql()) + if err != nil { + return err + } + + return generator.StartFromInformationSchema(matchTables, cache) +} diff --git a/tools/goctl/model/sql/gen/delete.go b/tools/goctl/model/sql/gen/delete.go index f6758ce2..c7ee7b79 100644 --- a/tools/goctl/model/sql/gen/delete.go +++ b/tools/goctl/model/sql/gen/delete.go @@ -9,7 +9,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -func genDelete(table Table, withCache bool) (string, string, error) { +func genDelete(table Table, withCache, postgreSql bool) (string, string, error) { keySet := collection.NewSet() keyVariableSet := collection.NewSet() keySet.AddStr(table.PrimaryCacheKey.KeyExpression) @@ -34,8 +34,9 @@ func genDelete(table Table, withCache bool) (string, string, error) { "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "dataType": table.PrimaryKey.DataType, "keys": strings.Join(keySet.KeysStr(), "\n"), - "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql), "keyValues": strings.Join(keyVariableSet.KeysStr(), ", "), + "postgreSql": postgreSql, }) if err != nil { return "", "", err diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index e9219206..c07b9a9c 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -6,7 +6,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -func genFindOne(table Table, withCache bool) (string, string, error) { +func genFindOne(table Table, withCache, postgreSql bool) (string, string, error) { camel := table.Name.ToCamel() text, err := util.LoadTemplate(category, findOneTemplateFile, template.FindOne) if err != nil { @@ -19,11 +19,12 @@ func genFindOne(table Table, withCache bool) (string, string, error) { "withCache": withCache, "upperStartCamelObject": camel, "lowerStartCamelObject": stringx.From(camel).Untitle(), - "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql), "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(), "dataType": table.PrimaryKey.DataType, "cacheKey": table.PrimaryCacheKey.KeyExpression, "cacheKeyVariable": table.PrimaryCacheKey.KeyLeft, + "postgreSql": postgreSql, }) if err != nil { return "", "", err diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index 53cccacf..9a020084 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -15,7 +15,7 @@ type findOneCode struct { cacheExtra string } -func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { +func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, error) { text, err := util.LoadTemplate(category, findOneByFieldTemplateFile, template.FindOneByField) if err != nil { return nil, err @@ -25,7 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { var list []string camelTableName := table.Name.ToCamel() for _, key := range table.UniqueCacheKey { - in, paramJoinString, originalFieldString := convertJoin(key) + in, paramJoinString, originalFieldString := convertJoin(key, postgreSql) output, err := t.Execute(map[string]interface{}{ "upperStartCamelObject": camelTableName, @@ -38,6 +38,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { "lowerStartCamelField": paramJoinString, "upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(), "originalField": originalFieldString, + "postgreSql": postgreSql, }) if err != nil { return nil, err @@ -87,7 +88,8 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { "upperStartCamelObject": camelTableName, "primaryKeyLeft": table.PrimaryCacheKey.VarLeft, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), - "originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()), + "originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql), + "postgreSql": postgreSql, }) if err != nil { return nil, err @@ -106,13 +108,17 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { }, nil } -func convertJoin(key Key) (in, paramJoinString, originalFieldString string) { +func convertJoin(key Key, postgreSql bool) (in, paramJoinString, originalFieldString string) { var inJoin, paramJoin, argJoin Join - for _, f := range key.Fields { + for index, f := range key.Fields { param := stringx.From(f.Name.ToCamel()).Untitle() inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType)) paramJoin = append(paramJoin, param) - argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source()))) + if postgreSql { + argJoin = append(argJoin, fmt.Sprintf("%s = $%d", wrapWithRawString(f.Name.Source(), postgreSql), index+1)) + } else { + argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source(), postgreSql))) + } } if len(inJoin) > 0 { in = inJoin.With(", ").Source() diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 5ed1c20c..300c47cd 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -29,8 +29,9 @@ type ( // source string dir string console.Console - pkg string - cfg *config.Config + pkg string + cfg *config.Config + isPostgreSql bool } // Option defines a function with argument defaultGenerator @@ -84,6 +85,13 @@ func WithConsoleOption(c console.Console) Option { } } +// WithPostgreSql marks defaultGenerator.isPostgreSql true +func WithPostgreSql() Option { + return func(generator *defaultGenerator) { + generator.isPostgreSql = true + } +} + func newDefaultOption() Option { return func(generator *defaultGenerator) { generator.Console = console.NewColorConsole() @@ -219,34 +227,34 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er table.UniqueCacheKey = uniqueKey table.ContainsUniqueCacheKey = len(uniqueKey) > 0 - varsCode, err := genVars(table, withCache) + varsCode, err := genVars(table, withCache, g.isPostgreSql) if err != nil { return "", err } - insertCode, insertCodeMethod, err := genInsert(table, withCache) + insertCode, insertCodeMethod, err := genInsert(table, withCache, g.isPostgreSql) if err != nil { return "", err } findCode := make([]string, 0) - findOneCode, findOneCodeMethod, err := genFindOne(table, withCache) + findOneCode, findOneCodeMethod, err := genFindOne(table, withCache, g.isPostgreSql) if err != nil { return "", err } - ret, err := genFindOneByField(table, withCache) + ret, err := genFindOneByField(table, withCache, g.isPostgreSql) if err != nil { return "", err } findCode = append(findCode, findOneCode, ret.findOneMethod) - updateCode, updateCodeMethod, err := genUpdate(table, withCache) + updateCode, updateCodeMethod, err := genUpdate(table, withCache, g.isPostgreSql) if err != nil { return "", err } - deleteCode, deleteCodeMethod, err := genDelete(table, withCache) + deleteCode, deleteCodeMethod, err := genDelete(table, withCache, g.isPostgreSql) if err != nil { return "", err } @@ -258,7 +266,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er return "", err } - newCode, err := genNew(table, withCache) + newCode, err := genNew(table, withCache, g.isPostgreSql) if err != nil { return "", err } @@ -309,7 +317,11 @@ func (g *defaultGenerator) executeModel(code *code) (*bytes.Buffer, error) { return output, nil } -func wrapWithRawString(v string) string { +func wrapWithRawString(v string, postgreSql bool) string { + if postgreSql { + return v + } + if v == "`" { return v } diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index a173a36d..2d87e003 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -92,10 +92,11 @@ func TestNamingModel(t *testing.T) { } func TestWrapWithRawString(t *testing.T) { - assert.Equal(t, "``", wrapWithRawString("")) - assert.Equal(t, "``", wrapWithRawString("``")) - assert.Equal(t, "`a`", wrapWithRawString("a")) - assert.Equal(t, "` `", wrapWithRawString(" ")) + assert.Equal(t, "``", wrapWithRawString("", false)) + assert.Equal(t, "``", wrapWithRawString("``", false)) + assert.Equal(t, "`a`", wrapWithRawString("a", false)) + assert.Equal(t, "a", wrapWithRawString("a", true)) + assert.Equal(t, "` `", wrapWithRawString(" ", false)) } func TestFields(t *testing.T) { diff --git a/tools/goctl/model/sql/gen/insert.go b/tools/goctl/model/sql/gen/insert.go index a82d46cb..7a055afe 100644 --- a/tools/goctl/model/sql/gen/insert.go +++ b/tools/goctl/model/sql/gen/insert.go @@ -1,6 +1,7 @@ package gen import ( + "fmt" "strings" "github.com/tal-tech/go-zero/core/collection" @@ -9,7 +10,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -func genInsert(table Table, withCache bool) (string, string, error) { +func genInsert(table Table, withCache, postgreSql bool) (string, string, error) { keySet := collection.NewSet() keyVariableSet := collection.NewSet() for _, key := range table.UniqueCacheKey { @@ -19,6 +20,7 @@ func genInsert(table Table, withCache bool) (string, string, error) { expressions := make([]string, 0) expressionValues := make([]string, 0) + var count int for _, field := range table.Fields { camel := field.Name.ToCamel() if camel == "CreateTime" || camel == "UpdateTime" { @@ -31,7 +33,12 @@ func genInsert(table Table, withCache bool) (string, string, error) { } } - expressions = append(expressions, "?") + count += 1 + if postgreSql { + expressions = append(expressions, fmt.Sprintf("$%d", count)) + } else { + expressions = append(expressions, "?") + } expressionValues = append(expressionValues, "data."+camel) } diff --git a/tools/goctl/model/sql/gen/new.go b/tools/goctl/model/sql/gen/new.go index 79d72d3e..50e1dd09 100644 --- a/tools/goctl/model/sql/gen/new.go +++ b/tools/goctl/model/sql/gen/new.go @@ -1,20 +1,27 @@ package gen import ( + "fmt" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/util" ) -func genNew(table Table, withCache bool) (string, error) { +func genNew(table Table, withCache, postgreSql bool) (string, error) { text, err := util.LoadTemplate(category, modelNewTemplateFile, template.New) if err != nil { return "", err } + t := fmt.Sprintf(`"%s"`, wrapWithRawString(table.Name.Source(), postgreSql)) + if postgreSql { + t = "`" + fmt.Sprintf(`"%s"."%s"`, table.Db.Source(), table.Name.Source()) + "`" + } + output, err := util.With("new"). Parse(text). Execute(map[string]interface{}{ - "table": wrapWithRawString(table.Name.Source()), + "table": t, "withCache": withCache, "upperStartCamelObject": table.Name.ToCamel(), }) diff --git a/tools/goctl/model/sql/gen/update.go b/tools/goctl/model/sql/gen/update.go index 4fb2bc1e..89248284 100644 --- a/tools/goctl/model/sql/gen/update.go +++ b/tools/goctl/model/sql/gen/update.go @@ -9,7 +9,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -func genUpdate(table Table, withCache bool) (string, string, error) { +func genUpdate(table Table, withCache, postgreSql bool) (string, string, error) { expressionValues := make([]string, 0) for _, field := range table.Fields { camel := field.Name.ToCamel() @@ -50,8 +50,9 @@ func genUpdate(table Table, withCache bool) (string, string, error) { "primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression, "primaryKeyVariable": table.PrimaryCacheKey.KeyLeft, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), - "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql), "expressionValues": strings.Join(expressionValues, ", "), + "postgreSql": postgreSql, }) if err != nil { return "", "", nil diff --git a/tools/goctl/model/sql/gen/vars.go b/tools/goctl/model/sql/gen/vars.go index 6be4fc70..4d8cede3 100644 --- a/tools/goctl/model/sql/gen/vars.go +++ b/tools/goctl/model/sql/gen/vars.go @@ -8,7 +8,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) -func genVars(table Table, withCache bool) (string, error) { +func genVars(table Table, withCache, postgreSql bool) (string, error) { keys := make([]string, 0) keys = append(keys, table.PrimaryCacheKey.VarExpression) for _, v := range table.UniqueCacheKey { @@ -27,8 +27,9 @@ func genVars(table Table, withCache bool) (string, error) { "upperStartCamelObject": camel, "cacheKeys": strings.Join(keys, "\n"), "autoIncrement": table.PrimaryKey.AutoIncrement, - "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()), + "originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source(), postgreSql), "withCache": withCache, + "postgreSql": postgreSql, }) if err != nil { return "", err diff --git a/tools/goctl/model/sql/model/postgresqlmodel.go b/tools/goctl/model/sql/model/postgresqlmodel.go new file mode 100644 index 00000000..2f8e6329 --- /dev/null +++ b/tools/goctl/model/sql/model/postgresqlmodel.go @@ -0,0 +1,234 @@ +package model + +import ( + "database/sql" + "strings" + + "github.com/tal-tech/go-zero/core/stores/sqlx" +) + +var ( + p2m = map[string]string{ + "int8": "bigint", + "numeric": "bigint", + "float8": "double", + "float4": "float", + "int2": "smallint", + "int4": "integer", + } +) + +// PostgreSqlModel gets table information from information_schema、pg_catalog +type PostgreSqlModel struct { + conn sqlx.SqlConn +} + +// PostgreColumn describes a column in table +type PostgreColumn struct { + Num sql.NullInt32 `db:"num"` + Field sql.NullString `db:"field"` + Type sql.NullString `db:"type"` + NotNull sql.NullBool `db:"not_null"` + Comment sql.NullString `db:"comment"` + ColumnDefault sql.NullString `db:"column_default"` + IdentityIncrement sql.NullInt32 `db:"identity_increment"` +} + +// PostgreIndex describes an index for a column +type PostgreIndex struct { + IndexName sql.NullString `db:"index_name"` + IndexId sql.NullInt32 `db:"index_id"` + IsUnique sql.NullBool `db:"is_unique"` + IsPrimary sql.NullBool `db:"is_primary"` + ColumnName sql.NullString `db:"column_name"` + IndexSort sql.NullInt32 `db:"index_sort"` +} + +// NewPostgreSqlModel creates an instance and return +func NewPostgreSqlModel(conn sqlx.SqlConn) *PostgreSqlModel { + return &PostgreSqlModel{ + conn: conn, + } +} + +// GetAllTables selects all tables from TABLE_SCHEMA +func (m *PostgreSqlModel) GetAllTables(schema string) ([]string, error) { + query := `select table_name from information_schema.tables where table_schema = $1` + var tables []string + err := m.conn.QueryRows(&tables, query, schema) + if err != nil { + return nil, err + } + + return tables, nil +} + +// FindColumns return columns in specified database and table +func (m *PostgreSqlModel) FindColumns(schema, table string) (*ColumnData, error) { + querySql := `select t.num,t.field,t.type,t.not_null,t.comment, c.column_default, identity_increment +from ( + SELECT a.attnum AS num, + c.relname, + a.attname AS field, + t.typname AS type, + a.atttypmod AS lengthvar, + a.attnotnull AS not_null, + b.description AS comment + FROM pg_class c, + pg_attribute a + LEFT OUTER JOIN pg_description b ON a.attrelid = b.objoid AND a.attnum = b.objsubid, + pg_type t + WHERE c.relname = $1 + and a.attnum > 0 + and a.attrelid = c.oid + and a.atttypid = t.oid + ORDER BY a.attnum) AS t + left join information_schema.columns AS c on t.relname = c.table_name + and t.field = c.column_name and c.table_schema = $2` + + var reply []*PostgreColumn + err := m.conn.QueryRowsPartial(&reply, querySql, table, schema) + if err != nil { + return nil, err + } + + list, err := m.getColumns(schema, table, reply) + if err != nil { + return nil, err + } + + var columnData ColumnData + columnData.Db = schema + columnData.Table = table + columnData.Columns = list + return &columnData, nil +} + +func (m *PostgreSqlModel) getColumns(schema, table string, in []*PostgreColumn) ([]*Column, error) { + index, err := m.getIndex(schema, table) + if err != nil { + return nil, err + } + var list []*Column + for _, e := range in { + var dft interface{} + if len(e.ColumnDefault.String) > 0 { + dft = e.ColumnDefault + } + + isNullAble := "YES" + if e.NotNull.Bool { + isNullAble = "NO" + } + + extra := "auto_increment" + if e.IdentityIncrement.Int32 != 1 { + extra = "" + } + + if len(index[e.Field.String]) > 0 { + for _, i := range index[e.Field.String] { + list = append(list, &Column{ + DbColumn: &DbColumn{ + Name: e.Field.String, + DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String), + Extra: extra, + Comment: e.Comment.String, + ColumnDefault: dft, + IsNullAble: isNullAble, + OrdinalPosition: int(e.Num.Int32), + }, + Index: i, + }) + } + } else { + list = append(list, &Column{ + DbColumn: &DbColumn{ + Name: e.Field.String, + DataType: m.convertPostgreSqlTypeIntoMysqlType(e.Type.String), + Extra: extra, + Comment: e.Comment.String, + ColumnDefault: dft, + IsNullAble: isNullAble, + OrdinalPosition: int(e.Num.Int32), + }, + }) + } + } + + return list, nil +} + +func (m *PostgreSqlModel) convertPostgreSqlTypeIntoMysqlType(in string) string { + r, ok := p2m[strings.ToLower(in)] + if ok { + return r + } + + return in +} + +func (m *PostgreSqlModel) getIndex(schema, table string) (map[string][]*DbIndex, error) { + indexes, err := m.FindIndex(schema, table) + if err != nil { + return nil, err + } + var index = make(map[string][]*DbIndex) + for _, e := range indexes { + if e.IsPrimary.Bool { + index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{ + IndexName: indexPri, + SeqInIndex: int(e.IndexSort.Int32), + }) + continue + } + + nonUnique := 0 + if !e.IsUnique.Bool { + nonUnique = 1 + } + + index[e.ColumnName.String] = append(index[e.ColumnName.String], &DbIndex{ + IndexName: e.IndexName.String, + NonUnique: nonUnique, + SeqInIndex: int(e.IndexSort.Int32), + }) + } + return index, nil +} + +// FindIndex finds index with given schema, table and column. +func (m *PostgreSqlModel) FindIndex(schema, table string) ([]*PostgreIndex, error) { + querySql := `select A.INDEXNAME AS index_name, + C.INDEXRELID AS index_id, + C.INDISUNIQUE AS is_unique, + C.INDISPRIMARY AS is_primary, + G.ATTNAME AS column_name, + G.attnum AS index_sort +from PG_AM B + left join PG_CLASS F on + B.OID = F.RELAM + left join PG_STAT_ALL_INDEXES E on + F.OID = E.INDEXRELID + left join PG_INDEX C on + E.INDEXRELID = C.INDEXRELID + left outer join PG_DESCRIPTION D on + C.INDEXRELID = D.OBJOID, + PG_INDEXES A, + pg_attribute G +where A.SCHEMANAME = E.SCHEMANAME + and A.TABLENAME = E.RELNAME + and A.INDEXNAME = E.INDEXRELNAME + and F.oid = G.attrelid + and E.SCHEMANAME = $1 + and E.RELNAME = $2 + order by C.INDEXRELID,G.attnum` + + var reply []*PostgreIndex + err := m.conn.QueryRowsPartial(&reply, querySql, schema, table) + if err != nil { + return nil, err + } + + return reply, nil +} diff --git a/tools/goctl/model/sql/template/delete.go b/tools/goctl/model/sql/template/delete.go index 4bb6bc76..b6dac7d8 100644 --- a/tools/goctl/model/sql/template/delete.go +++ b/tools/goctl/model/sql/template/delete.go @@ -10,9 +10,9 @@ func (m *default{{.upperStartCamelObject}}Model) Delete({{.lowerStartCamelPrimar {{.keys}} _, err {{if .containsIndexCache}}={{else}}:={{end}} m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = ?", m.table) + query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table) return conn.Exec(query, {{.lowerStartCamelPrimaryKey}}) - }, {{.keyValues}}){{else}}query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = ?", m.table) + }, {{.keyValues}}){{else}}query := fmt.Sprintf("delete from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table) _,err:=m.conn.Exec(query, {{.lowerStartCamelPrimaryKey}}){{end}} return err } diff --git a/tools/goctl/model/sql/template/find.go b/tools/goctl/model/sql/template/find.go index 2e23531a..c1962d65 100644 --- a/tools/goctl/model/sql/template/find.go +++ b/tools/goctl/model/sql/template/find.go @@ -6,7 +6,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOne({{.lowerStartCamelPrima {{if .withCache}}{{.cacheKey}} var resp {{.upperStartCamelObject}} err := m.QueryRow(&resp, {{.cacheKeyVariable}}, func(conn sqlx.SqlConn, v interface{}) error { - query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table) + query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}} limit 1", {{.lowerStartCamelObject}}Rows, m.table) return conn.QueryRow(v, query, {{.lowerStartCamelPrimaryKey}}) }) switch err { @@ -16,7 +16,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOne({{.lowerStartCamelPrima return nil, ErrNotFound default: return nil, err - }{{else}}query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table) + }{{else}}query := fmt.Sprintf("select %s from %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}} limit 1", {{.lowerStartCamelObject}}Rows, m.table) var resp {{.upperStartCamelObject}} err := m.conn.QueryRow(&resp, query, {{.lowerStartCamelPrimaryKey}}) switch err { @@ -71,7 +71,7 @@ func (m *default{{.upperStartCamelObject}}Model) formatPrimary(primary interface } func (m *default{{.upperStartCamelObject}}Model) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error { - query := fmt.Sprintf("select %s from %s where {{.originalPrimaryField}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table ) + query := fmt.Sprintf("select %s from %s where {{.originalPrimaryField}} = {{if .postgreSql}}$1{{else}}?{{end}} limit 1", {{.lowerStartCamelObject}}Rows, m.table ) return conn.QueryRow(v, query, primary) } ` diff --git a/tools/goctl/model/sql/template/new.go b/tools/goctl/model/sql/template/new.go index 77ada32e..e7c0657c 100644 --- a/tools/goctl/model/sql/template/new.go +++ b/tools/goctl/model/sql/template/new.go @@ -5,7 +5,7 @@ var New = ` func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) {{.upperStartCamelObject}}Model { return &default{{.upperStartCamelObject}}Model{ {{if .withCache}}CachedConn: sqlc.NewConn(conn, c){{else}}conn:conn{{end}}, - table: "{{.table}}", + table: {{.table}}, } } ` diff --git a/tools/goctl/model/sql/template/update.go b/tools/goctl/model/sql/template/update.go index 17486623..04fdfe1e 100644 --- a/tools/goctl/model/sql/template/update.go +++ b/tools/goctl/model/sql/template/update.go @@ -5,9 +5,9 @@ var Update = ` func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error { {{if .withCache}}{{.keys}} _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) + query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) return conn.Exec(query, {{.expressionValues}}) - }, {{.keyValues}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) + }, {{.keyValues}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = {{if .postgreSql}}$1{{else}}?{{end}}", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) _,err:=m.conn.Exec(query, {{.expressionValues}}){{end}} return err } diff --git a/tools/goctl/model/sql/template/vars.go b/tools/goctl/model/sql/template/vars.go index af36e0f4..d56be3b2 100644 --- a/tools/goctl/model/sql/template/vars.go +++ b/tools/goctl/model/sql/template/vars.go @@ -5,11 +5,14 @@ import "fmt" // Vars defines a template for var block in model var Vars = fmt.Sprintf(` var ( - {{.lowerStartCamelObject}}FieldNames = builderx.RawFieldNames(&{{.upperStartCamelObject}}{}) + {{.lowerStartCamelObject}}FieldNames = builderx.RawFieldNames(&{{.upperStartCamelObject}}{}{{if .postgreSql}},true{{end}}) {{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",") - {{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ",") - {{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?" + {{.lowerStartCamelObject}}RowsExpectAutoSet = {{if .postgreSql}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{else}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{if .autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "%screate_time%s", "%supdate_time%s"), ","){{end}} + {{.lowerStartCamelObject}}RowsWithPlaceHolder = {{if .postgreSql}}builderx.PostgreSqlJoin(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s")){{else}}strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "%screate_time%s", "%supdate_time%s"), "=?,") + "=?"{{end}} {{if .withCache}}{{.cacheKeys}}{{end}} ) -`, "`", "`", "`", "`", "`", "`", "`", "`") +`, "", "", "", "", // postgreSql mode + "`", "`", "`", "`", + "", "", "", "", // postgreSql mode + "`", "`", "`", "`")