package javagen import ( "bufio" "bytes" "fmt" "io" "os" "strings" "text/template" "github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/tools/goctl/api/spec" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/util" ) const packetTemplate = `package com.xhb.logic.http.packet.{{.packet}}; import com.google.gson.Gson; import com.xhb.commons.JSON; import com.xhb.commons.JsonParser; import com.xhb.core.network.HttpRequestClient; import com.xhb.core.packet.HttpRequestPacket; import com.xhb.core.response.HttpResponseData; import com.xhb.logic.http.DeProguardable; {{.import}} import org.jetbrains.annotations.NotNull; import org.json.JSONObject; public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packetName}}Response> { {{.paramsDeclaration}} public {{.packetName}}({{.params}}{{.requestType}} request) { super(request); this.request = request;{{.paramsSet}} } @Override public HttpRequestClient.Method requestMethod() { return HttpRequestClient.Method.{{.method}}; } @Override public String requestUri() { return {{.uri}}; } @Override public {{.packetName}}Response newInstanceFrom(JSON json) { return new {{.packetName}}Response(json); } public static class {{.packetName}}Response extends HttpResponseData { private {{.responseType}} responseData; {{.packetName}}Response(@NotNull JSON json) { super(json); JSONObject jsonObject = json.asObject(); if (JsonParser.hasKey(jsonObject, "data")) { Gson gson = new Gson(); JSONObject dataJson = JsonParser.getJSONObject(jsonObject, "data"); responseData = gson.fromJson(dataJson.toString(), {{.responseType}}.class); } } public {{.responseType}} get{{.responseType}} () { return responseData; } } {{.types}} } ` func genPacket(dir, packetName string, api *spec.ApiSpec) error { for _, route := range api.Service.Routes { if err := createWith(dir, api, route, packetName); err != nil { return err } } return nil } func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName string) error { packet, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler") packet = strings.Replace(packet, "Handler", "Packet", 1) if !ok { return fmt.Errorf("missing packet annotation for %q", route.Path) } javaFile := packet + ".java" fp, created, err := apiutil.MaybeCreateFile(dir, "", javaFile) if err != nil { return err } if !created { return nil } defer fp.Close() var builder strings.Builder var first bool tps := apiutil.GetLocalTypes(api, route) for _, tp := range tps { if first { first = false } else { fmt.Fprintln(&builder) } if err := genType(&builder, tp); err != nil { return err } } types := builder.String() writeIndent(&builder, 1) params := paramsForRoute(route) paramsDeclaration := declarationForRoute(route) paramsSet := paramsSet(route) t := template.Must(template.New("packetTemplate").Parse(packetTemplate)) var tmplBytes bytes.Buffer err = t.Execute(&tmplBytes, map[string]string{ "packetName": packet, "method": strings.ToUpper(route.Method), "uri": processUri(route), "types": strings.TrimSpace(types), "responseType": stringx.TakeOne(util.Title(route.ResponseType.Name), "Object"), "params": params, "paramsDeclaration": strings.TrimSpace(paramsDeclaration), "paramsSet": paramsSet, "packet": packetName, "requestType": util.Title(route.RequestType.Name), "import": getImports(api, route, packetName), }) if err != nil { return err } formatFile(&tmplBytes, fp) return nil } func getImports(api *spec.ApiSpec, route spec.Route, packetName string) string { var builder strings.Builder allTypes := apiutil.GetAllTypes(api, route) sharedTypes := apiutil.GetSharedTypes(api) for _, at := range allTypes { for _, item := range sharedTypes { if item.Name == at.Name { fmt.Fprintf(&builder, "import com.xhb.logic.http.packet.%s.model.%s;\n", packetName, item.Name) break } } } return builder.String() } func formatFile(tmplBytes *bytes.Buffer, file *os.File) { scanner := bufio.NewScanner(tmplBytes) builder := bufio.NewWriter(file) defer builder.Flush() preIsBreakLine := false for scanner.Scan() { text := strings.TrimSpace(scanner.Text()) if text == "" && preIsBreakLine { continue } preIsBreakLine = text == "" builder.WriteString(scanner.Text() + "\n") } if err := scanner.Err(); err != nil { println(err) } } func paramsSet(route spec.Route) string { path := route.Path cops := strings.Split(path, "/") var builder strings.Builder for _, cop := range cops { if len(cop) == 0 { continue } if strings.HasPrefix(cop, ":") { param := cop[1:] builder.WriteString("\n") builder.WriteString(fmt.Sprintf("\t\tthis.%s = %s;", param, param)) } } result := builder.String() return result } func paramsForRoute(route spec.Route) string { path := route.Path cops := strings.Split(path, "/") var builder strings.Builder for _, cop := range cops { if len(cop) == 0 { continue } if strings.HasPrefix(cop, ":") { builder.WriteString(fmt.Sprintf("String %s, ", cop[1:])) } } return builder.String() } func declarationForRoute(route spec.Route) string { path := route.Path cops := strings.Split(path, "/") var builder strings.Builder writeIndent(&builder, 1) for _, cop := range cops { if len(cop) == 0 { continue } if strings.HasPrefix(cop, ":") { writeIndent(&builder, 1) builder.WriteString(fmt.Sprintf("private String %s;\n", cop[1:])) } } result := strings.TrimSpace(builder.String()) if len(result) > 0 { result = "\n" + result } return result } func processUri(route spec.Route) string { path := route.Path var builder strings.Builder cops := strings.Split(path, "/") for index, cop := range cops { if len(cop) == 0 { continue } if strings.HasPrefix(cop, ":") { builder.WriteString("/\" + " + cop[1:] + " + \"") } else { builder.WriteString("/" + cop) if index == len(cops)-1 { builder.WriteString("\"") } } } result := builder.String() if strings.HasSuffix(result, " + \"") { result = result[:len(result)-4] } if strings.HasPrefix(result, "/") { result = "\"" + result } return result } func genType(writer io.Writer, tp spec.Type) error { writeIndent(writer, 1) fmt.Fprintf(writer, "static class %s implements DeProguardable {\n", util.Title(tp.Name)) for _, member := range tp.Members { if err := writeProperty(writer, member, 2); err != nil { return err } } writeBreakline(writer) writeIndent(writer, 1) genGetSet(writer, tp, 2) writeIndent(writer, 1) fmt.Fprintln(writer, "}") return nil }