You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/tools/goctl/model/sqlmodel/genmodel.go

186 lines
4.7 KiB
Go

package sqlmodel
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/token"
"strings"
"text/template"
"github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
var (
commonTemplate = `
var (
{{.modelWithLowerStart}}FieldNames = builderx.FieldNames(&{{.model}}{})
{{.modelWithLowerStart}}Rows = strings.Join({{.modelWithLowerStart}}FieldNames, ",")
{{.modelWithLowerStart}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.modelWithLowerStart}}FieldNames, {{.expected}}), ",")
{{.modelWithLowerStart}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.modelWithLowerStart}}FieldNames, {{.expected}}), "=?,") + "=?"
)
`
)
type (
fieldExp struct {
name string
tag string
ty string
}
structExp struct {
genStruct GenStruct
name string
idAutoIncrement bool
Fields []fieldExp
primaryKey string
conditions []string
ignoreFields []string
}
)
func NewStructExp(s GenStruct) (*structExp, error) {
src := s.TableStruct
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", src, 0)
if err != nil {
return nil, err
}
if len(file.Decls) == 0 {
return nil, sqlError("no struct inspect")
}
typeDecl := file.Decls[0].(*ast.GenDecl)
if len(typeDecl.Specs) == 0 {
return nil, sqlError("no type specs")
}
typeSpec := typeDecl.Specs[0].(*ast.TypeSpec)
structDecl := typeSpec.Type.(*ast.StructType)
var strExp structExp
strExp.genStruct = s
strExp.primaryKey = s.PrimaryKey
strExp.name = typeSpec.Name.Name
fields := structDecl.Fields.List
for _, field := range fields {
typeExpr := field.Type
start := typeExpr.Pos() - 1
end := typeExpr.End() - 1
typeInSource := src[start:end]
if len(field.Names) == 0 {
return nil, sqlError("field name empty")
}
var name = field.Names[0].Name
var tag = util.Untitle(name)
if field.Tag != nil {
tag = src[field.Tag.Pos():field.Tag.End()]
}
dbTag := getDBColumnName(tag, name)
strExp.Fields = append(strExp.Fields, fieldExp{
name: name,
ty: typeInSource,
tag: dbTag,
})
if dbTag == strExp.primaryKey && typeInSource == "int64" {
strExp.ignoreFields = append(strExp.ignoreFields, dbTag)
strExp.idAutoIncrement = true
}
if name == "UpdateTime" || name == "CreateTime" {
strExp.ignoreFields = append(strExp.ignoreFields, dbTag)
}
}
return &strExp, nil
}
func (s *structExp) genMysqlCRUD() (string, error) {
commonStr, err := s.genCommon()
if err != nil {
return "", err
}
insertStr, err := s.genInsert()
if err != nil {
return "", err
}
updateStr, err := s.genUpdate()
if err != nil {
return "", err
}
deleteStr, err := s.genDelete()
if err != nil {
return "", err
}
queryOneStr, err := s.genQueryOne()
if err != nil {
return "", err
}
queryListStr, err := s.genQueryList()
if err != nil {
return "", err
}
return strings.Join([]string{"package model \n", commonStr, s.genStruct.TableModel, queryOneStr, queryListStr, deleteStr, insertStr, updateStr}, "\n"), nil
}
func getDBColumnName(tag, name string) string {
matches := spec.TagRe.FindStringSubmatch(tag)
for i := range matches {
name := spec.TagSubNames[i]
if name == "name" {
return matches[i]
}
}
return util.Untitle(name)
}
func (s *structExp) genCommon() (string, error) {
templateName := commonTemplate
t := template.Must(template.New("commonTemplate").Parse(templateName))
var tmplBytes bytes.Buffer
var ignoreFieldsQuota []string
for _, item := range s.ignoreFields {
ignoreFieldsQuota = append(ignoreFieldsQuota, fmt.Sprintf("\"%s\"", item))
}
err := t.Execute(&tmplBytes, map[string]string{
"model": s.name,
"expected": strings.Join(ignoreFieldsQuota, ", "),
"modelWithLowerStart": fmtUnderLine2Camel(s.name, false),
})
if err != nil {
return "", err
}
return tmplBytes.String(), nil
}
func (s *structExp) buildCondition() (string, string) {
var conditionExp []string
var valueConditions []string
for _, field := range s.Fields {
if stringx.Contains(s.conditions, strings.ToLower(field.tag)) ||
stringx.Contains(s.conditions, strings.ToLower(field.name)) {
conditionExp = append(conditionExp, fmt.Sprintf("%s %s", util.Untitle(field.name), field.ty))
valueConditions = append(valueConditions, fmt.Sprintf("%s = ?", field.tag))
}
}
return strings.Join(conditionExp, ", "), strings.Join(valueConditions, " and ")
}
func (s *structExp) conditionNames() []string {
var conditionExp []string
for _, field := range s.Fields {
if stringx.Contains(s.conditions, strings.ToLower(field.tag)) ||
stringx.Contains(s.conditions, strings.ToLower(field.name)) {
conditionExp = append(conditionExp, util.Untitle(field.name))
}
}
return conditionExp
}