Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions stanza/models/constituency/dynamic_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def advance_past_constituents(gold_sequence, cur_index):
"""
count = 0
while cur_index < len(gold_sequence):
if isinstance(gold_sequence[cur_index], OpenConstituent):
if type(gold_sequence[cur_index]) is OpenConstituent:
count = count + 1
elif isinstance(gold_sequence[cur_index], CloseConstituent):
elif type(gold_sequence[cur_index]) is CloseConstituent:
count = count - 1
if count == -1: return cur_index
cur_index = cur_index + 1
Expand Down Expand Up @@ -102,12 +102,12 @@ def find_in_order_constituent_end(gold_sequence, cur_index):
count = 0
saw_shift = False
while cur_index < len(gold_sequence):
if isinstance(gold_sequence[cur_index], OpenConstituent):
if type(gold_sequence[cur_index]) is OpenConstituent:
count = count + 1
elif isinstance(gold_sequence[cur_index], CloseConstituent):
elif type(gold_sequence[cur_index]) is CloseConstituent:
count = count - 1
if count == -1: return cur_index
elif isinstance(gold_sequence[cur_index], Shift):
elif type(gold_sequence[cur_index]) is Shift:
if saw_shift and count == 0:
return cur_index
else:
Expand Down
16 changes: 10 additions & 6 deletions stanza/models/constituency/in_order_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def fix_wrong_open_multiple_subtrees(gold_transition, pred_transition, gold_sequ
return fix_wrong_open_subtrees(gold_transition, pred_transition, gold_sequence, gold_index, root_labels, more_than_two=True)

def advance_past_unaries(gold_sequence, cur_index):
while cur_index + 2 < len(gold_sequence) and isinstance(gold_sequence[cur_index], OpenConstituent) and isinstance(gold_sequence[cur_index+1], CloseConstituent):
Open = OpenConstituent
Close = CloseConstituent
is_type = type
seq_len = len(gold_sequence)
while cur_index + 2 < seq_len and is_type(gold_sequence[cur_index]) is Open and is_type(gold_sequence[cur_index+1]) is Close:
cur_index += 2
return cur_index

Expand Down Expand Up @@ -445,17 +449,17 @@ def fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_
"""
Repair Close/Shift -> Shift by moving the Close to after the next block is created
"""
if not isinstance(gold_transition, CloseConstituent):
if type(gold_transition) is not CloseConstituent:
return None
if not isinstance(pred_transition, Shift):
if type(pred_transition) is not Shift:
return None
if len(gold_sequence) < gold_index + 2:
return None
start_index = gold_index + 1
start_index = advance_past_unaries(gold_sequence, start_index)
if len(gold_sequence) < start_index + 2:
return None
if not isinstance(gold_sequence[start_index], Shift):
if type(gold_sequence[start_index]) is not Shift:
return None

end_index = find_in_order_constituent_end(gold_sequence, start_index)
Expand All @@ -467,9 +471,9 @@ def fix_close_shift_shift(gold_transition, pred_transition, gold_sequence, gold_
# if you would normally start building stuff_3,
# it is not clear if you want to close at the end of
# stuff_2 or build stuff_3 instead.
if ambiguous and isinstance(gold_sequence[end_index], CloseConstituent):
if ambiguous and type(gold_sequence[end_index]) is CloseConstituent:
return None
elif not ambiguous and isinstance(gold_sequence[end_index], Shift):
elif not ambiguous and type(gold_sequence[end_index]) is Shift:
return None

# close at the end of the brackets, rather than once the first bracket is finished
Expand Down