Skip to content

Commit d920fad

Browse files
authored
move context handling in trillian RPC calls to be request based and idiomatic (#2536)
* move context handling to be request based and idiomatic Signed-off-by: Bob Callaway <[email protected]> * reuse var Signed-off-by: Bob Callaway <[email protected]> --------- Signed-off-by: Bob Callaway <[email protected]>
1 parent 4b09ef5 commit d920fad

File tree

5 files changed

+46
-60
lines changed

5 files changed

+46
-60
lines changed

pkg/api/api.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ func NewAPI(treeID uint) (*API, error) {
183183

184184
cachedCheckpoints := make(map[int64]string)
185185
for _, r := range ranges.GetInactive() {
186-
tc := trillianclient.NewTrillianClient(ctx, logClient, r.TreeID)
187-
resp := tc.GetLatest(0)
186+
tc := trillianclient.NewTrillianClient(logClient, r.TreeID)
187+
resp := tc.GetLatest(ctx, 0)
188188
if resp.Status != codes.OK {
189189
return nil, fmt.Errorf("error fetching latest tree head for inactive shard %d: resp code is %d, err is %w", r.TreeID, resp.Status, resp.Err)
190190
}

pkg/api/entries.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
332332
return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry)
333333
}
334334

335-
tc := trillianclient.NewTrillianClient(ctx, api.logClient, api.treeID)
335+
tc := trillianclient.NewTrillianClient(api.logClient, api.treeID)
336336

337-
resp := tc.AddLeaf(leaf)
337+
resp := tc.AddLeaf(ctx, leaf)
338338
// this represents overall GRPC response state (not the results of insertion into the log)
339339
if resp.Status != codes.OK {
340340
return nil, handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianUnexpectedResult)
@@ -622,8 +622,8 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
622622
for i, hash := range searchHashes {
623623
var results map[int64]*trillian.GetEntryAndProofResponse
624624
for _, shard := range api.logRanges.AllShards() {
625-
tcs := trillianclient.NewTrillianClient(httpReqCtx, api.logClient, shard)
626-
resp := tcs.GetLeafAndProofByHash(hash)
625+
tcs := trillianclient.NewTrillianClient(api.logClient, shard)
626+
resp := tcs.GetLeafAndProofByHash(httpReqCtx, hash)
627627
switch resp.Status {
628628
case codes.OK:
629629
leafResult := resp.GetLeafAndProofResult
@@ -677,10 +677,10 @@ func retrieveLogEntryByIndex(ctx context.Context, logIndex int) (models.LogEntry
677677
log.ContextLogger(ctx).Infof("Retrieving log entry by index %d", logIndex)
678678

679679
tid, resolvedIndex := api.logRanges.ResolveVirtualIndex(logIndex)
680-
tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid)
680+
tc := trillianclient.NewTrillianClient(api.logClient, tid)
681681
log.ContextLogger(ctx).Debugf("Retrieving resolved index %v from TreeID %v", resolvedIndex, tid)
682682

683-
resp := tc.GetLeafAndProofByIndex(resolvedIndex)
683+
resp := tc.GetLeafAndProofByIndex(ctx, resolvedIndex)
684684
switch resp.Status {
685685
case codes.OK:
686686
case codes.NotFound, codes.OutOfRange, codes.InvalidArgument:
@@ -744,10 +744,10 @@ func retrieveUUIDFromTree(ctx context.Context, uuid string, tid int64) (models.L
744744
return models.LogEntry{}, &types.InputValidationError{Err: fmt.Errorf("parsing UUID: %w", err)}
745745
}
746746

747-
tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid)
747+
tc := trillianclient.NewTrillianClient(api.logClient, tid)
748748
log.ContextLogger(ctx).Debugf("Attempting to retrieve UUID %v from TreeID %v", uuid, tid)
749749

750-
resp := tc.GetLeafAndProofByHash(hashValue)
750+
resp := tc.GetLeafAndProofByHash(ctx, hashValue)
751751
switch resp.Status {
752752
case codes.OK:
753753
result := resp.GetLeafAndProofResult

pkg/api/tlog.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ import (
3737

3838
// GetLogInfoHandler returns the current size of the tree and the STH
3939
func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
40-
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID)
40+
ctx := params.HTTPRequest.Context()
41+
tc := trillianclient.NewTrillianClient(api.logClient, api.treeID)
4142

4243
// for each inactive shard, get the loginfo
4344
var inactiveShards []*models.InactiveShardLogInfo
4445
for _, shard := range api.logRanges.GetInactive() {
4546
// Get details for this inactive shard
46-
is, err := inactiveShardLogInfo(params.HTTPRequest.Context(), shard.TreeID, api.cachedCheckpoints)
47+
is, err := inactiveShardLogInfo(ctx, shard.TreeID, api.cachedCheckpoints)
4748
if err != nil {
4849
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("inactive shard error: %w", err), unexpectedInactiveShardError)
4950
}
@@ -53,7 +54,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
5354
if swag.BoolValue(params.Stable) && redisClient != nil {
5455
// key is treeID/latest
5556
key := fmt.Sprintf("%d/latest", api.logRanges.GetActive().TreeID)
56-
redisResult, err := redisClient.Get(params.HTTPRequest.Context(), key).Result()
57+
redisResult, err := redisClient.Get(ctx, key).Result()
5758
if err != nil {
5859
return handleRekorAPIError(params, http.StatusInternalServerError,
5960
fmt.Errorf("error getting checkpoint from redis: %w", err), "error getting checkpoint from redis")
@@ -82,7 +83,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
8283
return tlog.NewGetLogInfoOK().WithPayload(&logInfo)
8384
}
8485

85-
resp := tc.GetLatest(0)
86+
resp := tc.GetLatest(ctx, 0)
8687
if resp.Status != codes.OK {
8788
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianCommunicationError)
8889
}
@@ -96,7 +97,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
9697
hashString := hex.EncodeToString(root.RootHash)
9798
treeSize := int64(root.TreeSize)
9899

99-
scBytes, err := util.CreateAndSignCheckpoint(params.HTTPRequest.Context(),
100+
scBytes, err := util.CreateAndSignCheckpoint(ctx,
100101
viper.GetString("rekor_server.hostname"), api.logRanges.GetActive().TreeID, root.TreeSize, root.RootHash, api.logRanges.GetActive().Signer)
101102
if err != nil {
102103
return handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError)
@@ -123,17 +124,18 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
123124
errMsg := fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize)
124125
return handleRekorAPIError(params, http.StatusBadRequest, fmt.Errorf("consistency proof: %s", errMsg), errMsg)
125126
}
126-
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID)
127+
ctx := params.HTTPRequest.Context()
128+
tc := trillianclient.NewTrillianClient(api.logClient, api.treeID)
127129
if treeID := swag.StringValue(params.TreeID); treeID != "" {
128130
id, err := strconv.Atoi(treeID)
129131
if err != nil {
130132
log.Logger.Infof("Unable to convert %s to string, skipping initializing client with Tree ID: %v", treeID, err)
131133
} else {
132-
tc = trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, int64(id))
134+
tc = trillianclient.NewTrillianClient(api.logClient, int64(id))
133135
}
134136
}
135137

136-
resp := tc.GetConsistencyProof(*params.FirstSize, params.LastSize)
138+
resp := tc.GetConsistencyProof(ctx, *params.FirstSize, params.LastSize)
137139
if resp.Status != codes.OK {
138140
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianCommunicationError)
139141
}
@@ -168,8 +170,8 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
168170
}
169171

