Skip to content

Commit 28f7ec3

Browse files
committed
prefix match + code action
1 parent 397ea86 commit 28f7ec3

File tree

5 files changed

+190
-77
lines changed

5 files changed

+190
-77
lines changed

gopls/internal/cache/testfuncs/tests.go

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,20 @@ func (index *Index) All() []Result {
4949
type Result struct {
5050
Location protocol.Location // location of the test
5151
Name string // name of the test
52+
Type TestType // type of the test
53+
Subtest bool
5254
}
5355

56+
type TestType int
57+
58+
const (
59+
TypeInvalid TestType = iota
60+
TypeTest
61+
TypeBenchmark
62+
TypeFuzz
63+
TypeExample
64+
)
65+
5466
// NewIndex returns a new index of method-set information for all
5567
// package-level types in the specified package.
5668
func NewIndex(files []*parsego.File, info *types.Info) *Index {
@@ -84,15 +96,16 @@ func (b *indexBuilder) build(files []*parsego.File, info *types.Info) *Index {
8496
continue
8597
}
8698

87-
isTest, isExample := isTestOrExample(obj)
88-
if !isTest && !isExample {
99+
testType := getTestType(obj)
100+
if testType == TypeInvalid {
89101
continue
90102
}
91103

92104
var t gobTest
93105
t.Name = decl.Name.Name
94106
t.Location.URI = file.URI
95107
t.Location.Range, _ = file.NodeRange(decl)
108+
t.Type = testType
96109

97110
i, ok := b.fileIndex[t.Location.URI]
98111
if !ok {
@@ -105,7 +118,8 @@ func (b *indexBuilder) build(files []*parsego.File, info *types.Info) *Index {
105118
b.visited[obj] = true
106119

107120
// Check for subtests
108-
if isTest {
121+
switch testType {
122+
case TypeTest, TypeBenchmark, TypeFuzz:
109123
b.Files[i].Tests = append(b.Files[i].Tests, b.findSubtests(t, decl.Type, decl.Body, file, files, info)...)
110124
}
111125
}
@@ -168,6 +182,8 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast
168182
t.Name = b.uniqueName(parent.Name, rewrite(constant.StringVal(val)))
169183
t.Location.URI = file.URI
170184
t.Location.Range, _ = file.NodeRange(call)
185+
t.Type = parent.Type
186+
t.Subtest = true
171187
tests = append(tests, t)
172188

173189
fn, typ, body := findFunc(files, info, body, call.Args[1])
@@ -182,7 +198,8 @@ func (b *indexBuilder) findSubtests(parent gobTest, typ *ast.FuncType, body *ast
182198
}
183199

184200
// Never recurse if the second argument is a top-level test function
185-
if isTest, _ := isTestOrExample(fn); isTest {
201+
switch getTestType(fn) {
202+
case TypeTest, TypeBenchmark, TypeFuzz:
186203
continue
187204
}
188205

@@ -258,30 +275,35 @@ func findFunc(files []*parsego.File, info *types.Info, body *ast.BlockStmt, expr
258275
return nil, nil, nil
259276
}
260277

261-
// isTestOrExample reports whether the given func is a testing func or an
262-
// example func (or neither). isTestOrExample returns (true, false) for testing
263-
// funcs, (false, true) for example funcs, and (false, false) otherwise.
264-
func isTestOrExample(fn *types.Func) (isTest, isExample bool) {
278+
// getTestType reports the test type of the given function.
279+
func getTestType(fn *types.Func) TestType {
265280
sig := fn.Type().(*types.Signature)
266-
if sig.Params().Len() == 0 &&
267-
sig.Results().Len() == 0 {
268-
return false, isTestName(fn.Name(), "Example")
281+
if sig.Params().Len() == 0 && sig.Results().Len() == 0 {
282+
if isTestName(fn.Name(), "Example") {
283+
return TypeExample
284+
}
285+
return TypeInvalid
269286
}
270287

271288
kind, ok := testKind(sig)
272289
if !ok {
273-
return false, false
290+
return TypeInvalid
274291
}
275292
switch kind.Name() {
276293
case "T":
277-
return isTestName(fn.Name(), "Test"), false
294+
if isTestName(fn.Name(), "Test") {
295+
return TypeTest
296+
}
278297
case "B":
279-
return isTestName(fn.Name(), "Benchmark"), false
298+
if isTestName(fn.Name(), "Benchmark") {
299+
return TypeBenchmark
300+
}
280301
case "F":
281-
return isTestName(fn.Name(), "Fuzz"), false
282-
default:
283-
return false, false // "can't happen" (see testKind)
302+
if isTestName(fn.Name(), "Fuzz") {
303+
return TypeFuzz
304+
}
284305
}
306+
return TypeInvalid
285307
}
286308

287309
// isTestName reports whether name is a valid test name for the test kind
@@ -352,6 +374,8 @@ type gobFile struct {
352374
type gobTest struct {
353375
Location protocol.Location // location of the test
354376
Name string // name of the test
377+
Type TestType // type of the test
378+
Subtest bool
355379
}
356380

357381
func (t *gobTest) result() Result {

gopls/internal/golang/code_lens.go

Lines changed: 92 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"golang.org/x/tools/gopls/internal/cache"
2020
"golang.org/x/tools/gopls/internal/cache/metadata"
2121
"golang.org/x/tools/gopls/internal/cache/parsego"
22+
"golang.org/x/tools/gopls/internal/cache/testfuncs"
2223
"golang.org/x/tools/gopls/internal/file"
2324
"golang.org/x/tools/gopls/internal/protocol"
2425
"golang.org/x/tools/gopls/internal/protocol/command"
@@ -213,6 +214,29 @@ func regenerateCgoLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Ha
213214
}
214215

215216
func goToTestCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) ([]protocol.CodeLens, error) {
217+
matches, err := matchFunctionsWithTests(ctx, snapshot, fh)
218+
if err != nil {
219+
return nil, err
220+
}
221+
222+
lenses := make([]protocol.CodeLens, 0, len(matches))
223+
for _, t := range matches {
224+
lenses = append(lenses, protocol.CodeLens{
225+
Range: protocol.Range{Start: t.FuncPos, End: t.FuncPos},
226+
Command: command.NewGoToTestCommand("Go to "+t.Name, t.Loc),
227+
})
228+
}
229+
return lenses, nil
230+
}
231+
232+
type TestMatch struct {
233+
FuncPos protocol.Position // function position
234+
Name string // test name
235+
Loc protocol.Location // test location
236+
Type testfuncs.TestType // test type
237+
}
238+
239+
func matchFunctionsWithTests(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle) (matches []TestMatch, err error) {
216240
if strings.HasSuffix(fh.URI().Path(), "_test.go") {
217241
// Ignore test files.
218242
return nil, nil
@@ -238,7 +262,12 @@ func goToTestCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Han
238262
if err != nil {
239263
return nil, fmt.Errorf("couldn't parse file: %w", err)
240264
}
241-
funcPos := make(map[string]protocol.Position)
265+
266+
type Func struct {
267+
Name string
268+
Pos protocol.Position
269+
}
270+
var fileFuncs []Func
242271
for _, d := range pgf.File.Decls {
243272
fn, ok := d.(*ast.FuncDecl)
244273
if !ok {
@@ -254,32 +283,11 @@ func goToTestCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Han
254283
_, rname, _ := astutil.UnpackRecv(fn.Recv.List[0].Type)
255284
name = rname.Name + "_" + fn.Name.Name
256285
}
257-
funcPos[name] = rng.Start
258-
}
259-
260-
type TestType int
261-
262-
// Types are sorted by priority from high to low.
263-
const (
264-
T TestType = iota + 1
265-
E
266-
B
267-
F
268-
)
269-
testTypes := map[string]TestType{
270-
"Test": T,
271-
"Example": E,
272-
"Benchmark": B,
273-
"Fuzz": F,
274-
}
275-
276-
type Test struct {
277-
FuncPos protocol.Position
278-
Name string
279-
Loc protocol.Location
280-
Type TestType
286+
fileFuncs = append(fileFuncs, Func{
287+
Name: name,
288+
Pos: rng.Start,
289+
})
281290
}
282-
var matchedTests []Test
283291

284292
pkgIDs := make([]PackageID, 0, len(testPackages))
285293
for _, pkg := range testPackages {
@@ -291,48 +299,54 @@ func goToTestCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Han
291299
}
292300
for _, tests := range allTests {
293301
for _, test := range tests.All() {
294-
var (
295-
name string
296-
testType TestType
297-
)
298-
for prefix, t := range testTypes {
299-
if strings.HasPrefix(test.Name, prefix) {
300-
testType = t
301-
name = test.Name[len(prefix):]
302-
break
303-
}
302+
if test.Subtest {
303+
continue
304304
}
305-
if testType == 0 {
306-
continue // unknown type
305+
potentialFuncNames := getPotentialFuncNames(test)
306+
if len(potentialFuncNames) == 0 {
307+
continue
307308
}
308-
name = strings.TrimPrefix(name, "_")
309-
310-
// Try to find 'Foo' for 'TestFoo' and 'foo' for 'Test_foo'.
311-
pos, ok := funcPos[name]
312-
if !ok && token.IsExported(name) {
313-
// Try to find 'foo' for 'TestFoo'.
314-
runes := []rune(name)
315-
runes[0] = unicode.ToLower(runes[0])
316-
pos, ok = funcPos[string(runes)]
309+
310+
var matchedFunc Func
311+
for _, fn := range fileFuncs {
312+
var matched bool
313+
for _, n := range potentialFuncNames {
314+
// Check the prefix to be able to match 'TestDeletePanics' with 'Delete'.
315+
if strings.HasPrefix(n, fn.Name) {
316+
matched = true
317+
break
318+
}
319+
}
320+
if !matched {
321+
continue
322+
}
323+
324+
// Use the most specific function:
325+
//
326+
// - match 'TestDelete', 'TestDeletePanics' with 'Delete'
327+
// - match 'TestDeleteFunc', 'TestDeleteFuncClearTail' with 'DeleteFunc', not 'Delete'
328+
if len(matchedFunc.Name) < len(fn.Name) {
329+
matchedFunc = fn
330+
}
317331
}
318-
if ok {
332+
if matchedFunc.Name != "" {
319333
loc := test.Location
320334
loc.Range.End = loc.Range.Start // move cursor to the test's beginning
321335

322-
matchedTests = append(matchedTests, Test{
323-
FuncPos: pos,
336+
matches = append(matches, TestMatch{
337+
FuncPos: matchedFunc.Pos,
324338
Name: test.Name,
325339
Loc: loc,
326-
Type: testType,
340+
Type: test.Type,
327341
})
328342
}
329343
}
330344
}
331-
if len(matchedTests) == 0 {
345+
if len(matches) == 0 {
332346
return nil, nil
333347
}
334348

335-
slices.SortFunc(matchedTests, func(a, b Test) int {
349+
slices.SortFunc(matches, func(a, b TestMatch) int {
336350
if v := protocol.ComparePosition(a.FuncPos, b.FuncPos); v != 0 {
337351
return v
338352
}
@@ -341,13 +355,31 @@ func goToTestCodeLens(ctx context.Context, snapshot *cache.Snapshot, fh file.Han
341355
}
342356
return cmp.Compare(a.Name, b.Name)
343357
})
358+
return matches, nil
359+
}
344360

345-
lenses := make([]protocol.CodeLens, 0, len(matchedTests))
346-
for _, t := range matchedTests {
347-
lenses = append(lenses, protocol.CodeLens{
348-
Range: protocol.Range{Start: t.FuncPos, End: t.FuncPos},
349-
Command: command.NewGoToTestCommand("Go to "+t.Name, t.Loc),
350-
})
361+
func getPotentialFuncNames(test testfuncs.Result) []string {
362+
var name string
363+
switch test.Type {
364+
case testfuncs.TypeTest:
365+
name = strings.TrimPrefix(test.Name, "Test")
366+
case testfuncs.TypeBenchmark:
367+
name = strings.TrimPrefix(test.Name, "Benchmark")
368+
case testfuncs.TypeFuzz:
369+
name = strings.TrimPrefix(test.Name, "Fuzz")
370+
case testfuncs.TypeExample:
371+
name = strings.TrimPrefix(test.Name, "Example")
372+
}
373+
if name == "" {
374+
return nil
375+
}
376+
name = strings.TrimPrefix(name, "_")
377+
378+
lowerCasedName := []rune(name)
379+
lowerCasedName[0] = unicode.ToLower(lowerCasedName[0])
380+
381+
return []string{
382+
name, // 'Foo' for 'TestFoo', 'foo' for 'Test_foo'
383+
string(lowerCasedName), // 'foo' for 'TestFoo'
351384
}
352-
return lenses, nil
353385
}

gopls/internal/golang/codeaction.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ var codeActionProducers = [...]codeActionProducer{
266266
{kind: settings.RefactorRewriteEliminateDotImport, fn: refactorRewriteEliminateDotImport, needPkg: true},
267267
{kind: settings.RefactorRewriteAddTags, fn: refactorRewriteAddStructTags, needPkg: true},
268268
{kind: settings.RefactorRewriteRemoveTags, fn: refactorRewriteRemoveStructTags, needPkg: true},
269+
{kind: settings.GoToTest, fn: goToTest, needPkg: true},
269270
{kind: settings.GoplsDocFeatures, fn: goplsDocFeatures}, // offer this one last (#72742)
270271

271272
// Note: don't forget to update the allow-list in Server.CodeAction
@@ -1129,3 +1130,34 @@ func toggleCompilerOptDetails(ctx context.Context, req *codeActionsRequest) erro
11291130
}
11301131
return nil
11311132
}
1133+
1134+
// goToTest produces "Go to TestXxx" code action.
1135+
// See [server.commandHandler.GoToTest] for command implementation.
1136+
func goToTest(ctx context.Context, req *codeActionsRequest) error {
1137+
// TODO: add tests.
1138+
1139+
path, _ := astutil.PathEnclosingInterval(req.pgf.File, req.start, req.end)
1140+
if len(path) < 2 {
1141+
return nil
1142+
}
1143+
fn, ok := path[len(path)-2].(*ast.FuncDecl)
1144+
if !ok {
1145+
return nil
1146+
}
1147+
fnRng, err := req.pgf.NodeRange(fn)
1148+
if err != nil {
1149+
return fmt.Errorf("couldn't get node range: %w", err)
1150+
}
1151+
1152+
matches, err := matchFunctionsWithTests(ctx, req.snapshot, req.fh)
1153+
if err != nil {
1154+
return err
1155+
}
1156+
for _, m := range matches {
1157+
if m.FuncPos == fnRng.Start {
1158+
cmd := command.NewGoToTestCommand("Go to "+m.Name, m.Loc)
1159+
req.addCommandAction(cmd, false)
1160+
}
1161+
}
1162+
return nil
1163+
}

gopls/internal/settings/codeactionkind.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ const (
8383
GoToggleCompilerOptDetails protocol.CodeActionKind = "source.toggleCompilerOptDetails"
8484
AddTest protocol.CodeActionKind = "source.addTest"
8585
OrganizeImports protocol.CodeActionKind = "source.organizeImports"
86+
GoToTest protocol.CodeActionKind = "source.go_to_test"
8687

8788
// gopls
8889
GoplsDocFeatures protocol.CodeActionKind = "gopls.doc.features"

0 commit comments

Comments
 (0)