diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index 76919a87..6eb6e9ef 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -62,7 +62,7 @@ func NewErrorf(code int64, format string, args ...interface{}) *Error { // You must call Run for the connection to be active. func NewConn(s Stream) *Conn { conn := &Conn{ - handlers: []Handler{defaultHandler{}, &tracer{}}, + handlers: []Handler{defaultHandler{}}, stream: s, pending: make(map[ID]chan *WireResponse), handling: make(map[ID]*Request), diff --git a/internal/lsp/cmd/serve.go b/internal/lsp/cmd/serve.go index 275b0020..5e7d46da 100644 --- a/internal/lsp/cmd/serve.go +++ b/internal/lsp/cmd/serve.go @@ -83,7 +83,7 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { // For debugging purposes only. run := func(ctx context.Context, srv *lsp.Server) { - srv.Conn.AddHandler(&handler{trace: s.Trace, out: out}) + srv.Conn.AddHandler(&handler{loggingRPCs: s.Trace, out: out}) go srv.Run(ctx) } if s.Address != "" { @@ -94,7 +94,7 @@ func (s *Serve) Run(ctx context.Context, args ...string) error { } stream := jsonrpc2.NewHeaderStream(os.Stdin, os.Stdout) ctx, srv := lsp.NewServer(ctx, s.app.cache, stream) - srv.Conn.AddHandler(&handler{trace: s.Trace, out: out}) + srv.Conn.AddHandler(&handler{loggingRPCs: s.Trace, out: out}) return srv.Run(ctx) } @@ -119,8 +119,8 @@ func (s *Serve) forward() error { } type handler struct { - trace bool - out io.Writer + loggingRPCs bool + out io.Writer } type rpcStats struct { @@ -129,6 +129,7 @@ type rpcStats struct { id *jsonrpc2.ID payload *json.RawMessage start time.Time + close func() } type statsKeyType int @@ -144,49 +145,63 @@ func (h *handler) Cancel(ctx context.Context, conn *jsonrpc2.Conn, id jsonrpc2.I } func (h *handler) Request(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireRequest) context.Context { - if !h.trace { - return ctx + if r.Method == "" { + panic("no method in rpc stats") } - stats := &rpcStats{ + s := &rpcStats{ method: r.Method, - direction: direction, start: time.Now(), + direction: direction, payload: r.Params, } - ctx = context.WithValue(ctx, statsKey, stats) + mode := telemetry.Outbound + if direction == jsonrpc2.Receive { + mode = telemetry.Inbound + } + ctx, s.close = trace.StartSpan(ctx, r.Method, + tag.Tag{Key: telemetry.Method, Value: r.Method}, + tag.Tag{Key: telemetry.RPCDirection, Value: mode}, + tag.Tag{Key: telemetry.RPCID, Value: r.ID}, + ) + telemetry.Started.Record(ctx, 1) return ctx } func (h *handler) Response(ctx context.Context, direction jsonrpc2.Direction, r *jsonrpc2.WireResponse) context.Context { stats := h.getStats(ctx) - h.log(direction, r.ID, 0, stats.method, r.Result, nil) + h.logRPC(direction, r.ID, 0, stats.method, r.Result, nil) return ctx } func (h *handler) Done(ctx context.Context, err error) { - if !h.trace { - return - } stats := h.getStats(ctx) - h.log(stats.direction, stats.id, time.Since(stats.start), stats.method, stats.payload, err) + h.logRPC(stats.direction, stats.id, time.Since(stats.start), stats.method, stats.payload, err) + if err != nil { + ctx = telemetry.StatusCode.With(ctx, "ERROR") + } else { + ctx = telemetry.StatusCode.With(ctx, "OK") + } + elapsedTime := time.Since(stats.start) + latencyMillis := float64(elapsedTime) / float64(time.Millisecond) + telemetry.Latency.Record(ctx, latencyMillis) + stats.close() } func (h *handler) Read(ctx context.Context, bytes int64) context.Context { + telemetry.SentBytes.Record(ctx, bytes) return ctx } func (h *handler) Wrote(ctx context.Context, bytes int64) context.Context { + telemetry.ReceivedBytes.Record(ctx, bytes) return ctx } const eol = "\r\n\r\n\r\n" func (h *handler) Error(ctx context.Context, err error) { - if !h.trace { - return - } stats := h.getStats(ctx) - h.log(stats.direction, stats.id, 0, stats.method, nil, err) + h.logRPC(stats.direction, stats.id, 0, stats.method, nil, err) } func (h *handler) getStats(ctx context.Context) *rpcStats { @@ -194,13 +209,14 @@ func (h *handler) getStats(ctx context.Context) *rpcStats { if !ok || stats == nil { stats = &rpcStats{ method: "???", + close: func() {}, } } return stats } -func (h *handler) log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err error) { - if !h.trace { +func (h *handler) logRPC(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed time.Duration, method string, payload *json.RawMessage, err error) { + if !h.loggingRPCs { return } const eol = "\r\n\r\n\r\n" @@ -249,90 +265,3 @@ func (h *handler) log(direction jsonrpc2.Direction, id *jsonrpc2.ID, elapsed tim fmt.Fprintf(outx, ".\r\nParams: %s%s", params, eol) fmt.Fprintf(h.out, "%s", outx.String()) } - -type rpcStats struct { - server bool - method string - close func() - start time.Time -} - -func start(ctx context.Context, server bool, method string, id *ID) (context.Context, *rpcStats) { - if method == "" { - panic("no method in rpc stats") - } - s := &rpcStats{ - server: server, - method: method, - start: time.Now(), - } - mode := telemetry.Outbound - if server { - mode = telemetry.Inbound - } - ctx, s.close = trace.StartSpan(ctx, method, - tag.Tag{Key: telemetry.Method, Value: method}, - tag.Tag{Key: telemetry.RPCDirection, Value: mode}, - tag.Tag{Key: telemetry.RPCID, Value: id}, - ) - telemetry.Started.Record(ctx, 1) - return ctx, s -} - -func (s *rpcStats) end(ctx context.Context, err *error) { - if err != nil && *err != nil { - ctx = telemetry.StatusCode.With(ctx, "ERROR") - } else { - ctx = telemetry.StatusCode.With(ctx, "OK") - } - elapsedTime := time.Since(s.start) - latencyMillis := float64(elapsedTime) / float64(time.Millisecond) - telemetry.Latency.Record(ctx, latencyMillis) - s.close() -} - -type statsKeyType int - -const statsKey = statsKeyType(0) - -type tracer struct { -} - -func (h *tracer) Deliver(ctx context.Context, r *Request, delivered bool) bool { - return false -} - -func (h *tracer) Cancel(ctx context.Context, conn *Conn, id ID, cancelled bool) bool { - return false -} - -func (h *tracer) Request(ctx context.Context, direction Direction, r *WireRequest) context.Context { - ctx, stats := start(ctx, direction == Receive, r.Method, r.ID) - ctx = context.WithValue(ctx, statsKey, stats) - return ctx -} - -func (h *tracer) Response(ctx context.Context, direction Direction, r *WireResponse) context.Context { - return ctx -} - -func (h *tracer) Done(ctx context.Context, err error) { - stats, ok := ctx.Value(statsKey).(*rpcStats) - if ok && stats != nil { - stats.end(ctx, &err) - } -} - -func (h *tracer) Read(ctx context.Context, bytes int64) context.Context { - telemetry.SentBytes.Record(ctx, bytes) - return ctx -} - -func (h *tracer) Wrote(ctx context.Context, bytes int64) context.Context { - telemetry.ReceivedBytes.Record(ctx, bytes) - return ctx -} - -func (h *tracer) Error(ctx context.Context, err error) { - log.Printf("%v", err) -}