170172
func inactiveShardLogInfo(ctx context.Context, tid int64, cachedCheckpoints map[int64]string) (*models.InactiveShardLogInfo, error) {
171-
tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid)
172-
resp := tc.GetLatest(0)
173+
tc := trillianclient.NewTrillianClient(api.logClient, tid)
174+
resp := tc.GetLatest(ctx, 0)
173175
if resp.Status != codes.OK {
174176
return nil, fmt.Errorf("resp code is %d", resp.Status)
175177
}

pkg/trillianclient/trillian_client.go

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,15 @@ import (
3636

3737
// TrillianClient provides a wrapper around the Trillian client
3838
type TrillianClient struct {
39-
client trillian.TrillianLogClient
40-
logID int64
41-
context context.Context
39+
client trillian.TrillianLogClient
40+
logID int64
4241
}
4342

4443
// NewTrillianClient creates a TrillianClient with the given Trillian client and log/tree ID.
45-
func NewTrillianClient(ctx context.Context, logClient trillian.TrillianLogClient, logID int64) TrillianClient {
44+
func NewTrillianClient(logClient trillian.TrillianLogClient, logID int64) TrillianClient {
4645
return TrillianClient{
47-
client: logClient,
48-
logID: logID,
49-
context: ctx,
46+
client: logClient,
47+
logID: logID,
5048
}
5149
}
5250

@@ -76,26 +74,26 @@ func unmarshalLogRoot(logRoot []byte) (types.LogRootV1, error) {
7674
return root, nil
7775
}
7876

79-
func (t *TrillianClient) root() (types.LogRootV1, error) {
77+
func (t *TrillianClient) root(ctx context.Context) (types.LogRootV1, error) {
8078
rqst := &trillian.GetLatestSignedLogRootRequest{
8179
LogId: t.logID,
8280
}
83-
resp, err := t.client.GetLatestSignedLogRoot(t.context, rqst)
81+
resp, err := t.client.GetLatestSignedLogRoot(ctx, rqst)
8482
if err != nil {
8583
return types.LogRootV1{}, err
8684
}
8785
return unmarshalLogRoot(resp.SignedLogRoot.LogRoot)
8886
}
8987

90-
func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
88+
func (t *TrillianClient) AddLeaf(ctx context.Context, byteValue []byte) *Response {
9189
leaf := &trillian.LogLeaf{
9290
LeafValue: byteValue,
9391
}
9492
rqst := &trillian.QueueLeafRequest{
9593
LogId: t.logID,
9694
Leaf: leaf,
9795
}
98-
resp, err := t.client.QueueLeaf(t.context, rqst)
96+
resp, err := t.client.QueueLeaf(ctx, rqst)
9997

10098
// check for error
10199
if err != nil || (resp.QueuedLeaf.Status != nil && resp.QueuedLeaf.Status.Code != int32(codes.OK)) {
@@ -106,7 +104,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
106104
}
107105
}
108106

109-
root, err := t.root()
107+
root, err := t.root(ctx)
110108
if err != nil {
111109
return &Response{
112110
Status: status.Code(err),
@@ -131,7 +129,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
131129
for {
132130
root = *logClient.GetRoot()
133131
if root.TreeSize >= 1 {
134-
proofResp := t.getProofByHash(resp.QueuedLeaf.Leaf.MerkleLeafHash)
132+
proofResp := t.getProofByHash(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash)
135133
// if this call succeeds or returns an error other than "not found", return
136134
if proofResp.Err == nil || (proofResp.Err != nil && status.Code(proofResp.Err) != codes.NotFound) {
137135
return proofResp
@@ -148,7 +146,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
148146
}
149147
}
150148

151-
proofResp := waitForInclusion(t.context, resp.QueuedLeaf.Leaf.MerkleLeafHash)
149+
proofResp := waitForInclusion(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash)
152150
if proofResp.Err != nil {
153151
return &Response{
154152
Status: status.Code(proofResp.Err),
@@ -168,7 +166,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
168166
}
169167

170168
leafIndex := proofs[0].LeafIndex
171-
leafResp := t.GetLeafAndProofByIndex(leafIndex)
169+
leafResp := t.GetLeafAndProofByIndex(ctx, leafIndex)
172170
if leafResp.Err != nil {
173171
return &Response{
174172
Status: status.Code(leafResp.Err),
@@ -189,9 +187,9 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
189187
}
190188
}
191189

192-
func (t *TrillianClient) GetLeafAndProofByHash(hash []byte) *Response {
190+
func (t *TrillianClient) GetLeafAndProofByHash(ctx context.Context, hash []byte) *Response {
193191
// get inclusion proof for hash, extract index, then fetch leaf using index
194-
proofResp := t.getProofByHash(hash)
192+
proofResp := t.getProofByHash(ctx, hash)
195193
if proofResp.Err != nil {
196194
return &Response{
197195
Status: status.Code(proofResp.Err),
@@ -208,14 +206,11 @@ func (t *TrillianClient) GetLeafAndProofByHash(hash []byte) *Response {
208206
}
209207
}
210208

211-
return t.GetLeafAndProofByIndex(proofs[0].LeafIndex)
209+
return t.GetLeafAndProofByIndex(ctx, proofs[0].LeafIndex)
212210
}
213211

214-
func (t *TrillianClient) GetLeafAndProofByIndex(index int64) *Response {
215-
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
216-
defer cancel()
217-
218-
rootResp := t.GetLatest(0)
212+
func (t *TrillianClient) GetLeafAndProofByIndex(ctx context.Context, index int64) *Response {
213+
rootResp := t.GetLatest(ctx, 0)
219214
if rootResp.Err != nil {
220215
return &Response{
221216
Status: status.Code(rootResp.Err),
@@ -262,11 +257,7 @@ func (t *TrillianClient) GetLeafAndProofByIndex(index int64) *Response {
262257
}
263258
}
264259

265-
func (t *TrillianClient) GetLatest(leafSizeInt int64) *Response {
266-
267-
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
268-
defer cancel()
269-
260+
func (t *TrillianClient) GetLatest(ctx context.Context, leafSizeInt int64) *Response {
270261
resp, err := t.client.GetLatestSignedLogRoot(ctx,
271262
&trillian.GetLatestSignedLogRootRequest{
272263
LogId: t.logID,
@@ -280,11 +271,7 @@ func (t *TrillianClient) GetLatest(leafSizeInt int64) *Response {
280271
}
281272
}
282273

283-
func (t *TrillianClient) GetConsistencyProof(firstSize, lastSize int64) *Response {
284-
285-
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
286-
defer cancel()
287-
274+
func (t *TrillianClient) GetConsistencyProof(ctx context.Context, firstSize, lastSize int64) *Response {
288275
resp, err := t.client.GetConsistencyProof(ctx,
289276
&trillian.GetConsistencyProofRequest{
290277
LogId: t.logID,
@@ -299,11 +286,8 @@ func (t *TrillianClient) GetConsistencyProof(firstSize, lastSize int64) *Respons
299286
}
300287
}
301288

302-
func (t *TrillianClient) getProofByHash(hashValue []byte) *Response {
303-
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
304-
defer cancel()
305-
306-
rootResp := t.GetLatest(0)
289+
func (t *TrillianClient) getProofByHash(ctx context.Context, hashValue []byte) *Response {
290+
rootResp := t.GetLatest(ctx, 0)
307291
if rootResp.Err != nil {
308292
return &Response{
309293
Status: status.Code(rootResp.Err),

pkg/witness/publish_checkpoint.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func NewCheckpointPublisher(ctx context.Context,
8181
// before publishing the latest checkpoint. If this occurs due to a sporadic failure, this simply
8282
// means that a witness will not see a fresh checkpoint for an additional period.
8383
func (c *CheckpointPublisher) StartPublisher(ctx context.Context) {
84-
tc := trillianclient.NewTrillianClient(context.Background(), c.logClient, c.treeID)
84+
tc := trillianclient.NewTrillianClient(c.logClient, c.treeID)
8585
sTreeID := strconv.FormatInt(c.treeID, 10)
8686

8787
// publish on startup to ensure a checkpoint is available the first time Rekor starts up
@@ -103,7 +103,7 @@ func (c *CheckpointPublisher) StartPublisher(ctx context.Context) {
103103
// publish publishes the latest checkpoint to Redis once
104104
func (c *CheckpointPublisher) publish(tc *trillianclient.TrillianClient, sTreeID string) {
105105
// get latest checkpoint
106-
resp := tc.GetLatest(0)
106+
resp := tc.GetLatest(context.Background(), 0)
107107
if resp.Status != codes.OK {
108108
c.reqCounter.With(
109109
map[string]string{

0 commit comments

Comments
 (0)