Skip to content

Commit 35c1cad

Browse files
committed
Support RPC queries that filter by VM
1 parent 11fd5a3 commit 35c1cad

File tree

6 files changed

+40
-19
lines changed

6 files changed

+40
-19
lines changed

data/ram/memory/memory.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ func (s *store) GetAllByMemoryAccount(_ context.Context, memoryAccount string) (
9292
}
9393

9494
// GetAllVirtualAccountsByAddressAndType implements ram.Store.GetAllVirtualAccountsByAddressAndType
95-
func (s *store) GetAllVirtualAccountsByAddressAndType(_ context.Context, address string, accountType cvm.VirtualAccountType) ([]*ram.Record, error) {
95+
func (s *store) GetAllVirtualAccountsByAddressAndType(_ context.Context, vm, address string, accountType cvm.VirtualAccountType) ([]*ram.Record, error) {
9696
s.mu.Lock()
9797
defer s.mu.Unlock()
9898

99-
items := s.findByAddressAndAccountType(address, accountType)
99+
items := s.findByVmAddressAndAccountType(vm, address, accountType)
100100
if len(items) == 0 {
101101
return nil, ram.ErrItemNotFound
102102
}
@@ -136,14 +136,14 @@ func (s *store) findByVm(vm string) []*ram.Record {
136136
return res
137137
}
138138

139-
func (s *store) findByAddressAndAccountType(address string, accountType cvm.VirtualAccountType) []*ram.Record {
139+
func (s *store) findByVmAddressAndAccountType(vm, address string, accountType cvm.VirtualAccountType) []*ram.Record {
140140
var res []*ram.Record
141141
for _, item := range s.records {
142142
if !item.IsAllocated {
143143
continue
144144
}
145145

146-
if *item.Address == address && *item.Type == accountType {
146+
if item.Vm == vm && *item.Address == address && *item.Type == accountType {
147147
res = append(res, item)
148148
}
149149
}

data/ram/postgres/model.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,17 @@ func dbGetAllByMemoryAccount(ctx context.Context, tableName string, db *sqlx.DB,
173173
return res, nil
174174
}
175175

176-
func dbGetAllVirtualAccountsByAddressAndType(ctx context.Context, tableName string, db *sqlx.DB, address string, accountType cvm.VirtualAccountType) ([]*model, error) {
176+
func dbGetAllVirtualAccountsByAddressAndType(ctx context.Context, tableName string, db *sqlx.DB, vm, address string, accountType cvm.VirtualAccountType) ([]*model, error) {
177177
res := []*model{}
178178

179179
query := `SELECT id, vm, memory_account, index, is_allocated, address, item_type, data, slot, last_updated_at FROM ` + tableName + `
180-
WHERE address = $1 AND item_type = $2`
180+
WHERE vm = $1 AND address = $2 AND item_type = $3`
181181

182182
err := db.SelectContext(
183183
ctx,
184184
&res,
185185
query,
186+
vm,
186187
address,
187188
accountType,
188189
)

data/ram/postgres/store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ func (s *store) GetAllByMemoryAccount(ctx context.Context, memoryAccount string)
6060
}
6161

6262
// GetAllVirtualAccountsByAddressAndType implements ram.Store.GetAllVirtualAccountsByAddressAndType
63-
func (s *store) GetAllVirtualAccountsByAddressAndType(ctx context.Context, address string, accountType cvm.VirtualAccountType) ([]*ram.Record, error) {
64-
models, err := dbGetAllVirtualAccountsByAddressAndType(ctx, s.tableName, s.db, address, accountType)
63+
func (s *store) GetAllVirtualAccountsByAddressAndType(ctx context.Context, vm, address string, accountType cvm.VirtualAccountType) ([]*ram.Record, error) {
64+
models, err := dbGetAllVirtualAccountsByAddressAndType(ctx, s.tableName, s.db, vm, address, accountType)
6565
if err != nil {
6666
return nil, err
6767
}

data/ram/store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ type Store interface {
2424
GetAllByMemoryAccount(ctx context.Context, memoryAccount string) ([]*Record, error)
2525

2626
// GetAllVirtualAccountsByAddressAndType gets all database records for
27-
// allocated memory with the provided address and account type
28-
GetAllVirtualAccountsByAddressAndType(ctx context.Context, address string, accountType cvm.VirtualAccountType) ([]*Record, error)
27+
// allocated memory with the provided address and account type in a VM
28+
GetAllVirtualAccountsByAddressAndType(ctx context.Context, vm, address string, accountType cvm.VirtualAccountType) ([]*Record, error)
2929
}

data/ram/tests/tests.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ func testRoundTrip(t *testing.T, s ram.Store) {
2929
t.Run("testRoundTrip", func(t *testing.T) {
3030
ctx := context.Background()
3131

32+
vm := "vm"
3233
memoryAccount := "memory_account"
3334
address := "address"
3435
accountType := cvm.VirtualAccountTypeTimelock
3536

36-
_, err := s.GetAllVirtualAccountsByAddressAndType(ctx, address, accountType)
37+
_, err := s.GetAllVirtualAccountsByAddressAndType(ctx, vm, address, accountType)
3738
assert.Equal(t, ram.ErrItemNotFound, err)
3839

3940
_, err = s.GetAllByMemoryAccount(ctx, memoryAccount)
@@ -42,7 +43,7 @@ func testRoundTrip(t *testing.T, s ram.Store) {
4243
start := time.Now()
4344

4445
expected := &ram.Record{
45-
Vm: "vm",
46+
Vm: vm,
4647

4748
MemoryAccount: memoryAccount,
4849
Index: 12345,
@@ -61,7 +62,7 @@ func testRoundTrip(t *testing.T, s ram.Store) {
6162
assert.EqualValues(t, 1, expected.Id)
6263
assert.True(t, expected.LastUpdatedAt.After(start))
6364

64-
actual, err := s.GetAllVirtualAccountsByAddressAndType(ctx, address, accountType)
65+
actual, err := s.GetAllVirtualAccountsByAddressAndType(ctx, vm, address, accountType)
6566
require.NoError(t, err)
6667
require.Len(t, actual, 1)
6768
assertEquivalentRecords(t, &cloned, actual[0])
@@ -77,7 +78,7 @@ func testRoundTrip(t *testing.T, s ram.Store) {
7778
expected.Data = nil
7879
assert.Equal(t, ram.ErrStaleState, s.Save(ctx, expected))
7980

80-
actual, err = s.GetAllVirtualAccountsByAddressAndType(ctx, address, accountType)
81+
actual, err = s.GetAllVirtualAccountsByAddressAndType(ctx, vm, address, accountType)
8182
require.NoError(t, err)
8283
require.Len(t, actual, 1)
8384
assertEquivalentRecords(t, &cloned, actual[0])
@@ -91,7 +92,7 @@ func testRoundTrip(t *testing.T, s ram.Store) {
9192
cloned = expected.Clone()
9293
require.NoError(t, s.Save(ctx, expected))
9394

94-
_, err = s.GetAllVirtualAccountsByAddressAndType(ctx, address, accountType)
95+
_, err = s.GetAllVirtualAccountsByAddressAndType(ctx, vm, address, accountType)
9596
assert.Equal(t, ram.ErrItemNotFound, err)
9697

9798
actual, err = s.GetAllByMemoryAccount(ctx, memoryAccount)
@@ -172,6 +173,7 @@ func testGetAllVirtualAccountsByAddressAndType(t *testing.T, s ram.Store) {
172173
t.Run("testGetAllVirtualAccountsByAddressAndType", func(t *testing.T) {
173174
ctx := context.Background()
174175

176+
vm := "vm"
175177
addressToQuery := "address0"
176178

177179
var expected []*ram.Record
@@ -202,12 +204,15 @@ func testGetAllVirtualAccountsByAddressAndType(t *testing.T, s ram.Store) {
202204
}
203205
}
204206

205-
actual, err := s.GetAllVirtualAccountsByAddressAndType(ctx, addressToQuery, cvm.VirtualAccountTypeDurableNonce)
207+
actual, err := s.GetAllVirtualAccountsByAddressAndType(ctx, vm, addressToQuery, cvm.VirtualAccountTypeDurableNonce)
206208
require.NoError(t, err)
207209
require.Len(t, actual, len(expected))
208210
for i, record := range actual {
209211
assertEquivalentRecords(t, record, expected[i])
210212
}
213+
214+
_, err = s.GetAllVirtualAccountsByAddressAndType(ctx, vm+"-other", addressToQuery, cvm.VirtualAccountTypeDurableNonce)
215+
assert.Equal(t, ram.ErrItemNotFound, err)
211216
})
212217
}
213218

rpc/server.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ func (s *server) GetVirtualTimelockAccounts(ctx context.Context, req *indexerpb.
3939
"owner": base58.Encode(req.Owner.Value),
4040
})
4141

42-
records, err := s.ramStore.GetAllVirtualAccountsByAddressAndType(ctx, base58.Encode(req.Owner.Value), cvm.VirtualAccountTypeTimelock)
42+
records, err := s.ramStore.GetAllVirtualAccountsByAddressAndType(
43+
ctx,
44+
base58.Encode(req.VmAccount.Value),
45+
base58.Encode(req.Owner.Value),
46+
cvm.VirtualAccountTypeTimelock,
47+
)
4348
if err == ram.ErrItemNotFound {
4449
return &indexerpb.GetVirtualTimelockAccountsResponse{
4550
Result: indexerpb.GetVirtualTimelockAccountsResponse_NOT_FOUND,
@@ -100,7 +105,12 @@ func (s *server) GetVirtualDurableNonce(ctx context.Context, req *indexerpb.GetV
100105
"address": base58.Encode(req.Address.Value),
101106
})
102107

103-
records, err := s.ramStore.GetAllVirtualAccountsByAddressAndType(ctx, base58.Encode(req.Address.Value), cvm.VirtualAccountTypeDurableNonce)
108+
records, err := s.ramStore.GetAllVirtualAccountsByAddressAndType(
109+
ctx,
110+
base58.Encode(req.VmAccount.Value),
111+
base58.Encode(req.Address.Value),
112+
cvm.VirtualAccountTypeDurableNonce,
113+
)
104114
if err == ram.ErrItemNotFound {
105115
return &indexerpb.GetVirtualDurableNonceResponse{
106116
Result: indexerpb.GetVirtualDurableNonceResponse_NOT_FOUND,
@@ -156,7 +166,12 @@ func (s *server) GetVirtualRelayAccount(ctx context.Context, req *indexerpb.GetV
156166
"address": base58.Encode(req.Address.Value),
157167
})
158168

159-
records, err := s.ramStore.GetAllVirtualAccountsByAddressAndType(ctx, base58.Encode(req.Address.Value), cvm.VirtualAccountTypeRelay)
169+
records, err := s.ramStore.GetAllVirtualAccountsByAddressAndType(
170+
ctx,
171+
base58.Encode(req.VmAccount.Value),
172+
base58.Encode(req.Address.Value),
173+
cvm.VirtualAccountTypeRelay,
174+
)
160175
if err == ram.ErrItemNotFound {
161176
return &indexerpb.GetVirtualRelayAccountResponse{
162177
Result: indexerpb.GetVirtualRelayAccountResponse_NOT_FOUND,

0 commit comments

Comments
 (0)