You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
2.4 KiB
Go
114 lines
2.4 KiB
Go
package cli
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
|
)
|
|
|
|
type test struct {
|
|
source []string
|
|
expected string
|
|
expectedErr error
|
|
}
|
|
|
|
func Test_GetSourceProto(t *testing.T) {
|
|
pwd, err := os.Getwd()
|
|
if err != nil {
|
|
console.Error(err.Error())
|
|
return
|
|
}
|
|
|
|
var testData = []test{
|
|
{
|
|
source: []string{"a.proto"},
|
|
expected: filepath.Join(pwd, "a.proto"),
|
|
},
|
|
{
|
|
source: []string{"/foo/bar/a.proto"},
|
|
expected: "/foo/bar/a.proto",
|
|
},
|
|
{
|
|
source: []string{"a.proto", "b.proto"},
|
|
expectedErr: errMultiInput,
|
|
},
|
|
{
|
|
source: []string{"", "--go_out=."},
|
|
expectedErr: errInvalidInput,
|
|
},
|
|
}
|
|
|
|
for _, d := range testData {
|
|
ret, err := getSourceProto(d.source, pwd)
|
|
if d.expectedErr != nil {
|
|
assert.Equal(t, d.expectedErr, err)
|
|
continue
|
|
}
|
|
|
|
assert.Equal(t, d.expected, ret)
|
|
}
|
|
}
|
|
|
|
func Test_RemoveGoctlFlag(t *testing.T) {
|
|
var testData = []test{
|
|
{
|
|
source: strings.Fields("protoc foo.proto --go_out=. --go_opt=bar --zrpc_out=. --style go-zero --home=foo"),
|
|
expected: "protoc foo.proto --go_out=. --go_opt=bar",
|
|
},
|
|
{
|
|
source: strings.Fields("foo bar foo.proto"),
|
|
expected: "foo bar foo.proto",
|
|
},
|
|
{
|
|
source: strings.Fields("protoc foo.proto --go_out . --style=go_zero --home ."),
|
|
expected: "protoc foo.proto --go_out .",
|
|
},
|
|
{
|
|
source: strings.Fields(`protoc foo.proto --go_out . --style="go_zero" --home="."`),
|
|
expected: "protoc foo.proto --go_out .",
|
|
},
|
|
{
|
|
source: strings.Fields(`protoc foo.proto --go_opt=. --zrpc_out . --style=goZero --home=bar`),
|
|
expected: "protoc foo.proto --go_opt=.",
|
|
},
|
|
{
|
|
source: strings.Fields(`protoc foo.proto --go_opt=. --zrpc_out="bar" --style=goZero --home=bar`),
|
|
expected: "protoc foo.proto --go_opt=.",
|
|
},
|
|
}
|
|
for _, e := range testData {
|
|
cmd := strings.Join(removeGoctlFlag(e.source), " ")
|
|
assert.Equal(t, e.expected, cmd)
|
|
}
|
|
}
|
|
|
|
func Test_RemovePluginFlag(t *testing.T) {
|
|
var testData = []test{
|
|
{
|
|
source: strings.Fields("plugins=grpc:."),
|
|
expected: ".",
|
|
},
|
|
{
|
|
source: strings.Fields("plugins=g1,g2:."),
|
|
expected: ".",
|
|
},
|
|
{
|
|
source: strings.Fields("g1,g2:."),
|
|
expected: ".",
|
|
},
|
|
{
|
|
source: strings.Fields("plugins=g1,g2:foo"),
|
|
expected: "foo",
|
|
},
|
|
}
|
|
|
|
for _, e := range testData {
|
|
data := removePluginFlag(e.source[0])
|
|
assert.Equal(t, e.expected, data)
|
|
}
|
|
}
|