@@ -178,7 +178,7 @@ type AsyncFilterFunc func(context.Context, *Node) *Node
178
178
func AsyncFilter (it Iterator , check AsyncFilterFunc , workers int ) Iterator {
179
179
f := & asyncFilterIter {
180
180
it : ensureSourceIter (it ),
181
- slots : make (chan struct {}, workers + 1 ),
181
+ slots : make (chan struct {}, workers + 1 ), // extra 1 slot to make sure all the goroutines can be completed
182
182
passed : make (chan iteratorItem ),
183
183
}
184
184
for range cap (f .slots ) {
@@ -193,6 +193,9 @@ func AsyncFilter(it Iterator, check AsyncFilterFunc, workers int) Iterator {
193
193
return
194
194
case <- f .slots :
195
195
}
196
+ defer func () {
197
+ f .slots <- struct {}{} // the iterator has ended
198
+ }()
196
199
// read from the iterator and start checking nodes in parallel
197
200
// when a node is checked, it will be sent to the passed channel
198
201
// and the slot will be released
@@ -201,7 +204,11 @@ func AsyncFilter(it Iterator, check AsyncFilterFunc, workers int) Iterator {
201
204
nodeSource := f .it .NodeSource ()
202
205
203
206
// check the node async, in a separate goroutine
204
- <- f .slots
207
+ select {
208
+ case <- ctx .Done ():
209
+ return
210
+ case <- f .slots :
211
+ }
205
212
go func () {
206
213
if nn := check (ctx , node ); nn != nil {
207
214
item := iteratorItem {nn , nodeSource }
@@ -213,8 +220,6 @@ func AsyncFilter(it Iterator, check AsyncFilterFunc, workers int) Iterator {
213
220
f .slots <- struct {}{}
214
221
}()
215
222
}
216
- // the iterator has ended
217
- f .slots <- struct {}{}
218
223
}()
219
224
220
225
return f
0 commit comments