Skip to content

Commit 56df72c

Browse files
author
Jim.Idle
committed
fix: Rework of all Hash() and Equals() methods - implement generic collections
- Implement new collections using generics that implement the functionality required by the Java runtime in a more idiomatic Go way. - Fix Hash() and Equals() for all objects in the runtime - Fix getConflictingAlts so that it behaves the same way as Java, using a new generic collection - Replaces the use of the array2DHashSet, which was causing unneeded memory allocations. Replaced with generic collection that allocates minimally (though, I think I can improve on that with a little analysis). Jim Idle - [email protected] Signed-off-by: Jim.Idle <[email protected]>
1 parent ec01c11 commit 56df72c

17 files changed

+453
-320
lines changed

runtime/Go/antlr/atn_config.go

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,14 @@ import (
88
"fmt"
99
)
1010

11-
type comparable interface {
12-
equals(other interface{}) bool
13-
}
14-
1511
// ATNConfig is a tuple: (ATN state, predicted alt, syntactic, semantic
1612
// context). The syntactic context is a graph-structured stack node whose
1713
// path(s) to the root is the rule invocation(s) chain used to arrive at the
1814
// state. The semantic context is the tree of semantic predicates encountered
1915
// before reaching an ATN state.
2016
type ATNConfig interface {
21-
comparable
22-
23-
gequals(other Collectable[ATNConfig]) bool
24-
25-
hash() int
17+
Equals(o Collectable[ATNConfig]) bool
18+
Hash() int
2619

2720
GetState() ATNState
2821
GetAlt() int
@@ -136,15 +129,17 @@ func (b *BaseATNConfig) GetReachesIntoOuterContext() int {
136129
func (b *BaseATNConfig) SetReachesIntoOuterContext(v int) {
137130
b.reachesIntoOuterContext = v
138131
}
139-
func (b *BaseATNConfig) equals(o interface{}) bool {
140-
return b.gequals(o.(Collectable[ATNConfig]))
141-
}
142132

133+
// Equals is the default comparison function for an ATNConfig when no specialist implementation is required
134+
// for a collection.
135+
//
143136
// An ATN configuration is equal to another if both have the same state, they
144137
// predict the same alternative, and syntactic/semantic contexts are the same.
145-
func (b *BaseATNConfig) gequals(o Collectable[ATNConfig]) bool {
138+
func (b *BaseATNConfig) Equals(o Collectable[ATNConfig]) bool {
146139
if b == o {
147140
return true
141+
} else if o == nil {
142+
return false
148143
}
149144

150145
var other, ok = o.(*BaseATNConfig)
@@ -158,30 +153,32 @@ func (b *BaseATNConfig) gequals(o Collectable[ATNConfig]) bool {
158153
if b.context == nil {
159154
equal = other.context == nil
160155
} else {
161-
equal = b.context.gequals(other.context)
156+
equal = b.context.Equals(other.context)
162157
}
163158

164159
var (
165160
nums = b.state.GetStateNumber() == other.state.GetStateNumber()
166161
alts = b.alt == other.alt
167-
cons = b.semanticContext.equals(other.semanticContext)
162+
cons = b.semanticContext.Equals(other.semanticContext)
168163
sups = b.precedenceFilterSuppressed == other.precedenceFilterSuppressed
169164
)
170165

171166
return nums && alts && cons && sups && equal
172167
}
173168

174-
func (b *BaseATNConfig) hash() int {
169+
// Hash is the default hash function for BaseATNConfig, when no specialist hash function
170+
// is required for a collection
171+
func (b *BaseATNConfig) Hash() int {
175172
var c int
176173
if b.context != nil {
177-
c = b.context.hash()
174+
c = b.context.Hash()
178175
}
179176

180177
h := murmurInit(7)
181178
h = murmurUpdate(h, b.state.GetStateNumber())
182179
h = murmurUpdate(h, b.alt)
183180
h = murmurUpdate(h, c)
184-
h = murmurUpdate(h, b.semanticContext.hash())
181+
h = murmurUpdate(h, b.semanticContext.Hash())
185182
return murmurFinish(h, 4)
186183
}
187184

@@ -248,7 +245,9 @@ func NewLexerATNConfig1(state ATNState, alt int, context PredictionContext) *Lex
248245
return &LexerATNConfig{BaseATNConfig: NewBaseATNConfig5(state, alt, context, SemanticContextNone)}
249246
}
250247

251-
func (l *LexerATNConfig) hash() int {
248+
// Hash is the default hash function for LexerATNConfig objects, it can be used directly or via
249+
// the default comparator [ObjEqComparator].
250+
func (l *LexerATNConfig) Hash() int {
252251
var f int
253252
if l.passedThroughNonGreedyDecision {
254253
f = 1
@@ -258,19 +257,20 @@ func (l *LexerATNConfig) hash() int {
258257
h := murmurInit(7)
259258
h = murmurUpdate(h, l.state.GetStateNumber())
260259
h = murmurUpdate(h, l.alt)
261-
h = murmurUpdate(h, l.context.hash())
262-
h = murmurUpdate(h, l.semanticContext.hash())
260+
h = murmurUpdate(h, l.context.Hash())
261+
h = murmurUpdate(h, l.semanticContext.Hash())
263262
h = murmurUpdate(h, f)
264-
h = murmurUpdate(h, l.lexerActionExecutor.hash())
263+
h = murmurUpdate(h, l.lexerActionExecutor.Hash())
265264
h = murmurFinish(h, 6)
266265
return h
267266
}
268267

269-
func (l *LexerATNConfig) equals(other interface{}) bool {
270-
return l.gequals(other.(Collectable[ATNConfig]))
271-
}
272-
273-
func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool {
268+
// Equals is the default comparison function for LexerATNConfig objects, it can be used directly or via
269+
// the default comparator [ObjEqComparator].
270+
func (l *LexerATNConfig) Equals(other Collectable[ATNConfig]) bool {
271+
if l == other {
272+
return true
273+
}
274274
var othert, ok = other.(*LexerATNConfig)
275275

276276
if l == other {
@@ -284,7 +284,7 @@ func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool {
284284
var b bool
285285

286286
if l.lexerActionExecutor != nil {
287-
b = !l.lexerActionExecutor.equals(othert.lexerActionExecutor)
287+
b = !l.lexerActionExecutor.Equals(othert.lexerActionExecutor)
288288
} else {
289289
b = othert.lexerActionExecutor != nil
290290
}
@@ -293,7 +293,7 @@ func (l *LexerATNConfig) gequals(other Collectable[ATNConfig]) bool {
293293
return false
294294
}
295295

296-
return l.BaseATNConfig.equals(othert.BaseATNConfig)
296+
return l.BaseATNConfig.Equals(othert.BaseATNConfig)
297297
}
298298

299299
func checkNonGreedyDecision(source *LexerATNConfig, target ATNState) bool {

runtime/Go/antlr/atn_config_set.go

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@ package antlr
77
import "fmt"
88

99
type ATNConfigSet interface {
10-
hash() int
10+
Hash() int
11+
Equals(o Collectable[ATNConfig]) bool
1112
Add(ATNConfig, *DoubleDict) bool
1213
AddAll([]ATNConfig) bool
1314

14-
GetStates() Set
15+
GetStates() *JStore[ATNState, Comparator[ATNState]]
1516
GetPredicates() []SemanticContext
1617
GetItems() []ATNConfig
1718

1819
OptimizeConfigs(interpreter *BaseATNSimulator)
1920

20-
Equals(other interface{}) bool
21-
2221
Length() int
2322
IsEmpty() bool
2423
Contains(ATNConfig) bool
@@ -57,7 +56,7 @@ type BaseATNConfigSet struct {
5756
// effectively doubles the number of objects associated with ATNConfigs. All
5857
// keys are hashed by (s, i, _, pi), not including the context. Wiped out when
5958
// read-only because a set becomes a DFA state.
60-
configLookup Set
59+
configLookup *JStore[ATNConfig, Comparator[ATNConfig]]
6160

6261
// configs is the added elements.
6362
configs []ATNConfig
@@ -83,7 +82,7 @@ type BaseATNConfigSet struct {
8382

8483
// readOnly is whether it is read-only. Do not
8584
// allow any code to manipulate the set if true because DFA states will point at
86-
// sets and those must not change. It not protect other fields; conflictingAlts
85+
// sets and those must not change. It not, protect other fields; conflictingAlts
8786
// in particular, which is assigned after readOnly.
8887
readOnly bool
8988

@@ -104,7 +103,7 @@ func (b *BaseATNConfigSet) Alts() *BitSet {
104103
func NewBaseATNConfigSet(fullCtx bool) *BaseATNConfigSet {
105104
return &BaseATNConfigSet{
106105
cachedHash: -1,
107-
configLookup: newArray2DHashSetWithCap(hashATNConfig, equalATNConfigs, 16, 2),
106+
configLookup: NewJStore[ATNConfig, Comparator[ATNConfig]](&ATNConfigComparator[ATNConfig]{}),
108107
fullCtx: fullCtx,
109108
}
110109
}
@@ -126,9 +125,11 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
126125
b.dipsIntoOuterContext = true
127126
}
128127

129-
existing := b.configLookup.Add(config).(ATNConfig)
128+
existing, present := b.configLookup.Put(config)
130129

131-
if existing == config {
130+
// The config was not already in the set
131+
//
132+
if !present {
132133
b.cachedHash = -1
133134
b.configs = append(b.configs, config) // Track order here
134135
return true
@@ -154,11 +155,14 @@ func (b *BaseATNConfigSet) Add(config ATNConfig, mergeCache *DoubleDict) bool {
154155
return true
155156
}
156157

157-
func (b *BaseATNConfigSet) GetStates() Set {
158-
states := newArray2DHashSet(nil, nil)
158+
func (b *BaseATNConfigSet) GetStates() *JStore[ATNState, Comparator[ATNState]] {
159+
160+
// states uses the standard comparator provided by the ATNState instance
161+
//
162+
states := NewJStore[ATNState, Comparator[ATNState]](&ObjEqComparator[ATNState]{})
159163

160164
for i := 0; i < len(b.configs); i++ {
161-
states.Add(b.configs[i].GetState())
165+
states.Put(b.configs[i].GetState())
162166
}
163167

164168
return states
@@ -227,7 +231,7 @@ func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool {
227231
for _, c := range b.configs {
228232
found := false
229233
for _, c2 := range bs.configs {
230-
if c.equals(c2) {
234+
if c.Equals(c2) {
231235
found = true
232236
break
233237
}
@@ -241,7 +245,7 @@ func (b *BaseATNConfigSet) Compare(bs *BaseATNConfigSet) bool {
241245
return true
242246
}
243247

244-
func (b *BaseATNConfigSet) Equals(other interface{}) bool {
248+
func (b *BaseATNConfigSet) Equals(other Collectable[ATNConfig]) bool {
245249
if b == other {
246250
return true
247251
} else if _, ok := other.(*BaseATNConfigSet); !ok {
@@ -259,7 +263,7 @@ func (b *BaseATNConfigSet) Equals(other interface{}) bool {
259263
b.Compare(other2)
260264
}
261265

262-
func (b *BaseATNConfigSet) hash() int {
266+
func (b *BaseATNConfigSet) Hash() int {
263267
if b.readOnly {
264268
if b.cachedHash == -1 {
265269
b.cachedHash = b.hashCodeConfigs()
@@ -274,7 +278,7 @@ func (b *BaseATNConfigSet) hash() int {
274278
func (b *BaseATNConfigSet) hashCodeConfigs() int {
275279
h := 1
276280
for _, config := range b.configs {
277-
h = 31*h + config.hash()
281+
h = 31*h + config.Hash()
278282
}
279283
return h
280284
}
@@ -310,7 +314,7 @@ func (b *BaseATNConfigSet) Clear() {
310314

311315
b.configs = make([]ATNConfig, 0)
312316
b.cachedHash = -1
313-
b.configLookup = newArray2DHashSet(nil, equalATNConfigs)
317+
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](&BaseATNConfigComparator[ATNConfig]{})
314318
}
315319

316320
func (b *BaseATNConfigSet) FullContext() bool {
@@ -392,7 +396,8 @@ type OrderedATNConfigSet struct {
392396
func NewOrderedATNConfigSet() *OrderedATNConfigSet {
393397
b := NewBaseATNConfigSet(false)
394398

395-
b.configLookup = newArray2DHashSet(nil, nil)
399+
// This set uses the standard Hash() and Equals() from ATNConfig
400+
b.configLookup = NewJStore[ATNConfig, Comparator[ATNConfig]](&ObjEqComparator[ATNConfig]{})
396401

397402
return &OrderedATNConfigSet{BaseATNConfigSet: b}
398403
}
@@ -402,7 +407,7 @@ func hashATNConfig(i interface{}) int {
402407
hash := 7
403408
hash = 31*hash + o.GetState().GetStateNumber()
404409
hash = 31*hash + o.GetAlt()
405-
hash = 31*hash + o.GetSemanticContext().hash()
410+
hash = 31*hash + o.GetSemanticContext().Hash()
406411
return hash
407412
}
408413

@@ -430,5 +435,5 @@ func equalATNConfigs(a, b interface{}) bool {
430435
return false
431436
}
432437

433-
return ai.GetSemanticContext().equals(bi.GetSemanticContext())
438+
return ai.GetSemanticContext().Equals(bi.GetSemanticContext())
434439
}

runtime/Go/antlr/atn_state.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ type ATNState interface {
4949
AddTransition(Transition, int)
5050

5151
String() string
52-
hash() int
52+
Hash() int
53+
Equals(Collectable[ATNState]) bool
5354
}
5455

5556
type BaseATNState struct {
@@ -123,15 +124,15 @@ func (as *BaseATNState) SetNextTokenWithinRule(v *IntervalSet) {
123124
as.NextTokenWithinRule = v
124125
}
125126

126-
func (as *BaseATNState) hash() int {
127+
func (as *BaseATNState) Hash() int {
127128
return as.stateNumber
128129
}
129130

130131
func (as *BaseATNState) String() string {
131132
return strconv.Itoa(as.stateNumber)
132133
}
133134

134-
func (as *BaseATNState) equals(other interface{}) bool {
135+
func (as *BaseATNState) Equals(other Collectable[ATNState]) bool {
135136
if ot, ok := other.(ATNState); ok {
136137
return as.stateNumber == ot.GetStateNumber()
137138
}

0 commit comments

Comments
 (0)