You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/rest/httpc/requests.go

197 lines
4.5 KiB
Go

package httpc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
nurl "net/url"
"strings"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/trace"
"github.com/zeromicro/go-zero/rest/httpc/internal"
"github.com/zeromicro/go-zero/rest/internal/header"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
oteltrace "go.opentelemetry.io/otel/trace"
)
var interceptors = []internal.Interceptor{
internal.LogInterceptor,
}
// Do sends an HTTP request with the given arguments and returns an HTTP response.
// data is automatically marshal into a *httpRequest, typically it's defined in an API file.
func Do(ctx context.Context, method, url string, data interface{}) (*http.Response, error) {
req, err := buildRequest(ctx, method, url, data)
if err != nil {
return nil, err
}
return DoRequest(req)
}
// DoRequest sends an HTTP request and returns an HTTP response.
func DoRequest(r *http.Request) (*http.Response, error) {
return request(r, defaultClient{})
}
type (
client interface {
do(r *http.Request) (*http.Response, error)
}
defaultClient struct{}
)
func (c defaultClient) do(r *http.Request) (*http.Response, error) {
return http.DefaultClient.Do(r)
}
func buildFormQuery(u *nurl.URL, val map[string]interface{}) string {
query := u.Query()
for k, v := range val {
query.Add(k, fmt.Sprint(v))
}
return query.Encode()
}
func buildRequest(ctx context.Context, method, url string, data interface{}) (*http.Request, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
}
var val map[string]map[string]interface{}
if data != nil {
val, err = mapping.Marshal(data)
if err != nil {
return nil, err
}
}
if err := fillPath(u, val[pathKey]); err != nil {
return nil, err
}
var reader io.Reader
jsonVars, hasJsonBody := val[jsonKey]
if hasJsonBody {
if method == http.MethodGet {
return nil, ErrGetWithBody
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
if err := enc.Encode(jsonVars); err != nil {
return nil, err
}
reader = &buf
}
req, err := http.NewRequestWithContext(ctx, method, u.String(), reader)
if err != nil {
return nil, err
}
req.URL.RawQuery = buildFormQuery(u, val[formKey])
fillHeader(req, val[headerKey])
if hasJsonBody {
req.Header.Set(header.ContentType, header.JsonContentType)
}
return req, nil
}
func fillHeader(r *http.Request, val map[string]interface{}) {
for k, v := range val {
r.Header.Add(k, fmt.Sprint(v))
}
}
func fillPath(u *nurl.URL, val map[string]interface{}) error {
used := make(map[string]lang.PlaceholderType)
fields := strings.Split(u.Path, slash)
for i := range fields {
field := fields[i]
if len(field) > 0 && field[0] == colon {
name := field[1:]
ival, ok := val[name]
if !ok {
return fmt.Errorf("missing path variable %q", name)
}
value := fmt.Sprint(ival)
if len(value) == 0 {
return fmt.Errorf("empty path variable %q", name)
}
fields[i] = value
used[name] = lang.Placeholder
}
}
if len(val) != len(used) {
for key := range used {
delete(val, key)
}
var unused []string
for key := range val {
unused = append(unused, key)
}
return fmt.Errorf("more path variables are provided: %q", strings.Join(unused, ", "))
}
u.Path = strings.Join(fields, slash)
return nil
}
func request(r *http.Request, cli client) (*http.Response, error) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
propagator := otel.GetTextMapPropagator()
spanName := r.URL.Path
ctx, span := tracer.Start(
r.Context(),
spanName,
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(semconv.HTTPClientAttributesFromHTTPRequest(r)...),
)
defer span.End()
respHandlers := make([]internal.ResponseHandler, len(interceptors))
for i, interceptor := range interceptors {
var h internal.ResponseHandler
r, h = interceptor(r)
respHandlers[i] = h
}
r = r.WithContext(ctx)
propagator.Inject(ctx, propagation.HeaderCarrier(r.Header))
resp, err := cli.do(r)
for i := len(respHandlers) - 1; i >= 0; i-- {
respHandlers[i](resp, err)
}
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return resp, err
}
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(resp.StatusCode)...)
span.SetStatus(semconv.SpanStatusFromHTTPStatusCode(resp.StatusCode))
return resp, err
}