diff --git a/go/analysis/passes/stdmethods/stdmethods.go b/go/analysis/passes/stdmethods/stdmethods.go index eead289e..1c18dbc9 100644 --- a/go/analysis/passes/stdmethods/stdmethods.go +++ b/go/analysis/passes/stdmethods/stdmethods.go @@ -7,10 +7,7 @@ package stdmethods import ( - "bytes" - "fmt" "go/ast" - "go/printer" "go/token" "go/types" "strings" @@ -95,12 +92,12 @@ func run(pass *analysis.Pass) (interface{}, error) { switch n := n.(type) { case *ast.FuncDecl: if n.Recv != nil { - canonicalMethod(pass, n.Name, n.Type) + canonicalMethod(pass, n.Name) } case *ast.InterfaceType: for _, field := range n.Methods.List { for _, id := range field.Names { - canonicalMethod(pass, id, field.Type.(*ast.FuncType)) + canonicalMethod(pass, id) } } } @@ -108,7 +105,7 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } -func canonicalMethod(pass *analysis.Pass, id *ast.Ident, t *ast.FuncType) { +func canonicalMethod(pass *analysis.Pass, id *ast.Ident) { // Expected input/output. expect, ok := canonicalMethods[id.Name] if !ok { @@ -116,11 +113,9 @@ func canonicalMethod(pass *analysis.Pass, id *ast.Ident, t *ast.FuncType) { } // Actual input/output - args := typeFlatten(t.Params.List) - var results []ast.Expr - if t.Results != nil { - results = typeFlatten(t.Results.List) - } + sign := pass.TypesInfo.Defs[id].Type().(*types.Signature) + args := sign.Params() + results := sign.Results() // Do the =s (if any) all match? if !matchParams(pass, expect.args, args, "=") || !matchParams(pass, expect.results, results, "=") { @@ -136,11 +131,7 @@ func canonicalMethod(pass *analysis.Pass, id *ast.Ident, t *ast.FuncType) { expectFmt += " (" + argjoin(expect.results) + ")" } - var buf bytes.Buffer - if err := printer.Fprint(&buf, pass.Fset, t); err != nil { - fmt.Fprintf(&buf, "<%s>", err) - } - actual := buf.String() + actual := sign.String() actual = strings.TrimPrefix(actual, "func") actual = id.Name + actual @@ -159,45 +150,27 @@ func argjoin(x []string) string { return strings.Join(y, ", ") } -// Turn parameter list into slice of types -// (in the ast, types are Exprs). -// Have to handle f(int, bool) and f(x, y, z int) -// so not a simple 1-to-1 conversion. -func typeFlatten(l []*ast.Field) []ast.Expr { - var t []ast.Expr - for _, f := range l { - if len(f.Names) == 0 { - t = append(t, f.Type) - continue - } - for range f.Names { - t = append(t, f.Type) - } - } - return t -} - // Does each type in expect with the given prefix match the corresponding type in actual? -func matchParams(pass *analysis.Pass, expect []string, actual []ast.Expr, prefix string) bool { +func matchParams(pass *analysis.Pass, expect []string, actual *types.Tuple, prefix string) bool { for i, x := range expect { if !strings.HasPrefix(x, prefix) { continue } - if i >= len(actual) { + if i >= actual.Len() { return false } - if !matchParamType(pass.Fset, pass.Pkg, x, actual[i]) { + if !matchParamType(pass.Fset, pass.Pkg, x, actual.At(i).Type()) { return false } } - if prefix == "" && len(actual) > len(expect) { + if prefix == "" && actual.Len() > len(expect) { return false } return true } // Does this one type match? -func matchParamType(fset *token.FileSet, pkg *types.Package, expect string, actual ast.Expr) bool { +func matchParamType(fset *token.FileSet, pkg *types.Package, expect string, actual types.Type) bool { expect = strings.TrimPrefix(expect, "=") // Strip package name if we're in that package. if n := len(pkg.Name()); len(expect) > n && expect[:n] == pkg.Name() && expect[n] == '.' { @@ -205,7 +178,5 @@ func matchParamType(fset *token.FileSet, pkg *types.Package, expect string, actu } // Overkill but easy. - var buf bytes.Buffer - printer.Fprint(&buf, fset, actual) - return buf.String() == expect + return actual.String() == expect } diff --git a/go/analysis/passes/stdmethods/testdata/src/a/a.go b/go/analysis/passes/stdmethods/testdata/src/a/a.go index a007bbfb..829c5b53 100644 --- a/go/analysis/passes/stdmethods/testdata/src/a/a.go +++ b/go/analysis/passes/stdmethods/testdata/src/a/a.go @@ -8,7 +8,7 @@ import "fmt" type T int -func (T) Scan(x fmt.ScanState, c byte) {} // want "should have signature Scan" +func (T) Scan(x fmt.ScanState, c byte) {} // want `should have signature Scan\(fmt\.ScanState, rune\) error` func (T) Format(fmt.State, byte) {} // want `should have signature Format\(fmt.State, rune\)` @@ -19,5 +19,5 @@ func (U) Format(byte) {} // no error: first parameter must be fmt.State to trigg func (U) GobDecode() {} // want `should have signature GobDecode\(\[\]byte\) error` type I interface { - ReadByte() byte // want "should have signature ReadByte" + ReadByte() byte // want `should have signature ReadByte\(\) \(byte, error\)` }