Middleware Extensions

Introduction

Kitex, as a lightweight RPC framework, offers powerful extensibility and primarily provides two methods of extension: one is a relatively low-level approach that involves adding middleware directly, and the other is a higher-level approach that involves adding suites. The following mainly introduces the usage of middleware.

Middleware

Middleware is a relatively low level of extension. Most of the Kitex-based extension and secondary development functions are based on middleware to achieve. Kitex’s Middleware is defined in pkg/endpoint/endpoint.go, the most important of which are two types:

  1. Endpoint is a function that accepts ctx, req, resp and returns err. Please refer to the following “Example” code.
  2. Middleware (hereinafter referred to as MW) is also a function that receives and returns an Endpoint.
type Middleware func(Endpoint) Endpoint

In fact, a middleware is essentially a function that takes an Endpoint as input and returns an Endpoint as output. This ensures transparency to the application, as the application itself is unaware of whether it is being decorated by middleware. Due to this feature, middlewares can be nested and used in combination.

Middlewares are used in a chained manner. By invoking the provided next function, you can obtain the response (if any) and error returned by the subsequent middleware. Based on this, you can perform the necessary processing and return an error to the previous middleware (be sure to check for errors returned by next and avoid swallowing errors) or set the response accordingly.

Client Middleware

There are two ways to add Client Middleware:

  1. Client.WithMiddleware adds Middleware to the current client, which is executed after service circuit breaker and timeout Middleware ;
  2. Client.WithInstanceMW adds middleware to the current client, which is executed after service discovery and load balance. If there is an instance circuit breaker, it will be executed after the instance circuit breaker (if Proxy is used, it will not be called, such as in Mesh mode). Note that the above functions should all be passed as options when creating the client.

Client Middleware call sequence:

  1. XDS routing, service level circuit breaker , timeout;
  2. ContextMiddleware;
  3. Middleware set by Client.WithMiddleware ;
  4. ACLMiddleware;
  5. Service Discovery , Instance circuit breaker , Instance-Level Middleware/Service Discovery, Proxy Middleware
  6. IOErrorHandleMW

The above can be seen in https://github.com/cloudwego/kitex/blob/develop/client/client.go

Context Middleware

Context Middleware is essentially Client Middleware, but the difference is that it is controlled by ctx whether and which to inject middleware. The introduction of Context Middleware is to provide a method that can inject Client Middleware globally or dynamically. Typical usage scenarios include statistics on which downstream interfaces are called. Middleware can be injected into ctx using ctx = client.WithContextMiddlewares(ctx, mw). Note: Context Middleware executes before middleware set by client.WithMiddleware().

Server Middleware

There are indeed certain differences between server-side middleware and client-side middleware. You can add server-side middleware through server.WithMiddleware, which is used in the same way as the client and passed in through Option when creating the server.

Server Middleware call sequence:

  1. ErrHandleMW
  2. ACLMiddleware
  3. Middleware set by Server.WithMiddleware

The above can be seen in https://github.com/cloudwego/kitex/blob/develop/server/server.go

Example

We can use the following example to see how to use Middleware.

Request/Reponse

If we need to print out the request content before the request, and then print out the response content after the request, we can write the following middleware:

/*
type Request struct {
   Message        string            `thrift:"Message,1,required" frugal:"1,required,string" json:"Message"`
   Base           *base.Base        `thrift:"Base,255,optional" frugal:"255,optional,base.Base" json:"Base,omitempty"`
}

type Response struct {
   Message  string         `thrift:"Message,1,required" frugal:"1,required,string" json:"Message"`
   BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" json:"BaseResp,omitempty"`
}
*/
import "github.com/cloudwego/kitex/pkg/utils"

func ExampleMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
	return func(ctx context.Context, request, response interface{}) error {
		if arg, ok := request.(utils.KitexArgs); ok {
			if req := arg.GetFirstArgument().(*echo.Request; req != nil {
				klog.Debugf("Request Message: %v", req.Message)
			}
		}
		err := next(ctx, request, response)
		if result, ok := response.(utils.KitexResult); ok {
			if resp, ok := result.GetResult().(*echo.Response); ok {
				klog.Debugf("Response Message: %v", resp.Message)
                // resp.SetSuccess(...) could be used to replace customized response
				// But notice: the type should be the same with the response of this method
			}
		}
		return err
	}
}

The provided example is for illustrative purposes, and it is indeed important to exercise caution when implementing such logging practices in a production environment. Logging every request and response indiscriminately can indeed have performance implications, especially when dealing with large response bodies.

Precautions

If RPCInfo is used in custom middleware, be aware that RPCInfo will be recycled after the rpc ends, so if you use goroutine operation RPCInfo in middleware, there will be issues . Please avoid such operations .

gRPC Middleware

As we all know, in addition to Thrift, Kitex also supports the protobuf and gRPC encoding/decoding protocols. In the case of protobuf, it refers to using protobuf exclusively to define the payload format, and the service definition only includes unary methods. However, if streaming methods are introduced, Kitex will use the gRPC protocol for encoding/decoding and communication.

For services using protobuf (unary only), the development of middleware remains consistent with the previous context, as the design of both is identical.

However, if streaming methods are used, the development of middleware is completely different. Therefore, the usage of gRPC streaming middleware is explained separately as a standalone unit.

For streaming methods, such as client stream, server stream, bidirectional stream, etc., and considering that the sending and receiving of messages (Recv & Send) have their own business logic control, middleware can not cover the messages themselves. Therefore, if you want to implement request/response logging at the message level during Send/Recv operations, you need to wrap Kitex’s streaming.Stream as follows:

type wrappedStream struct {
        streaming.Stream
}

func (w *wrappedStream) RecvMsg(m interface{}) error {
        log.Printf("Receive a message: %T(%v)", m, m)
        return w.Stream.RecvMsg(m)
}

func (w *wrappedStream) SendMsg(m interface{}) error {
        log.Printf("Send a message: %T(%v)", m, m)
        return w.Stream.SendMsg(m)
}

func newWrappedStream(s streaming.Stream) streaming.Stream {
        return &wrappedStream{s}
}

Then, within the middleware, insert the wrapped streaming.Stream object at specific invocation points.

import "github.com/cloudwego/kitex/pkg/streaming"

// A middleware that can be used for both client-side and server-side in Kitex with gRPC/Thrift/TTheader-protobuf
func DemoGRPCMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, req, res interface{}) error {

        var Nil interface{} // can not switch nil directly in go
        switch Nil {
        case req: // The current middleware is used for the client-side and specifically designed for streaming methods
            err := next(ctx, req, res)
            // The stream object can only be obtained after the final endpoint returns
            if tmp, ok := res.(*streaming.Result); err == nil && ok {
                tmp.Stream = newWrappedStream(tmp.Stream) // wrap stream object
            }
            return err
        case res: // The current middleware is used for the server-side and specifically designed for streaming methods
            if tmp, ok := req.(*streaming.Args); ok {
                tmp.Stream = newWrappedStream(tmp.Stream) // wrap stream object
            }
        default: // pure unary method, or thrift method
            // do something else
        }
        return next(ctx, req, res)
    }
}

Explanation of the request/response parameter types obtained within the Kitex middleware in different scenarios of gRPC:

Scenario Request Type Response Type
Kitex-gRPC Server Unary/Streaming *streaming.Args nil
Kitex-gRPC Client Unary *xxxservice.XXXMethodArgs *xxxservice.XXXMethodResult
Kitex-gRPC Client Streaming nil *streaming.Result

Summary

Middleware is indeed a lower-level implementation of extensions, typically used to inject simple code containing specific functionalities. However, in complex scenarios, single middleware may not be sufficient to meet the business requirements. In such cases, a more comprehensive approach is needed, which involves assembling multiple middlewares or options into a complete middleware layer. Users can develop this requirement based on suites, refer to Suite Extend

FAQ

How to recover handler panic in middleware

Question: A handler who wanted to recover their own business in middleware threw a panic and found that the panic had already been recovered by the framework.

Description: The framework will recover and report the panic in Handler. If you want to capture panic in custom middleware, you can determine the type of error returned in middleware (whether it is kerrors.ErrPanic).

func TestServerMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
   return func(ctx context.Context, req, resp interface{}) (err error) {
      err = next(ctx, req, resp)
      if errors.Is(err, kerrors.ErrPanic) {
         fmt.Println("capture panic")
      }
      return err
   }
}

How to get the real Request/Response in Middleware?

Due to implementation needs, the req and resp passed in middlewares are not the req and resp passed by the real user, but an object wrapped by Kitex, specifically a structure similar to the following.

Thrift

// req
type ${XService}${XMethod}Args struct {
    Req *${XRequest} `thrift:"req,1" json:"req"`
}

func (p *${XService}${XMethod}Args) GetFirstArgument() interface{} {
    return p.Req
}


// resp
type ${XService}${XMethod}Result struct {
    Success *${XResponse} `thrift:"success,0" json:"success,omitempty"`
}

func (p *${XService}${XMethod}Result) GetResult() interface{} {
    return p.Success
}

Protobuf

// req
type ${XMethod}Args struct {
    Req *${XRequest}
}

func (p *${XMethod}Args) GetReq() *${XRequest} {
    if !p.IsSetReq() {
        return ${XMethod}Args_Req_DEFAULT
    }
    return p.Req
}


// resp
type ${XMethod}Result struct {
    Success *${XResponse}
}

func (p *${XMethod}Result) GetSuccess() *${XResponse} {
    if !p.IsSetSuccess() {
        return ${XMethod}Result_Success_DEFAULT
    }
    return p.Success
}

The above generated code can be seen in kitex_gen directory. Therefore, there are three solutions for the business side to obtain the real req and resp:

  1. If you can determine which method is being called and the type of req used, you can directly obtain the specific Args type through type assertion, and then obtain the real req through the GetReq method.
  2. For thrift generated code, by asserting GetFirstArgument or GetResult , obtain interface{}, and then do type assertion to the real req or resp (Note: Since the returned interface{} contains a type, judging interface{} nil cannot intercept the case where req/resp itself is a null pointer, so we need to judge whether the asserted req/resp is a null pointer again);
  3. Obtain the real request/response body through reflection method, refer to the code:
var ExampleMW endpoint.Middleware = func(next endpoint.Endpoint) endpoint.Endpoint {
    return func(ctx context.Context, request, response interface{}) error {
        reqV := reflect.ValueOf(request).MethodByName("GetReq").Call(nil)[0]
        log.Infof(ctx, "request: %T", reqV.Interface())
        err := next(ctx, request, response)
        respV := reflect.ValueOf(response).MethodByName("GetSuccess").Call(nil)[0]
        log.Infof(ctx, "response: %T", respV.Interface())
        return err
    }
} 

Last modified January 5, 2024 : fix: link error (#917) (b1a0a93)