Skip to content

Commit 5a50b97

Browse files
Revert "Revert "credentials/alts: defer ALTS stream creation until handshake …" (#6179)
1 parent 89ec960 commit 5a50b97

File tree

2 files changed

+102
-18
lines changed

2 files changed

+102
-18
lines changed

credentials/alts/internal/handshaker/handshaker.go

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,16 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
138138
// and server options (server options struct does not exist now. When
139139
// caller can provide endpoints, it should be created.
140140

141-
// altsHandshaker is used to complete a ALTS handshaking between client and
141+
// altsHandshaker is used to complete an ALTS handshake between client and
142142
// server. This handshaker talks to the ALTS handshaker service in the metadata
143143
// server.
144144
type altsHandshaker struct {
145145
// RPC stream used to access the ALTS Handshaker service.
146146
stream altsgrpc.HandshakerService_DoHandshakeClient
147147
// the connection to the peer.
148148
conn net.Conn
149+
// a virtual connection to the ALTS handshaker service.
150+
clientConn *grpc.ClientConn
149151
// client handshake options.
150152
clientOpts *ClientHandshakerOptions
151153
// server handshake options.
@@ -154,39 +156,33 @@ type altsHandshaker struct {
154156
side core.Side
155157
}
156158

157-
// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
158-
// stub created using the passed conn and used to talk to the ALTS Handshaker
159+
// NewClientHandshaker creates a core.Handshaker that performs a client-side
160+
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
159161
// service in the metadata server.
160162
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
161-
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
162-
if err != nil {
163-
return nil, err
164-
}
165163
return &altsHandshaker{
166-
stream: stream,
164+
stream: nil,
167165
conn: c,
166+
clientConn: conn,
168167
clientOpts: opts,
169168
side: core.ClientSide,
170169
}, nil
171170
}
172171

173-
// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
174-
// stub created using the passed conn and used to talk to the ALTS Handshaker
172+
// NewServerHandshaker creates a core.Handshaker that performs a server-side
173+
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
175174
// service in the metadata server.
176175
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
177-
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
178-
if err != nil {
179-
return nil, err
180-
}
181176
return &altsHandshaker{
182-
stream: stream,
177+
stream: nil,
183178
conn: c,
179+
clientConn: conn,
184180
serverOpts: opts,
185181
side: core.ServerSide,
186182
}, nil
187183
}
188184

189-
// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
185+
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
190186
// done, ClientHandshake returns a secure connection.
191187
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
192188
if !acquire() {
@@ -198,6 +194,16 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
198194
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
199195
}
200196

197+
// TODO(matthewstevenson88): Change unit tests to use public APIs so
198+
// that h.stream can unconditionally be set based on h.clientConn.
199+
if h.stream == nil {
200+
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
201+
if err != nil {
202+
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
203+
}
204+
h.stream = stream
205+
}
206+
201207
// Create target identities from service account list.
202208
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
203209
for _, account := range h.clientOpts.TargetServiceAccounts {
@@ -229,7 +235,7 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
229235
return conn, authInfo, nil
230236
}
231237

232-
// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
238+
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
233239
// done, ServerHandshake returns a secure connection.
234240
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
235241
if !acquire() {
@@ -241,6 +247,16 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent
241247
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
242248
}
243249

250+
// TODO(matthewstevenson88): Change unit tests to use public APIs so
251+
// that h.stream can unconditionally be set based on h.clientConn.
252+
if h.stream == nil {
253+
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
254+
if err != nil {
255+
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
256+
}
257+
h.stream = stream
258+
}
259+
244260
p := make([]byte, frameLimit)
245261
n, err := h.conn.Read(p)
246262
if err != nil {
@@ -371,5 +387,7 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
371387
// Close terminates the Handshaker. It should be called when the caller obtains
372388
// the secure connection.
373389
func (h *altsHandshaker) Close() {
374-
h.stream.CloseSend()
390+
if h.stream != nil {
391+
h.stream.CloseSend()
392+
}
375393
}

credentials/alts/internal/handshaker/handshaker_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import (
2525
"testing"
2626
"time"
2727

28+
"github.com/google/go-cmp/cmp"
29+
"github.com/google/go-cmp/cmp/cmpopts"
2830
grpc "google.golang.org/grpc"
2931
core "google.golang.org/grpc/credentials/alts/internal"
3032
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
@@ -283,3 +285,67 @@ func (s) TestPeerNotResponding(t *testing.T) {
283285
t.Errorf("ClientHandshake() = %v, want %v", got, want)
284286
}
285287
}
288+
289+
func (s) TestNewClientHandshaker(t *testing.T) {
290+
conn := testutil.NewTestConn(nil, nil)
291+
clientConn := &grpc.ClientConn{}
292+
opts := &ClientHandshakerOptions{}
293+
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
294+
if err != nil {
295+
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
296+
}
297+
expectedHs := &altsHandshaker{
298+
stream: nil,
299+
conn: conn,
300+
clientConn: clientConn,
301+
clientOpts: opts,
302+
serverOpts: nil,
303+
side: core.ClientSide,
304+
}
305+
cmpOpts := []cmp.Option{
306+
cmp.AllowUnexported(altsHandshaker{}),
307+
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
308+
}
309+
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
310+
t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
311+
}
312+
if hs.(*altsHandshaker).stream != nil {
313+
t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream")
314+
}
315+
if hs.(*altsHandshaker).clientConn != clientConn {
316+
t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn")
317+
}
318+
hs.Close()
319+
}
320+
321+
func (s) TestNewServerHandshaker(t *testing.T) {
322+
conn := testutil.NewTestConn(nil, nil)
323+
clientConn := &grpc.ClientConn{}
324+
opts := &ServerHandshakerOptions{}
325+
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
326+
if err != nil {
327+
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
328+
}
329+
expectedHs := &altsHandshaker{
330+
stream: nil,
331+
conn: conn,
332+
clientConn: clientConn,
333+
clientOpts: nil,
334+
serverOpts: opts,
335+
side: core.ServerSide,
336+
}
337+
cmpOpts := []cmp.Option{
338+
cmp.AllowUnexported(altsHandshaker{}),
339+
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
340+
}
341+
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
342+
t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
343+
}
344+
if hs.(*altsHandshaker).stream != nil {
345+
t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream")
346+
}
347+
if hs.(*altsHandshaker).clientConn != clientConn {
348+
t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn")
349+
}
350+
hs.Close()
351+
}

0 commit comments

Comments
 (0)