Skip to content

Commit ba015dc

Browse files
authored
Feat!: add returning to merge expression builder (#4125)
* Add returning to merge * Fix * Fix * Fix test * Fmt * Quote self
1 parent 22c456d commit ba015dc

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

sqlglot/expressions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
"""
1212

1313
from __future__ import annotations
14-
1514
import datetime
1615
import math
1716
import numbers
@@ -36,6 +35,7 @@
3635
from sqlglot.tokens import Token, TokenError
3736

3837
if t.TYPE_CHECKING:
38+
from typing_extensions import Self
3939
from sqlglot._typing import E, Lit
4040
from sqlglot.dialects.dialect import DialectType
4141

@@ -1368,7 +1368,7 @@ def returning(
13681368
dialect: DialectType = None,
13691369
copy: bool = True,
13701370
**opts,
1371-
) -> DML:
1371+
) -> "Self":
13721372
"""
13731373
Set the RETURNING expression. Not supported by all dialects.
13741374
@@ -6276,7 +6276,7 @@ class Use(Expression):
62766276
arg_types = {"this": True, "kind": False}
62776277

62786278

6279-
class Merge(Expression):
6279+
class Merge(DML):
62806280
arg_types = {
62816281
"this": True,
62826282
"using": True,
@@ -6840,9 +6840,7 @@ def delete(
68406840
if where:
68416841
delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
68426842
if returning:
6843-
delete_expr = t.cast(
6844-
Delete, delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
6845-
)
6843+
delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
68466844
return delete_expr
68476845

68486846

@@ -6885,7 +6883,7 @@ def insert(
68856883
insert = Insert(this=this, expression=expr, overwrite=overwrite)
68866884

68876885
if returning:
6888-
insert = t.cast(Insert, insert.returning(returning, dialect=dialect, copy=False, **opts))
6886+
insert = insert.returning(returning, dialect=dialect, copy=False, **opts)
68896887

68906888
return insert
68916889

@@ -6895,6 +6893,7 @@ def merge(
68956893
into: ExpOrStr,
68966894
using: ExpOrStr,
68976895
on: ExpOrStr,
6896+
returning: t.Optional[ExpOrStr] = None,
68986897
dialect: DialectType = None,
68996898
copy: bool = True,
69006899
**opts,
@@ -6915,14 +6914,15 @@ def merge(
69156914
into: The target table to merge data into.
69166915
using: The source table to merge data from.
69176916
on: The join condition for the merge.
6917+
returning: The columns to return from the merge.
69186918
dialect: The dialect used to parse the input expressions.
69196919
copy: Whether to copy the expression.
69206920
**opts: Other options to use to parse the input expressions.
69216921
69226922
Returns:
69236923
Merge: The syntax tree for the MERGE statement.
69246924
"""
6925-
return Merge(
6925+
merge = Merge(
69266926
this=maybe_parse(into, dialect=dialect, copy=copy, **opts),
69276927
using=maybe_parse(using, dialect=dialect, copy=copy, **opts),
69286928
on=maybe_parse(on, dialect=dialect, copy=copy, **opts),
@@ -6931,6 +6931,10 @@ def merge(
69316931
for when_expr in when_exprs
69326932
],
69336933
)
6934+
if returning:
6935+
merge = merge.returning(returning, dialect=dialect, copy=False, **opts)
6936+
6937+
return merge
69346938

69356939

69366940
def condition(

sqlglot/generator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3625,13 +3625,15 @@ def merge_sql(self, expression: exp.Merge) -> str:
36253625
using = f"USING {self.sql(expression, 'using')}"
36263626
on = f"ON {self.sql(expression, 'on')}"
36273627
expressions = self.expressions(expression, sep=" ", indent=False)
3628+
returning = self.sql(expression, "returning")
3629+
if returning:
3630+
expressions = f"{expressions}{returning}"
3631+
36283632
sep = self.sep()
3629-
returning = self.expressions(expression, key="returning", indent=False)
3630-
returning = f"RETURNING {returning}" if returning else ""
36313633

36323634
return self.prepend_ctes(
36333635
expression,
3634-
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}{sep}{returning}",
3636+
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}",
36353637
)
36363638

36373639
@unsupported_args("format")

sqlglot/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6846,7 +6846,7 @@ def _parse_merge(self) -> exp.Merge:
68466846
using=using,
68476847
on=on,
68486848
expressions=self._parse_when_matched(),
6849-
returning=self._match(TokenType.RETURNING) and self._parse_csv(self._parse_bitwise),
6849+
returning=self._parse_returning(),
68506850
)
68516851

68526852
def _parse_when_matched(self) -> t.List[exp.When]:

tests/test_build.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,16 @@ def test_build(self):
761761
),
762762
"MERGE INTO target_table AS target USING source_table AS source ON target.id = source.id WHEN MATCHED THEN UPDATE SET target.name = source.name",
763763
),
764+
(
765+
lambda: exp.merge(
766+
"WHEN MATCHED THEN UPDATE SET target.name = source.name",
767+
into=exp.table_("target_table").as_("target"),
768+
using=exp.table_("source_table").as_("source"),
769+
on="target.id = source.id",
770+
returning="target.*",
771+
),
772+
"MERGE INTO target_table AS target USING source_table AS source ON target.id = source.id WHEN MATCHED THEN UPDATE SET target.name = source.name RETURNING target.*",
773+
),
764774
]:
765775
with self.subTest(sql):
766776
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)

0 commit comments

Comments
 (0)