diff --git a/tools/goctl/api/javagen/gencomponents.go b/tools/goctl/api/javagen/gencomponents.go index 85dd9d4d..98bbff73 100644 --- a/tools/goctl/api/javagen/gencomponents.go +++ b/tools/goctl/api/javagen/gencomponents.go @@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error { } func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error { - defineStruct, ok := ty.(spec.DefineStruct) - if !ok { - return errors.New("unsupported type %s" + ty.Name()) - } - - for _, item := range c.requestTypes { - if item.Name() == defineStruct.Name() { - if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 { - return nil - } - } + defineStruct, done, err := c.checkStruct(ty) + if done { + return err } modelFile := util.Title(ty.Name()) + ".java" @@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type return err } +func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) { + defineStruct, ok := ty.(spec.DefineStruct) + if !ok { + return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name()) + } + + for _, item := range c.requestTypes { + if item.Name() == defineStruct.Name() { + if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 { + return spec.DefineStruct{}, true, nil + } + } + } + return defineStruct, false, nil +} + func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) { var builder strings.Builder if err := c.writeType(&builder, defineStruct); err != nil { diff --git a/tools/goctl/api/javagen/util.go b/tools/goctl/api/javagen/util.go index fd55a68a..80a70b05 100644 --- a/tools/goctl/api/javagen/util.go +++ b/tools/goctl/api/javagen/util.go @@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) { return "", err } - switch valueType { - case "int": - return "Integer[]", nil - case "long": - return "Long[]", nil - case "float": - return "Float[]", nil - case "double": - return "Double[]", nil - case "boolean": - return "Boolean[]", nil + s := getBaseType(valueType) + if len(s) == 0 { + return s, errors.New("unsupported primitive type " + tp.Name()) } return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil @@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) { return "", errors.New("unsupported primitive type " + tp.Name()) } +func getBaseType(valueType string) string { + switch valueType { + case "int": + return "Integer[]" + case "long": + return "Long[]" + case "float": + return "Float[]" + case "double": + return "Double[]" + case "boolean": + return "Boolean[]" + default: + return "" + } +} + func primitiveType(tp string) (string, bool) { switch tp { case "string": diff --git a/tools/goctl/model/sql/command/command_test.go b/tools/goctl/model/sql/command/command_test.go index a78ae555..6dbb9917 100644 --- a/tools/goctl/model/sql/command/command_test.go +++ b/tools/goctl/model/sql/command/command_test.go @@ -6,6 +6,8 @@ import ( "path/filepath" "testing" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" + "github.com/stretchr/testify/assert" "github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/util" @@ -19,7 +21,10 @@ var ( ) func TestFromDDl(t *testing.T) { - err := fromDDl("./user.sql", t.TempDir(), cfg, true, false) + err := gen.Clean() + assert.Nil(t, err) + + err = fromDDl("./user.sql", t.TempDir(), cfg, true, false) assert.Equal(t, errNotMatched, err) // case dir is not exists diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index 9c23261a..53cccacf 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { var list []string camelTableName := table.Name.ToCamel() for _, key := range table.UniqueCacheKey { - var inJoin, paramJoin, argJoin Join - for _, f := range key.Fields { - param := stringx.From(f.Name.ToCamel()).Untitle() - inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType)) - paramJoin = append(paramJoin, param) - argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source()))) - } - var in string - if len(inJoin) > 0 { - in = inJoin.With(", ").Source() - } - - var paramJoinString string - if len(paramJoin) > 0 { - paramJoinString = paramJoin.With(",").Source() - } - - var originalFieldString string - if len(argJoin) > 0 { - originalFieldString = argJoin.With(" and ").Source() - } + in, paramJoinString, originalFieldString := convertJoin(key) output, err := t.Execute(map[string]interface{}{ "upperStartCamelObject": camelTableName, @@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) { findOneInterfaceMethod: strings.Join(listMethod, util.NL), }, nil } + +func convertJoin(key Key) (in, paramJoinString, originalFieldString string) { + var inJoin, paramJoin, argJoin Join + for _, f := range key.Fields { + param := stringx.From(f.Name.ToCamel()).Untitle() + inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType)) + paramJoin = append(paramJoin, param) + argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source()))) + } + if len(inJoin) > 0 { + in = inJoin.With(", ").Source() + } + + if len(paramJoin) > 0 { + paramJoinString = paramJoin.With(",").Source() + } + + if len(argJoin) > 0 { + originalFieldString = argJoin.With(" and ").Source() + } + return in, paramJoinString, originalFieldString +} diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index cfe45273..15d3d9a4 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) { } } + checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex) + return &Table{ + Name: stringx.From(tableName), + PrimaryKey: primaryKey, + UniqueIndex: uniqueIndex, + NormalIndex: normalIndex, + Fields: fields, + }, nil +} + +func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) { log := console.NewColorConsole() uniqueSet := collection.NewSet() for k, i := range uniqueIndex { @@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) { normalIndexSet.Add(joinRet) } - - return &Table{ - Name: stringx.From(tableName), - PrimaryKey: primaryKey, - UniqueIndex: uniqueIndex, - NormalIndex: normalIndex, - Fields: fields, - }, nil } func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) { @@ -289,27 +292,9 @@ func ConvertDataType(table *model.Table) (*Table, error) { AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"), } - fieldM := make(map[string]*Field) - for _, each := range table.Columns { - isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES" - dt, err := converter.ConvertDataType(each.DataType, isDefaultNull) - if err != nil { - return nil, err - } - columnSeqInIndex := 0 - if each.Index != nil { - columnSeqInIndex = each.Index.SeqInIndex - } - - field := &Field{ - Name: stringx.From(each.Name), - DataBaseType: each.DataType, - DataType: dt, - Comment: each.Comment, - SeqInIndex: columnSeqInIndex, - OrdinalPosition: each.OrdinalPosition, - } - fieldM[each.Name] = field + fieldM, err := getTableFields(table) + if err != nil { + return nil, err } for _, each := range fieldM { @@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) { return &reply, nil } + +func getTableFields(table *model.Table) (map[string]*Field, error) { + fieldM := make(map[string]*Field) + for _, each := range table.Columns { + isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES" + dt, err := converter.ConvertDataType(each.DataType, isDefaultNull) + if err != nil { + return nil, err + } + columnSeqInIndex := 0 + if each.Index != nil { + columnSeqInIndex = each.Index.SeqInIndex + } + + field := &Field{ + Name: stringx.From(each.Name), + DataBaseType: each.DataType, + DataType: dt, + Comment: each.Comment, + SeqInIndex: columnSeqInIndex, + OrdinalPosition: each.OrdinalPosition, + } + fieldM[each.Name] = field + } + return fieldM, nil +}