Skip to content

Commit 432310b

Browse files
committed
Merge branch 'rafaelha/fix-slicing-reverse-edge-case' of https://github.com/QuEraComputing/kirin into rafaelha/inline_getitem_has_done_something_fix
2 parents 9ca4cb0 + f61f805 commit 432310b

File tree

8 files changed

+11
-118
lines changed

8 files changed

+11
-118
lines changed

src/kirin/dialects/ilist/rewrite/inline_getitem.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from kirin.rewrite import abc
33
from kirin.analysis import const
44
from kirin.dialects import py
5-
from kirin.dialects.py.slice import SliceAttribute
65

76
from ..stmts import New
87

@@ -35,10 +34,8 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
3534
):
3635
node.result.replace_by(stmt.args[index])
3736
return abc.RewriteResult(has_done_something=True)
38-
elif isinstance(index, (slice, SliceAttribute)):
39-
new_tuple = New(
40-
tuple(stmt.args[index.start : index.stop : index.step]),
41-
)
37+
elif isinstance(index, slice):
38+
new_tuple = New(tuple(stmt.args[index]))
4239
node.replace_by(new_tuple)
4340
return abc.RewriteResult(has_done_something=True)
4441
else:

src/kirin/dialects/py/indexing.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,7 @@ class Concrete(interp.MethodTable):
9898

9999
@interp.impl(GetItem)
100100
def getindex(self, interp, frame: interp.Frame, stmt: GetItem):
101-
from kirin.dialects.py.slice import SliceAttribute
102-
103-
index = frame.get(stmt.index)
104-
105-
# need to handle special case of slice attribute
106-
if isinstance(index, SliceAttribute):
107-
index_value = index.unwrap()
108-
else:
109-
index_value = index
110-
111-
return (frame.get(stmt.obj)[index_value],)
101+
return (frame.get(stmt.obj)[frame.get(stmt.index)],)
112102

113103

114104
@dialect.register(key="typeinfer")
@@ -218,23 +208,13 @@ def getitem(
218208
return (const.Unknown(),)
219209

220210
if isinstance(obj, const.Value):
221-
from kirin.dialects.py.slice import SliceAttribute
222-
223-
# need to handle special case of slice attribute
224-
if isinstance(index.data, SliceAttribute):
225-
index_value = index.data.unwrap()
226-
else:
227-
index_value = index.data
228-
229-
return (const.Value(obj.data[index_value]),)
230-
211+
return (const.Value(obj.data[index.data]),)
231212
elif isinstance(obj, const.PartialTuple):
232213
obj = obj.data
233214
if isinstance(index.data, int) and 0 <= index.data < len(obj):
234215
return (obj[index.data],)
235216
elif isinstance(index.data, slice):
236-
sl = index.data
237-
return (const.PartialTuple(obj[sl.start : sl.stop : sl.step]),)
217+
return (const.PartialTuple(obj[index.data]),)
238218
return (const.Unknown(),)
239219

240220

src/kirin/dialects/py/slice.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from kirin import ir, types, interp, lowering
1515
from kirin.decl import info, statement
16-
from kirin.print.printer import Printer
1716
from kirin.dialects.py.constant import Constant
1817

1918
dialect = ir.Dialect("py.slice")
@@ -63,55 +62,18 @@ def __init__(
6362
)
6463

6564

66-
@dataclass
67-
class SliceAttribute(ir.Data[slice]):
68-
69-
start: int | None
70-
stop: int | None
71-
step: int | None
72-
73-
def __post_init__(self) -> None:
74-
if self.start is None and self.step is None:
75-
self.type = types.Slice[types.Literal(self.stop)]
76-
else:
77-
self.type = types.Slice3[
78-
types.Literal(self.start),
79-
types.Literal(self.stop),
80-
types.Literal(self.step),
81-
]
82-
83-
def unwrap(self):
84-
return slice(self.start, self.stop, self.step)
85-
86-
def __hash__(self):
87-
return hash((type(self), self.start, self.stop, self.step))
88-
89-
def print_impl(self, printer: Printer) -> None:
90-
return printer.plain_print(f"slice({self.start}, {self.stop}, {self.step})")
91-
92-
def is_structurally_equal(
93-
self, other: ir.Attribute, context: dict | None = None
94-
) -> bool:
95-
return (
96-
isinstance(other, SliceAttribute)
97-
and self.start == other.start
98-
and self.stop == other.stop
99-
and self.step == other.step
100-
)
101-
102-
10365
@dialect.register
10466
class Concrete(interp.MethodTable):
10567

10668
@interp.impl(Slice)
10769
def _slice(self, interp, frame: interp.Frame, stmt: Slice):
10870
start, stop, step = frame.get_values(stmt.args)
10971
if start is None and step is None:
110-
return (SliceAttribute(None, stop, None),)
72+
return (slice(stop),)
11173
elif step is None:
112-
return (SliceAttribute(start, stop, None),)
74+
return (slice(start, stop),)
11375
else:
114-
return (SliceAttribute(start, stop, step),)
76+
return (slice(start, stop, step),)
11577

11678

11779
@dialect.register

src/kirin/rewrite/getitem.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from kirin.analysis import const
55
from kirin.dialects import py
66
from kirin.rewrite.abc import RewriteRule, RewriteResult
7-
from kirin.dialects.py.slice import SliceAttribute
87

98

109
@dataclass
@@ -30,10 +29,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
3029
):
3130
node.result.replace_by(stmt.args[index])
3231
return RewriteResult(has_done_something=True)
33-
elif isinstance(index, (slice, SliceAttribute)):
34-
new_tuple = py.tuple.New(
35-
tuple(stmt.args[index.start : index.stop : index.step]),
36-
)
32+
elif isinstance(index, slice):
33+
new_tuple = py.tuple.New(tuple(stmt.args[index]))
3734
node.replace_by(new_tuple)
3835
return RewriteResult(has_done_something=True)
3936
else:

