Skip to content
37 changes: 37 additions & 0 deletions balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,40 @@ func (rr *roundRobin) Close() error {
}
return nil
}

type pickFirst struct {
rr *roundRobin
}

// PickFirst Balancer is a simple balancer for testing multi-addresses in one addrConn.
// By using this balancer, all address shares the same addrConn.
// Although it wrapped by RoundRobin balancer, the logic of all methods work fine because
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go on pickFirst instead, otherwise it will be visible in our godoc, but it's intended for readers of the code (not the API).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it's a wrapper around RoundRobin?

// balancer.Get() returns the address Up by resetTransport()
func PickFirst(r naming.Resolver) Balancer {
return &pickFirst{rr: &roundRobin{r: r}}
}

// The only difference is using ff.watchAddrUpdates() to use findFirstMD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function comments must start with the name of the function.

In this case, you should just leave them out because the type (and, consequently this function) are not exported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think this comment is stale. watchAddrUpdates() wasn't changed and Start just calls the RR balancer's Start.

I think you can comment pickFirst as:

pickFirst is the same as roundRobin, but with a different type to control conditional behavior in grpc internals.

func (ff *pickFirst) Start(target string, config BalancerConfig) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pickFirst -> pf for receiver name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I'll change that. In my first implementation, I thought the word was findFirst, so I used ff.
And also I forgot to update the comment because I was using metadata type(findFirstMD) to decide which balancer in that version.

return ff.rr.Start(target, config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you embed the rr balancer in pickFirst, you won't need to rewrite functions like this:

type pickFirst struct {
  *roundRobin
}

}

// Up sets the connected state of addr and sends notification if there are pending
// Get() calls.
func (ff *pickFirst) Up(addr Address) func(error) {
return ff.rr.Up(addr)
}

// Get returns the next addr in the rotation.
func (ff *pickFirst) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
addr, put, err = ff.rr.Get(ctx, opts)
return
}

func (ff *pickFirst) Notify() <-chan []Address {
return ff.rr.addrCh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return ff.rr.Notify()

}

func (ff *pickFirst) Close() error {
return ff.rr.Close()
}
344 changes: 344 additions & 0 deletions balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package grpc
import (
"fmt"
"math"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -421,3 +422,346 @@ func TestOneAddressRemoval(t *testing.T) {
servers[i].stop()
}
}

func checkServerUp(t *testing.T, currentServer *server) {
req := "port"
port := currentServer.port
cc, err := Dial("localhost:"+port, WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == port {
break
}
time.Sleep(10 * time.Millisecond)
}
cc.Close()
}

func TestPickFirstEmptyAddrs(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(PickFirst(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
// available after that.
u := &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
// Loop until the above updates apply.
for {
time.Sleep(10 * time.Millisecond)
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil {
break
}
}
cc.Close()
servers[0].stop()
}

func TestPickFirstCloseWithPendingRPC(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(PickFirst(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
updates := []*naming.Update{{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}}
r.w.inject(updates)
// Loop until the above update applies.
for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
break
}
time.Sleep(10 * time.Millisecond)
}
// Issue 2 RPCs which should be completed with error status once cc is closed.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
go func() {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
time.Sleep(5 * time.Millisecond)
cc.Close()
wg.Wait()
servers[0].stop()
}

func TestPickFirstOrderAllServerUp(t *testing.T) {
// Start 3 servers on 3 ports.
numServers := 3
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(PickFirst(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}

// Add servers[1] and [2] to the service discovery.
u := &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})

u = &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})

// Loop until all 3 servers are up
checkServerUp(t, servers[0])
checkServerUp(t, servers[1])
checkServerUp(t, servers[2])

// Check the incoming RPCs served in server[0]
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
}

// Delete server[0] in the balancer, the incoming RPCs served in server[1]
// For test addrconn, close server[0] instead
u = &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
// Loop until it changes to server[1]
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}

// Add server[0] back to the balancer, the incoming RPCs served in server[1]
// Add is append operation, the order of Notify now is {server[2].port server[0].port}
u = &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}

// Delete server[1] in the balancer, the incoming RPCs served in server[2]
u = &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[2].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port)
}
time.Sleep(10 * time.Millisecond)
}

// After remove server[3], incoming RPCs still served in server[0]
cc.Close()
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}

func TestPickFirstOrderOneServerDown(t *testing.T) {
// Start 3 servers on 3 ports.
numServers := 3
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(PickFirst(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}

// Add servers[1] and [2] to the service discovery.
u := &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})

u = &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})

// Loop until all 3 servers are up
checkServerUp(t, servers[0])
checkServerUp(t, servers[1])
checkServerUp(t, servers[2])

// Check the incoming RPCs served in server[0]
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
}

// server[0] down, incoming RPCs served in server[1], but the order of Notify still remains
// {server[0] server[1] server[2]}
servers[0].stop()
// Loop until it changes to server[1]
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}

// up the server[0] back, the incoming RPCs served in server[1]
p, _ := strconv.Atoi(servers[0].port)
servers[0] = newTestServer()
go servers[0].start(t, p, math.MaxUint32)
servers[0].wait(t, 2*time.Second)
checkServerUp(t, servers[0])

for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
}

// Delete server[1] in the balancer, the incoming RPCs served in server[0]
u = &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
}

// After remove server[3], incoming RPCs still served in server[0]
cc.Close()
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}

func TestPickFirstOneAddressRemoval(t *testing.T) {
// Start 2 servers.
numServers := 2
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("localhost:"+servers[0].port, WithBalancer(PickFirst(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Add servers[1] to the service discovery.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "localhost:" + servers[1].port,
})
r.w.inject(updates)

// Create a new cc to Loop until servers[1] is up
checkServerUp(t, servers[0])
checkServerUp(t, servers[1])

var wg sync.WaitGroup
numRPC := 100
sleepDuration := 10 * time.Millisecond
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, delete server[0].
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "localhost:" + servers[0].port,
})
r.w.inject(updates)
wg.Done()
}()

// All non-failfast RPCs should not fail because there's at least one connection available.
for i := 0; i < numRPC; i++ {
wg.Add(1)
go func() {
var reply string
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
wg.Done()
}()
}
wg.Wait()
cc.Close()
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}
Loading