From 8b927904ee0dec805c89aaf9172f4459296ed6e8 Mon Sep 17 00:00:00 2001 From: Ian Cottrell Date: Fri, 12 Jul 2019 00:43:12 -0400 Subject: [PATCH] internal/jsonrpc2: extract logic to handler hooks Change-Id: Ief531e4b68fcb0dbc71e263c185fb285a9042479 Reviewed-on: https://go-review.googlesource.com/c/tools/+/185983 Reviewed-by: Rebecca Stambler --- internal/jsonrpc2/handler.go | 43 ++++++++- internal/jsonrpc2/jsonrpc2.go | 148 +++++++++++++++++++---------- internal/jsonrpc2/jsonrpc2_test.go | 61 ++++++++---- internal/jsonrpc2/wire.go | 8 +- internal/lsp/cmd/serve.go | 70 +++++++++++++- internal/lsp/protocol/protocol.go | 7 +- 6 files changed, 255 insertions(+), 82 deletions(-) diff --git a/internal/jsonrpc2/handler.go b/internal/jsonrpc2/handler.go index a1175bf0..0b04c903 100644 --- a/internal/jsonrpc2/handler.go +++ b/internal/jsonrpc2/handler.go @@ -6,8 +6,6 @@ package jsonrpc2 import ( "context" - "encoding/json" - "time" ) // Handler is the interface used to hook into the mesage handling of an rpc @@ -38,7 +36,26 @@ type Handler interface { // method is the method name specified in the message // payload is the parameters for a call or notification, and the result for a // response - Log(direction Direction, id *ID, elapsed time.Duration, method string, payload *json.RawMessage, err *Error) + + // Request is called near the start of processing any request. + Request(ctx context.Context, direction Direction, r *WireRequest) context.Context + // Response is called near the start of processing any response. + Response(ctx context.Context, direction Direction, r *WireResponse) context.Context + // Done is called when any request is fully processed. + // For calls, this means the response has also been processed, for notifies + // this is as soon as the message has been written to the stream. + // If err is set, it implies the request failed. + Done(ctx context.Context, err error) + // Read is called with a count each time some data is read from the stream. + // The read calls are delayed until after the data has been interpreted so + // that it can be attributed to a request/response. + Read(ctx context.Context, bytes int64) context.Context + // Wrote is called each time some data is written to the stream. + Wrote(ctx context.Context, bytes int64) context.Context + // Error is called with errors that cannot be delivered through the normal + // mechanisms, for instance a failure to process a notify cannot be delivered + // back to the other party. + Error(ctx context.Context, err error) } // Direction is used to indicate to a logger whether the logged message was being @@ -73,9 +90,27 @@ func (EmptyHandler) Cancel(ctx context.Context, conn *Conn, id ID, cancelled boo return false } -func (EmptyHandler) Log(direction Direction, id *ID, elapsed time.Duration, method string, payload *json.RawMessage, err *Error) { +func (EmptyHandler) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context { + return ctx } +func (EmptyHandler) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context { + return ctx +} + +func (EmptyHandler) Done(ctx context.Context, err error) { +} + +func (EmptyHandler) Read(ctx context.Context, bytes int64) context.Context { + return ctx +} + +func (EmptyHandler) Wrote(ctx context.Context, bytes int64) context.Context { + return ctx +} + +func (EmptyHandler) Error(ctx context.Context, err error) {} + type defaultHandler struct{ EmptyHandler } func (defaultHandler) Deliver(ctx context.Context, r *Request, delivered bool) bool { diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index 7b765d23..61aa4e21 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -11,6 +11,7 @@ import ( "context" "encoding/json" "fmt" + "log" "sync" "sync/atomic" "time" @@ -28,7 +29,7 @@ type Conn struct { stream Stream err error pendingMu sync.Mutex // protects the pending map - pending map[ID]chan *wireResponse + pending map[ID]chan *WireResponse handlingMu sync.Mutex // protects the handling map handling map[ID]*Request } @@ -47,18 +48,11 @@ const ( type Request struct { conn *Conn cancel context.CancelFunc - start time.Time state requestState nextRequest chan struct{} - // Method is a string containing the method name to invoke. - Method string - // Params is either a struct or an array with the parameters of the method. - Params *json.RawMessage - // The id of this request, used to tie the response back to the request. - // Will be either a string or a number. If not set, the request is a notify, - // and no response is possible. - ID *ID + // The Wire values of the request. + WireRequest } type rpcStats struct { @@ -115,9 +109,9 @@ func NewErrorf(code int64, format string, args ...interface{}) *Error { // You must call Run for the connection to be active. func NewConn(s Stream) *Conn { conn := &Conn{ - handlers: []Handler{defaultHandler{}}, + handlers: []Handler{defaultHandler{}, &tracer{}}, stream: s, - pending: make(map[ID]chan *wireResponse), + pending: make(map[ID]chan *WireResponse), handling: make(map[ID]*Request), } return conn @@ -150,14 +144,11 @@ func (c *Conn) Cancel(id ID) { // It will return as soon as the notification has been sent, as no response is // possible. func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (err error) { - ctx, rpcStats := start(ctx, false, method, nil) - defer rpcStats.end(ctx, &err) - jsonParams, err := marshalToRaw(params) if err != nil { return fmt.Errorf("marshalling notify parameters: %v", err) } - request := &wireRequest{ + request := &WireRequest{ Method: method, Params: jsonParams, } @@ -166,10 +157,17 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e return fmt.Errorf("marshalling notify request: %v", err) } for _, h := range c.handlers { - h.Log(Send, nil, -1, request.Method, request.Params, nil) + ctx = h.Request(ctx, Send, request) } + defer func() { + for _, h := range c.handlers { + h.Done(ctx, err) + } + }() n, err := c.stream.Write(ctx, data) - telemetry.SentBytes.Record(ctx, n) + for _, h := range c.handlers { + ctx = h.Wrote(ctx, n) + } return err } @@ -179,13 +177,11 @@ func (c *Conn) Notify(ctx context.Context, method string, params interface{}) (e func (c *Conn) Call(ctx context.Context, method string, params, result interface{}) (err error) { // generate a new request identifier id := ID{Number: atomic.AddInt64(&c.seq, 1)} - ctx, rpcStats := start(ctx, false, method, &id) - defer rpcStats.end(ctx, &err) jsonParams, err := marshalToRaw(params) if err != nil { return fmt.Errorf("marshalling call parameters: %v", err) } - request := &wireRequest{ + request := &WireRequest{ ID: &id, Method: method, Params: jsonParams, @@ -195,9 +191,12 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface if err != nil { return fmt.Errorf("marshalling call request: %v", err) } + for _, h := range c.handlers { + ctx = h.Request(ctx, Send, request) + } // we have to add ourselves to the pending map before we send, otherwise we // are racing the response - rchan := make(chan *wireResponse) + rchan := make(chan *WireResponse) c.pendingMu.Lock() c.pending[id] = rchan c.pendingMu.Unlock() @@ -206,14 +205,15 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface c.pendingMu.Lock() delete(c.pending, id) c.pendingMu.Unlock() + for _, h := range c.handlers { + h.Done(ctx, err) + } }() // now we are ready to send - before := time.Now() - for _, h := range c.handlers { - h.Log(Send, request.ID, -1, request.Method, request.Params, nil) - } n, err := c.stream.Write(ctx, data) - telemetry.SentBytes.Record(ctx, n) + for _, h := range c.handlers { + ctx = h.Wrote(ctx, n) + } if err != nil { // sending failed, we will never get a response, so don't leave it pending return err @@ -221,9 +221,8 @@ func (c *Conn) Call(ctx context.Context, method string, params, result interface // now wait for the response select { case response := <-rchan: - elapsed := time.Since(before) for _, h := range c.handlers { - h.Log(Receive, response.ID, elapsed, request.Method, response.Result, response.Error) + ctx = h.Response(ctx, Receive, response) } // is it an error response? if response.Error != nil { @@ -283,9 +282,6 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro if r.IsNotify() { return fmt.Errorf("reply not invoked with a valid call") } - ctx, close := trace.StartSpan(ctx, r.Method+":reply") - defer close() - // reply ends the handling phase of a call, so if we are not yet // parallel we should be now. The go routine is allowed to continue // to do work after replying, which is why it is important to unlock @@ -293,12 +289,11 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro r.Parallel() r.state = requestReplied - elapsed := time.Since(r.start) var raw *json.RawMessage if err == nil { raw, err = marshalToRaw(result) } - response := &wireResponse{ + response := &WireResponse{ Result: raw, ID: r.ID, } @@ -314,10 +309,12 @@ func (r *Request) Reply(ctx context.Context, result interface{}, err error) erro return err } for _, h := range r.conn.handlers { - h.Log(Send, response.ID, elapsed, r.Method, response.Result, response.Error) + ctx = h.Response(ctx, Send, response) } n, err := r.conn.stream.Write(ctx, data) - telemetry.SentBytes.Record(ctx, n) + for _, h := range r.conn.handlers { + ctx = h.Wrote(ctx, n) + } if err != nil { // TODO(iancottrell): if a stream write fails, we really need to shut down @@ -374,7 +371,7 @@ func (c *Conn) Run(ctx context.Context) error { // a badly formed message arrived, log it and continue // we trust the stream to have isolated the error to just this message for _, h := range c.handlers { - h.Log(Receive, nil, -1, "", nil, NewErrorf(0, "unmarshal failed: %v", err)) + h.Error(ctx, fmt.Errorf("unmarshal failed: %v", err)) } continue } @@ -382,19 +379,23 @@ func (c *Conn) Run(ctx context.Context) error { switch { case msg.Method != "": // if method is set it must be a request - reqCtx, cancelReq := context.WithCancel(ctx) - reqCtx, rpcStats := start(reqCtx, true, msg.Method, msg.ID) - telemetry.ReceivedBytes.Record(ctx, n) + ctx, cancelReq := context.WithCancel(ctx) thisRequest := nextRequest nextRequest = make(chan struct{}) req := &Request{ conn: c, cancel: cancelReq, nextRequest: nextRequest, - start: time.Now(), - Method: msg.Method, - Params: msg.Params, - ID: msg.ID, + WireRequest: WireRequest{ + VersionTag: msg.VersionTag, + Method: msg.Method, + Params: msg.Params, + ID: msg.ID, + }, + } + for _, h := range c.handlers { + ctx = h.Request(ctx, Receive, &req.WireRequest) + ctx = h.Read(ctx, n) } c.setHandling(req, true) go func() { @@ -403,16 +404,17 @@ func (c *Conn) Run(ctx context.Context) error { defer func() { c.setHandling(req, false) if !req.IsNotify() && req.state < requestReplied { - req.Reply(reqCtx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method)) + req.Reply(ctx, nil, NewErrorf(CodeInternalError, "method %q did not reply", req.Method)) } req.Parallel() - rpcStats.end(reqCtx, nil) + for _, h := range c.handlers { + h.Done(ctx, err) + } cancelReq() }() delivered := false for _, h := range c.handlers { - h.Log(Receive, req.ID, -1, req.Method, req.Params, nil) - if h.Deliver(reqCtx, req, delivered) { + if h.Deliver(ctx, req, delivered) { delivered = true } } @@ -426,7 +428,7 @@ func (c *Conn) Run(ctx context.Context) error { } c.pendingMu.Unlock() // and send the reply to the channel - response := &wireResponse{ + response := &WireResponse{ Result: msg.Result, Error: msg.Error, ID: msg.ID, @@ -435,7 +437,7 @@ func (c *Conn) Run(ctx context.Context) error { close(rchan) default: for _, h := range c.handlers { - h.Log(Receive, nil, -1, "", nil, NewErrorf(0, "message not a call, notify or response, ignoring")) + h.Error(ctx, fmt.Errorf("message not a call, notify or response, ignoring")) } } } @@ -449,3 +451,49 @@ func marshalToRaw(obj interface{}) (*json.RawMessage, error) { raw := json.RawMessage(data) return &raw, nil } + +type statsKeyType int + +const statsKey = statsKeyType(0) + +type tracer struct { +} + +func (h *tracer) Deliver(ctx context.Context, r *Request, delivered bool) bool { + return false +} + +func (h *tracer) Cancel(ctx context.Context, conn *Conn, id ID, cancelled bool) bool { + return false +} + +func (h *tracer) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context { + ctx, stats := start(ctx, direction == Receive, r.Method, r.ID) + ctx = context.WithValue(ctx, statsKey, stats) + return ctx +} + +func (h *tracer) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context { + return ctx +} + +func (h *tracer) Done(ctx context.Context, err error) { + stats, ok := ctx.Value(statsKey).(*rpcStats) + if ok && stats != nil { + stats.end(ctx, &err) + } +} + +func (h *tracer) Read(ctx context.Context, bytes int64) context.Context { + telemetry.SentBytes.Record(ctx, bytes) + return ctx +} + +func (h *tracer) Wrote(ctx context.Context, bytes int64) context.Context { + telemetry.ReceivedBytes.Record(ctx, bytes) + return ctx +} + +func (h *tracer) Error(ctx context.Context, err error) { + log.Printf("%v", err) +} diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index 731f5883..53afface 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -108,7 +108,7 @@ func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w stream = jsonrpc2.NewStream(r, w) } conn := jsonrpc2.NewConn(stream) - conn.AddHandler(handle{}) + conn.AddHandler(&handle{log: *logRPC}) go func() { defer func() { r.Close() @@ -121,9 +121,11 @@ func run(ctx context.Context, t *testing.T, withHeaders bool, r io.ReadCloser, w return conn } -type handle struct{ jsonrpc2.EmptyHandler } +type handle struct { + log bool +} -func (handle) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool { +func (h *handle) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool { switch r.Method { case "no_args": if r.Params != nil { @@ -158,18 +160,43 @@ func (handle) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) return true } -func (handle) Log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err *jsonrpc2.Error) { - if !*logRPC { - return - } - switch { - case err != nil: - log.Printf("%v failure [%v] %s %v", direction, id, method, err) - case id == nil: - log.Printf("%v notification %s %s", direction, method, *payload) - case elapsed >= 0: - log.Printf("%v response in %v [%v] %s %s", direction, elapsed, id, method, *payload) - default: - log.Printf("%v call [%v] %s %s", direction, id, method, *payload) - } +func (h *handle) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID, cancelled bool) bool { + return false +} + +func (h *handle) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { + if h.log { + if r.ID != nil { + log.Printf("%v call [%v] %s %s", direction, r.ID, r.Method, r.Params) + } else { + log.Printf("%v notification %s %s", direction, r.Method, r.Params) + } + ctx = context.WithValue(ctx, "method", r.Method) + ctx = context.WithValue(ctx, "start", time.Now()) + } + return ctx +} + +func (h *handle) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { + if h.log { + method := ctx.Value("method") + elapsed := time.Since(ctx.Value("start").(time.Time)) + log.Printf("%v response in %v [%v] %s %s", direction, elapsed, r.ID, method, r.Result) + } + return ctx +} + +func (h *handle) Done(ctx context.Context, err error) { +} + +func (h *handle) Read(ctx context.Context, bytes int64) context.Context { + return ctx +} + +func (h *handle) Wrote(ctx context.Context, bytes int64) context.Context { + return ctx +} + +func (h *handle) Error(ctx context.Context, err error) { + log.Printf("%v", err) } diff --git a/internal/jsonrpc2/wire.go b/internal/jsonrpc2/wire.go index 1e126a41..3e891e07 100644 --- a/internal/jsonrpc2/wire.go +++ b/internal/jsonrpc2/wire.go @@ -34,8 +34,8 @@ const ( CodeServerOverloaded = -32000 ) -// wireRequest is sent to a server to represent a Call or Notify operaton. -type wireRequest struct { +// WireRequest is sent to a server to represent a Call or Notify operaton. +type WireRequest struct { // VersionTag is always encoded as the string "2.0" VersionTag VersionTag `json:"jsonrpc"` // Method is a string containing the method name to invoke. @@ -48,11 +48,11 @@ type wireRequest struct { ID *ID `json:"id,omitempty"` } -// wireResponse is a reply to a Request. +// WireResponse is a reply to a Request. // It will always have the ID field set to tie it back to a request, and will // have either the Result or Error fields set depending on whether it is a // success or failure response. -type wireResponse struct { +type WireResponse struct { // VersionTag is always encoded as the string "2.0" VersionTag VersionTag `json:"jsonrpc"` // Result is the response value, and is required on success. diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go index db7cffbf..ac023128 100644 --- a/internal/lsp/cmd/serve.go +++ b/internal/lsp/cmd/serve.go @@ -120,6 +120,18 @@ type handler struct { out io.Writer } +type rpcStats struct { + method string + direction jsonrpc2.Direction + id *jsonrpc2.ID + payload *json.RawMessage + start time.Time +} + +type statsKeyType int + +const statsKey = statsKeyType(0) + func (h *handler) Deliver(ctx context.Context, r *jsonrpc2.Request, delivered bool) bool { return false } @@ -128,7 +140,63 @@ func (h *handler) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.I return false } -func (h *handler) Log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err *jsonrpc2.Error) { +func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { + if !h.trace { + return ctx + } + stats := &rpcStats{ + method: r.Method, + direction: direction, + start: time.Now(), + payload: r.Params, + } + ctx = context.WithValue(ctx, statsKey, stats) + return ctx +} + +func (h *handler) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { + stats := h.getStats(ctx) + h.log(direction, r.ID, 0, stats.method, r.Result, nil) + return ctx +} + +func (h *handler) Done(ctx context.Context, err error) { + if !h.trace { + return + } + stats := h.getStats(ctx) + h.log(stats.direction, stats.id, time.Since(stats.start), stats.method, stats.payload, err) +} + +func (h *handler) Read(ctx context.Context, bytes int64) context.Context { + return ctx +} + +func (h *handler) Wrote(ctx context.Context, bytes int64) context.Context { + return ctx +} + +const eol = "\r\n\r\n\r\n" + +func (h *handler) Error(ctx context.Context, err error) { + if !h.trace { + return + } + stats := h.getStats(ctx) + h.log(stats.direction, stats.id, 0, stats.method, nil, err) +} + +func (h *handler) getStats(ctx context.Context) *rpcStats { + stats, ok := ctx.Value(statsKey).(*rpcStats) + if !ok || stats == nil { + stats = &rpcStats{ + method: "???", + } + } + return stats +} + +func (h *handler) log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err error) { if !h.trace { return } diff --git a/internal/lsp/protocol/protocol.go b/internal/lsp/protocol/protocol.go index a4146a42..c1bbb767 100644 --- a/internal/lsp/protocol/protocol.go +++ b/internal/lsp/protocol/protocol.go @@ -6,8 +6,6 @@ package protocol import ( "context" - "encoding/json" - "time" "golang.org/x/tools/internal/jsonrpc2" "golang.org/x/tools/internal/lsp/telemetry/trace" @@ -17,7 +15,7 @@ import ( type DocumentUri = string -type canceller struct{} +type canceller struct{ jsonrpc2.EmptyHandler } type clientHandler struct { canceller @@ -42,9 +40,6 @@ func (canceller) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.ID return true } -func (canceller) Log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err *jsonrpc2.Error) { -} - func NewClient(stream jsonrpc2.Stream, client Client) (*jsonrpc2.Conn, Server, xlog.Logger) { log := xlog.New(NewLogger(client)) conn := jsonrpc2.NewConn(stream)