From c3f57e9b0a5bd3fb3dcf0f1fdee2b3411c48705d Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sun, 30 Jul 2023 21:37:41 +0800 Subject: [PATCH] chore: fix potential nil pointer errors (#3454) --- core/iox/bufferpool.go | 4 ++++ core/iox/bufferpool_test.go | 23 +++++++++++++++++++++++ core/mapping/unmarshaler.go | 2 +- core/mapping/utils.go | 2 +- core/mapping/utils_test.go | 8 ++++---- core/stores/sqlx/orm.go | 4 ++-- core/trace/agent.go | 8 +++++--- 7 files changed, 40 insertions(+), 11 deletions(-) diff --git a/core/iox/bufferpool.go b/core/iox/bufferpool.go index d2546c23..8a486f1f 100644 --- a/core/iox/bufferpool.go +++ b/core/iox/bufferpool.go @@ -32,6 +32,10 @@ func (bp *BufferPool) Get() *bytes.Buffer { // Put returns buf into bp. func (bp *BufferPool) Put(buf *bytes.Buffer) { + if buf == nil { + return + } + if buf.Cap() < bp.capability { bp.pool.Put(buf) } diff --git a/core/iox/bufferpool_test.go b/core/iox/bufferpool_test.go index 254e5efd..54bd64f5 100644 --- a/core/iox/bufferpool_test.go +++ b/core/iox/bufferpool_test.go @@ -13,3 +13,26 @@ func TestBufferPool(t *testing.T) { pool.Put(bytes.NewBuffer(make([]byte, 0, 2*capacity))) assert.True(t, pool.Get().Cap() <= capacity) } + +func TestBufferPool_Put(t *testing.T) { + t.Run("with nil buf", func(t *testing.T) { + pool := NewBufferPool(1024) + pool.Put(nil) + val := pool.Get() + assert.IsType(t, new(bytes.Buffer), val) + }) + + t.Run("with less-cap buf", func(t *testing.T) { + pool := NewBufferPool(1024) + pool.Put(bytes.NewBuffer(make([]byte, 0, 512))) + val := pool.Get() + assert.IsType(t, new(bytes.Buffer), val) + }) + + t.Run("with more-cap buf", func(t *testing.T) { + pool := NewBufferPool(1024) + pool.Put(bytes.NewBuffer(make([]byte, 0, 1024<<1))) + val := pool.Get() + assert.IsType(t, new(bytes.Buffer), val) + }) +} diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 0de92762..b6281e42 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -878,7 +878,7 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName string) error { rv := reflect.ValueOf(v) - if err := ValidatePtr(&rv); err != nil { + if err := ValidatePtr(rv); err != nil { return err } diff --git a/core/mapping/utils.go b/core/mapping/utils.go index ae558959..3597c97e 100644 --- a/core/mapping/utils.go +++ b/core/mapping/utils.go @@ -79,7 +79,7 @@ func SetMapIndexValue(tp reflect.Type, value, key, target reflect.Value) { } // ValidatePtr validates v if it's a valid pointer. -func ValidatePtr(v *reflect.Value) error { +func ValidatePtr(v reflect.Value) error { // sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr, // panic otherwise if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() { diff --git a/core/mapping/utils_test.go b/core/mapping/utils_test.go index b8e7ea18..cbb1208d 100644 --- a/core/mapping/utils_test.go +++ b/core/mapping/utils_test.go @@ -218,25 +218,25 @@ func TestParseSegments(t *testing.T) { func TestValidatePtrWithNonPtr(t *testing.T) { var foo string rve := reflect.ValueOf(foo) - assert.NotNil(t, ValidatePtr(&rve)) + assert.NotNil(t, ValidatePtr(rve)) } func TestValidatePtrWithPtr(t *testing.T) { var foo string rve := reflect.ValueOf(&foo) - assert.Nil(t, ValidatePtr(&rve)) + assert.Nil(t, ValidatePtr(rve)) } func TestValidatePtrWithNilPtr(t *testing.T) { var foo *string rve := reflect.ValueOf(foo) - assert.NotNil(t, ValidatePtr(&rve)) + assert.NotNil(t, ValidatePtr(rve)) } func TestValidatePtrWithZeroValue(t *testing.T) { var s string e := reflect.Zero(reflect.TypeOf(s)) - assert.NotNil(t, ValidatePtr(&e)) + assert.NotNil(t, ValidatePtr(e)) } func TestSetValueNotSettable(t *testing.T) { diff --git a/core/stores/sqlx/orm.go b/core/stores/sqlx/orm.go index d250059a..9924eefc 100644 --- a/core/stores/sqlx/orm.go +++ b/core/stores/sqlx/orm.go @@ -146,7 +146,7 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error { } rv := reflect.ValueOf(v) - if err := mapping.ValidatePtr(&rv); err != nil { + if err := mapping.ValidatePtr(rv); err != nil { return err } @@ -182,7 +182,7 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error { func unmarshalRows(v any, scanner rowsScanner, strict bool) error { rv := reflect.ValueOf(v) - if err := mapping.ValidatePtr(&rv); err != nil { + if err := mapping.ValidatePtr(rv); err != nil { return err } diff --git a/core/trace/agent.go b/core/trace/agent.go index 840e28ce..19ae59cd 100644 --- a/core/trace/agent.go +++ b/core/trace/agent.go @@ -26,6 +26,7 @@ const ( kindOtlpGrpc = "otlpgrpc" kindOtlpHttp = "otlphttp" kindFile = "file" + protocolUdp = "udp" ) var ( @@ -65,9 +66,10 @@ func createExporter(c Config) (sdktrace.SpanExporter, error) { // Just support jaeger and zipkin now, more for later switch c.Batcher { case kindJaeger: - u, _ := url.Parse(c.Endpoint) - if u.Scheme == "udp" { - return jaeger.New(jaeger.WithAgentEndpoint(jaeger.WithAgentHost(u.Hostname()), jaeger.WithAgentPort(u.Port()))) + u, err := url.Parse(c.Endpoint) + if err == nil && u.Scheme == protocolUdp { + return jaeger.New(jaeger.WithAgentEndpoint(jaeger.WithAgentHost(u.Hostname()), + jaeger.WithAgentPort(u.Port()))) } return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint))) case kindZipkin: