diff --git a/tools/goctl/api/format/format.go b/tools/goctl/api/format/format.go index 1e237bf0..645af2c5 100644 --- a/tools/goctl/api/format/format.go +++ b/tools/goctl/api/format/format.go @@ -4,25 +4,17 @@ import ( "bufio" "errors" "fmt" - "go/format" "go/scanner" "io/ioutil" "os" "path/filepath" - "regexp" - "strconv" "strings" "github.com/tal-tech/go-zero/core/errorx" - "github.com/tal-tech/go-zero/tools/goctl/api/parser" "github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/urfave/cli" ) -var ( - reg = regexp.MustCompile("type (?P.*)[\\s]+{") -) - func GoFormatApi(c *cli.Context) error { useStdin := c.Bool("stdin") @@ -65,10 +57,7 @@ func ApiFormatByStdin() error { return err } - result, err := apiFormat(string(data)) - if err != nil { - return err - } + result := apiFormat(string(data)) _, err = fmt.Print(result) if err != nil { @@ -83,98 +72,28 @@ func ApiFormatByPath(apiFilePath string) error { return err } - result, err := apiFormat(string(data)) - if err != nil { - return err - } - + result := apiFormat(string(data)) if err := ioutil.WriteFile(apiFilePath, []byte(result), os.ModePerm); err != nil { return err } return nil } -func apiFormat(data string) (string, error) { - r := reg.ReplaceAllStringFunc(data, func(m string) string { - parts := reg.FindStringSubmatch(m) - if len(parts) < 2 { - return m - } - if !strings.Contains(m, "struct") { - return "type " + parts[1] + " struct {" - } - return m - }) - - apiStruct, err := parser.ParseApi(r) - if err != nil { - return "", err - } - info := strings.TrimSpace(apiStruct.Info) - if len(apiStruct.Service) == 0 { - return data, nil - } - - fs, err := format.Source([]byte(strings.TrimSpace(apiStruct.Type))) - if err != nil { - str := err.Error() - lineNumber := strings.Index(str, ":") - if lineNumber > 0 { - ln, err := strconv.ParseInt(str[:lineNumber], 10, 64) - if err != nil { - return "", err - } - pn := 0 - if len(info) > 0 { - pn = countRune(info, '\n') + 1 - } - number := int(ln) + pn + 1 - return "", errors.New(fmt.Sprintf("line: %d, %s", number, str[lineNumber+1:])) - } - return "", err - } - - var result string - if len(strings.TrimSpace(info)) > 0 { - result += strings.TrimSpace(info) + "\n\n" - } - if len(strings.TrimSpace(apiStruct.Imports)) > 0 { - result += strings.TrimSpace(apiStruct.Imports) + "\n\n" - } - if len(strings.TrimSpace(string(fs))) > 0 { - result += strings.TrimSpace(string(fs)) + "\n\n" - } - if len(strings.TrimSpace(apiStruct.Service)) > 0 { - result += formatService(apiStruct.Service) + "\n\n" - } - - return strings.TrimSpace(result), nil -} - -func formatService(str string) string { +func apiFormat(data string) string { var builder strings.Builder - scanner := bufio.NewScanner(strings.NewReader(str)) + scanner := bufio.NewScanner(strings.NewReader(data)) var tapCount = 0 for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) - if line == ")" || line == "}" { + noCommentLine := util.RemoveComment(line) + if noCommentLine == ")" || noCommentLine == "}" { tapCount -= 1 } util.WriteIndent(&builder, tapCount) builder.WriteString(line + "\n") - if strings.HasSuffix(line, "(") || strings.HasSuffix(line, "{") { + if strings.HasSuffix(noCommentLine, "(") || strings.HasSuffix(noCommentLine, "{") { tapCount += 1 } } return strings.TrimSpace(builder.String()) } - -func countRune(s string, r rune) int { - count := 0 - for _, c := range s { - if c == r { - count++ - } - } - return count -} diff --git a/tools/goctl/api/format/format_test.go b/tools/goctl/api/format/format_test.go index a0f9444a..77701c35 100644 --- a/tools/goctl/api/format/format_test.go +++ b/tools/goctl/api/format/format_test.go @@ -41,7 +41,6 @@ service A-api { ) func TestInlineTypeNotExist(t *testing.T) { - r, err := apiFormat(notFormattedStr) - assert.Nil(t, err) + r := apiFormat(notFormattedStr) assert.Equal(t, r, formattedStr) } diff --git a/tools/goctl/api/gogen/gen_test.go b/tools/goctl/api/gogen/gen_test.go index 5d595f2e..8f984daf 100644 --- a/tools/goctl/api/gogen/gen_test.go +++ b/tools/goctl/api/gogen/gen_test.go @@ -283,6 +283,26 @@ service A-api { } ` +const noStructTagApi = ` +type Request { + Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` +} + +type XXX { +} + +type ( + Response { + Message string ` + "`" + `json:"message"` + "`" + ` + } +) + +service A-api { + @handler GreetHandler + get /greet/from/:name(Request) returns (Response) +} +` + func TestParser(t *testing.T) { filename := "greet.api" err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm) @@ -501,6 +521,22 @@ func TestHasImportApi(t *testing.T) { validate(t, filename) } +func TestNoStructApi(t *testing.T) { + filename := "greet.api" + err := ioutil.WriteFile(filename, []byte(noStructTagApi), os.ModePerm) + assert.Nil(t, err) + defer os.Remove(filename) + + parser, err := parser.NewParser(filename) + assert.Nil(t, err) + + spec, err := parser.Parse() + assert.Nil(t, err) + assert.Equal(t, len(spec.Types), 3) + + validate(t, filename) +} + func validate(t *testing.T, api string) { dir := "_go" os.RemoveAll(dir) diff --git a/tools/goctl/api/new/newservice.go b/tools/goctl/api/new/newservice.go index b5f506c5..18f811ab 100644 --- a/tools/goctl/api/new/newservice.go +++ b/tools/goctl/api/new/newservice.go @@ -11,11 +11,11 @@ import ( ) const apiTemplate = ` -type Request struct { +type Request { Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` } -type Response struct { +type Response { Message string ` + "`" + `json:"message"` + "`" + ` } diff --git a/tools/goctl/api/parser/apifileparser.go b/tools/goctl/api/parser/apifileparser.go index 8e13cb7d..5d5db3e2 100644 --- a/tools/goctl/api/parser/apifileparser.go +++ b/tools/goctl/api/parser/apifileparser.go @@ -7,6 +7,9 @@ import ( "fmt" "io" "strings" + + "github.com/tal-tech/go-zero/core/stringx" + "github.com/tal-tech/go-zero/tools/goctl/api/util" ) const ( @@ -15,6 +18,7 @@ const ( tokenType = "type" tokenService = "service" tokenServiceAnnotation = "@server" + tokenStruct = "struct" ) type ( @@ -72,7 +76,7 @@ func ParseApi(src string) (*ApiStruct, error) { } } -func (s *apiRootState) process(api *ApiStruct, token string) (apiFileState, error) { +func (s *apiRootState) process(api *ApiStruct, _ string) (apiFileState, error) { var builder strings.Builder for { ch, err := s.readSkipComment() @@ -124,7 +128,7 @@ func (s *apiInfoState) process(api *ApiStruct, token string) (apiFileState, erro return nil, err } - api.Info += "\n" + token + line + api.Info += newline + token + line token = "" if strings.TrimSpace(line) == string(rightParenthesis) { return &apiRootState{s.baseState}, nil @@ -139,12 +143,12 @@ func (s *apiImportState) process(api *ApiStruct, token string) (apiFileState, er } line = token + line - line = removeComment(line) + line = util.RemoveComment(line) if len(strings.Fields(line)) != 2 { return nil, errors.New("import syntax error: " + line) } - api.Imports += "\n" + line + api.Imports += newline + line return &apiRootState{s.baseState}, nil } @@ -156,11 +160,14 @@ func (s *apiTypeState) process(api *ApiStruct, token string) (apiFileState, erro return nil, err } - api.Type += "\n\n" + token + line - token = "" - line = strings.TrimSpace(line) - line = removeComment(line) + line = token + line + if blockCount <= 1 { + line = mayInsertStructKeyword(line) + } + api.Type += newline + newline + line line = strings.TrimSpace(line) + line = util.RemoveComment(line) + token = "" if strings.HasSuffix(line, leftBrace) { blockCount++ @@ -191,10 +198,9 @@ func (s *apiServiceState) process(api *ApiStruct, token string) (apiFileState, e line = token + line token = "" - api.Service += "\n" + line - line = strings.TrimSpace(line) - line = removeComment(line) + api.Service += newline + line line = strings.TrimSpace(line) + line = util.RemoveComment(line) if strings.HasSuffix(line, leftBrace) { blockCount++ @@ -215,10 +221,30 @@ func (s *apiServiceState) process(api *ApiStruct, token string) (apiFileState, e } } -func removeComment(line string) string { - var commentIdx = strings.Index(line, "//") - if commentIdx >= 0 { - return line[:commentIdx] +func mayInsertStructKeyword(line string) string { + line = util.RemoveComment(line) + if !strings.HasSuffix(line, leftBrace) { + return line } - return line + + fields := strings.Fields(line) + if stringx.Contains(fields, tokenStruct) || stringx.Contains(fields, tokenStruct+leftBrace) || len(fields) <= 1 { + return line + } + + var insertIndex int + if fields[0] == tokenType { + insertIndex = 2 + } else { + insertIndex = 1 + } + if insertIndex >= len(fields) { + return line + } + + var result []string + result = append(result, fields[:insertIndex]...) + result = append(result, tokenStruct) + result = append(result, fields[insertIndex:]...) + return strings.Join(result, " ") } diff --git a/tools/goctl/api/parser/parser.go b/tools/goctl/api/parser/parser.go index 0fc83abf..ce6f079c 100644 --- a/tools/goctl/api/parser/parser.go +++ b/tools/goctl/api/parser/parser.go @@ -15,9 +15,8 @@ import ( ) type Parser struct { - r *bufio.Reader - typeDef string - api *ApiStruct + r *bufio.Reader + api *ApiStruct } func NewParser(filename string) (*Parser, error) { @@ -73,15 +72,14 @@ func NewParser(filename string) (*Parser, error) { var buffer = new(bytes.Buffer) buffer.WriteString(apiStruct.Service) return &Parser{ - r: bufio.NewReader(buffer), - typeDef: apiStruct.Type, - api: apiStruct, + r: bufio.NewReader(buffer), + api: apiStruct, }, nil } func (p *Parser) Parse() (api *spec.ApiSpec, err error) { api = new(spec.ApiSpec) - var sp = StructParser{Src: p.typeDef} + var sp = StructParser{Src: p.api.Type} types, err := sp.Parse() if err != nil { return nil, err diff --git a/tools/goctl/api/parser/vars.go b/tools/goctl/api/parser/vars.go index ba6ed42b..f7811b86 100644 --- a/tools/goctl/api/parser/vars.go +++ b/tools/goctl/api/parser/vars.go @@ -16,4 +16,5 @@ const ( multilineBeginTag = '>' multilineEndTag = '<' semicolon = ';' + newline = "\n" ) diff --git a/tools/goctl/api/util/util.go b/tools/goctl/api/util/util.go index 5ddf7a26..12f8a994 100644 --- a/tools/goctl/api/util/util.go +++ b/tools/goctl/api/util/util.go @@ -81,3 +81,11 @@ func WriteIndent(writer io.Writer, indent int) { fmt.Fprint(writer, "\t") } } + +func RemoveComment(line string) string { + var commentIdx = strings.Index(line, "//") + if commentIdx >= 0 { + return strings.TrimSpace(line[:commentIdx]) + } + return strings.TrimSpace(line) +}