src/kirin/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
NoneType = PyClass(type(None))
2929
List = Generic(list, TypeVar("T"))
3030
Slice = Generic(slice, TypeVar("T"))
31-
Slice3 = Generic(slice, TypeVar("T1"), TypeVar("T2"), TypeVar("T3"))
3231
Tuple = Generic(tuple, Vararg(TypeVar("T")))
3332
Dict = Generic(dict, TypeVar("K"), TypeVar("V"))
3433
Set = Generic(set, TypeVar("T"))

test/dialects/ilist/test_inline_getitem.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def func(x: int):
3333

3434
assert before == after
3535
assert len(func.callable_region.blocks[0].stmts) == 1
36-
print(func.code.print())
3736

3837

3938
@pytest.mark.parametrize(
@@ -54,8 +53,6 @@ def func():
5453
ylist = ilist.New(values=(0, 1, 2, 3, 4), elem_type=types.PyClass(int))
5554
return ylist[sl]
5655

57-
func.code.print()
58-
5956
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
6057
assert GetItem in stmt_types
6158

@@ -90,16 +87,13 @@ def func():
9087
assert GetItem in stmt_types
9188

9289
before = func()
93-
func.code.print()
9490

9591
apply_getitem_optimization(func)
9692

9793
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
9894
assert GetItem not in stmt_types
9995
after = func()
10096

101-
func.code.print()
102-
10397
assert before == after
10498

10599

test/dialects/pystmts/test_slice.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from kirin import types
22
from kirin.prelude import basic_no_opt
3-
from kirin.dialects import py, ilist
4-
from kirin.dialects.py.slice import SliceAttribute
3+
from kirin.dialects import py
54

65

76
@basic_no_opt
@@ -45,32 +44,3 @@ def test_wrong_slice():
4544

4645
stmt: py.slice.Slice = wrong_slice.code.body.blocks[0].stmts.at(7)
4746
assert stmt.result.type.is_subseteq(types.Bottom)
48-
49-
50-
def test_slice_attr():
51-
52-
@basic_no_opt
53-
def test():
54-
55-
return (slice(0, 20), slice(30), slice(1, 40, 5))
56-
57-
result = test()
58-
assert result == (
59-
SliceAttribute(0, 20, None),
60-
SliceAttribute(None, 30, None),
61-
SliceAttribute(1, 40, 5),
62-
)
63-
64-
65-
def test_slice_attr_hash():
66-
assert hash(SliceAttribute(0, 20, None)) == hash((SliceAttribute, 0, 20, None))
67-
68-
69-
def test_slice_get_index():
70-
@basic_no_opt
71-
def test():
72-
x = slice(0, 20, None)
73-
y = range(40)
74-
return y[x]
75-
76-
assert test() == ilist.IList(range(0, 20, 1))

test/rules/test_getitem.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def func(x: int):
3030

3131
assert before == after
3232
assert len(func.callable_region.blocks[0].stmts) == 1
33-
print(func.code.print())
3433

3534

3635
@pytest.mark.parametrize(
@@ -51,8 +50,6 @@ def func():
5150
ylist = (0, 1, 2, 3, 4)
5251
return ylist[sl]
5352

54-
func.code.print()
55-
5653
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
5754
assert GetItem in stmt_types
5855

@@ -87,16 +84,13 @@ def func():
8784
assert GetItem in stmt_types
8885

8986
before = func()
90-
func.code.print()
9187

9288
apply_getitem_optimization(func)
9389

9490
stmt_types = [type(stmt) for stmt in func.callable_region.blocks[0].stmts]
9591
assert GetItem not in stmt_types
9692
after = func()
9793

98-
func.code.print()
99-
10094
assert before == after
10195

10296

0 commit comments

Comments
 (0)