optimize code (#579)

* optimize code

* optimize returns & unit test
master v1.1.6
anqiansong 4 years ago committed by GitHub
parent bd623aaac3
commit 888551627c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
} }
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error { func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
defineStruct, ok := ty.(spec.DefineStruct) defineStruct, done, err := c.checkStruct(ty)
if !ok { if done {
return errors.New("unsupported type %s" + ty.Name()) return err
}
for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return nil
}
}
} }
modelFile := util.Title(ty.Name()) + ".java" modelFile := util.Title(ty.Name()) + ".java"
@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
return err return err
} }
func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
defineStruct, ok := ty.(spec.DefineStruct)
if !ok {
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
}
for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return spec.DefineStruct{}, true, nil
}
}
}
return defineStruct, false, nil
}
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) { func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
var builder strings.Builder var builder strings.Builder
if err := c.writeType(&builder, defineStruct); err != nil { if err := c.writeType(&builder, defineStruct); err != nil {

@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", err return "", err
} }
switch valueType { s := getBaseType(valueType)
case "int": if len(s) == 0 {
return "Integer[]", nil return s, errors.New("unsupported primitive type " + tp.Name())
case "long":
return "Long[]", nil
case "float":
return "Float[]", nil
case "double":
return "Double[]", nil
case "boolean":
return "Boolean[]", nil
} }
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", errors.New("unsupported primitive type " + tp.Name()) return "", errors.New("unsupported primitive type " + tp.Name())
} }
func getBaseType(valueType string) string {
switch valueType {
case "int":
return "Integer[]"
case "long":
return "Long[]"
case "float":
return "Float[]"
case "double":
return "Double[]"
case "boolean":
return "Boolean[]"
default:
return ""
}
}
func primitiveType(tp string) (string, bool) { func primitiveType(tp string) (string, bool) {
switch tp { switch tp {
case "string": case "string":

@ -6,6 +6,8 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
@ -19,7 +21,10 @@ var (
) )
func TestFromDDl(t *testing.T) { func TestFromDDl(t *testing.T) {
err := fromDDl("./user.sql", t.TempDir(), cfg, true, false) err := gen.Clean()
assert.Nil(t, err)
err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
assert.Equal(t, errNotMatched, err) assert.Equal(t, errNotMatched, err)
// case dir is not exists // case dir is not exists

@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
var list []string var list []string
camelTableName := table.Name.ToCamel() camelTableName := table.Name.ToCamel()
for _, key := range table.UniqueCacheKey { for _, key := range table.UniqueCacheKey {
var inJoin, paramJoin, argJoin Join in, paramJoinString, originalFieldString := convertJoin(key)
for _, f := range key.Fields {
param := stringx.From(f.Name.ToCamel()).Untitle()
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
paramJoin = append(paramJoin, param)
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
}
var in string
if len(inJoin) > 0 {
in = inJoin.With(", ").Source()
}
var paramJoinString string
if len(paramJoin) > 0 {
paramJoinString = paramJoin.With(",").Source()
}
var originalFieldString string
if len(argJoin) > 0 {
originalFieldString = argJoin.With(" and ").Source()
}
output, err := t.Execute(map[string]interface{}{ output, err := t.Execute(map[string]interface{}{
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
findOneInterfaceMethod: strings.Join(listMethod, util.NL), findOneInterfaceMethod: strings.Join(listMethod, util.NL),
}, nil }, nil
} }
func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
var inJoin, paramJoin, argJoin Join
for _, f := range key.Fields {
param := stringx.From(f.Name.ToCamel()).Untitle()
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
paramJoin = append(paramJoin, param)
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
}
if len(inJoin) > 0 {
in = inJoin.With(", ").Source()
}
if len(paramJoin) > 0 {
paramJoinString = paramJoin.With(",").Source()
}
if len(argJoin) > 0 {
originalFieldString = argJoin.With(" and ").Source()
}
return in, paramJoinString, originalFieldString
}

@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
} }
} }
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
}
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
log := console.NewColorConsole() log := console.NewColorConsole()
uniqueSet := collection.NewSet() uniqueSet := collection.NewSet()
for k, i := range uniqueIndex { for k, i := range uniqueIndex {
@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {
normalIndexSet.Add(joinRet) normalIndexSet.Add(joinRet)
} }
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
} }
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) { func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
@ -289,28 +292,10 @@ func ConvertDataType(table *model.Table) (*Table, error) {
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"), AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
} }
fieldM := make(map[string]*Field) fieldM, err := getTableFields(table)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
if err != nil { if err != nil {
return nil, err return nil, err
} }
columnSeqInIndex := 0
if each.Index != nil {
columnSeqInIndex = each.Index.SeqInIndex
}
field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,
OrdinalPosition: each.OrdinalPosition,
}
fieldM[each.Name] = field
}
for _, each := range fieldM { for _, each := range fieldM {
reply.Fields = append(reply.Fields, each) reply.Fields = append(reply.Fields, each)
@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {
return &reply, nil return &reply, nil
} }
func getTableFields(table *model.Table) (map[string]*Field, error) {
fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
if err != nil {
return nil, err
}
columnSeqInIndex := 0
if each.Index != nil {
columnSeqInIndex = each.Index.SeqInIndex
}
field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,
OrdinalPosition: each.OrdinalPosition,
}
fieldM[each.Name] = field
}
return fieldM, nil
}

Loading…
Cancel
Save