Skip to content

Commit 089b77e

Browse files
authored
Fix(athena): DDL fixes (#4132)
* Fix(athena): DDL fixes * PR feedback
1 parent ba015dc commit 089b77e

File tree

3 files changed

+118
-8
lines changed

3 files changed

+118
-8
lines changed

sqlglot/dialects/athena.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,40 @@ def _generate_as_hive(expression: exp.Expression) -> bool:
2020
else:
2121
return expression.kind != "VIEW" # CREATE VIEW is never Hive but CREATE SCHEMA etc is
2222

23-
elif isinstance(expression, exp.Alter) or isinstance(expression, exp.Drop):
24-
return True # all ALTER and DROP statements are Hive
23+
# https://docs.aws.amazon.com/athena/latest/ug/ddl-reference.html
24+
elif isinstance(expression, (exp.Alter, exp.Drop, exp.Describe)):
25+
if isinstance(expression, exp.Drop) and expression.kind == "VIEW":
26+
# DROP VIEW is Trino (I guess because CREATE VIEW is)
27+
return False
28+
29+
# Everything else is Hive
30+
return True
2531

2632
return False
2733

2834

35+
def _location_property_sql(self: Athena.Generator, e: exp.LocationProperty):
36+
# If table_type='iceberg', the LocationProperty is called 'location'
37+
# Otherwise, it's called 'external_location'
38+
# ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html
39+
40+
prop_name = "external_location"
41+
42+
if isinstance(e.parent, exp.Properties):
43+
table_type_property = next(
44+
(
45+
p
46+
for p in e.parent.expressions
47+
if isinstance(p, exp.Property) and p.name == "table_type"
48+
),
49+
None,
50+
)
51+
if table_type_property and table_type_property.text("value") == "iceberg":
52+
prop_name = "location"
53+
54+
return f"{prop_name}={self.sql(e, 'this')}"
55+
56+
2957
class Athena(Trino):
3058
"""
3159
Over the years, it looks like AWS has taken various execution engines, bolted on AWS-specific modifications and then
@@ -48,7 +76,7 @@ class Athena(Trino):
4876
Trino:
4977
- Uses double quotes to quote identifiers
5078
- Used for DDL operations that involve SELECT queries, eg:
51-
- CREATE VIEW
79+
- CREATE VIEW / DROP VIEW
5280
- CREATE TABLE... AS SELECT
5381
- Used for DML operations
5482
- SELECT, INSERT, UPDATE, DELETE, MERGE
@@ -79,27 +107,40 @@ class Parser(Trino.Parser):
79107
TokenType.USING: lambda self: self._parse_as_command(self._prev),
80108
}
81109

110+
class _HiveGenerator(Hive.Generator):
111+
def alter_sql(self, expression: exp.Alter) -> str:
112+
# package any ALTER TABLE ADD actions into a Schema object
113+
# so it gets generated as `ALTER TABLE .. ADD COLUMNS(...)`
114+
# instead of `ALTER TABLE ... ADD COLUMN` which is invalid syntax on Athena
115+
if isinstance(expression, exp.Alter) and expression.kind == "TABLE":
116+
if expression.actions and isinstance(expression.actions[0], exp.ColumnDef):
117+
new_actions = exp.Schema(expressions=expression.actions)
118+
expression.set("actions", [new_actions])
119+
120+
return super().alter_sql(expression)
121+
82122
class Generator(Trino.Generator):
83123
"""
84124
Generate queries for the Athena Trino execution engine
85125
"""
86126

87-
TYPE_MAPPING = {
88-
**Trino.Generator.TYPE_MAPPING,
89-
exp.DataType.Type.TEXT: "STRING",
127+
PROPERTIES_LOCATION = {
128+
**Trino.Generator.PROPERTIES_LOCATION,
129+
exp.LocationProperty: exp.Properties.Location.POST_WITH,
90130
}
91131

92132
TRANSFORMS = {
93133
**Trino.Generator.TRANSFORMS,
94-
exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}",
134+
exp.FileFormatProperty: lambda self, e: f"format={self.sql(e, 'this')}",
135+
exp.LocationProperty: _location_property_sql,
95136
}
96137

97138
def __init__(self, *args, **kwargs):
98139
super().__init__(*args, **kwargs)
99140

100141
hive_kwargs = {**kwargs, "dialect": "hive"}
101142

102-
self._hive_generator = Hive.Generator(*args, **hive_kwargs)
143+
self._hive_generator = Athena._HiveGenerator(*args, **hive_kwargs)
103144

104145
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
105146
if _generate_as_hive(expression):

sqlglot/expressions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4398,6 +4398,15 @@ class Alter(Expression):
43984398
"not_valid": False,
43994399
}
44004400

4401+
@property
4402+
def kind(self) -> t.Optional[str]:
4403+
kind = self.args.get("kind")
4404+
return kind and kind.upper()
4405+
4406+
@property
4407+
def actions(self) -> t.List[Expression]:
4408+
return self.args.get("actions") or []
4409+
44014410

44024411
class AddConstraint(Expression):
44034412
arg_types = {"expressions": True}

tests/dialects/test_athena.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from sqlglot import exp
12
from tests.dialects.test_dialect import Validator
23

34

@@ -68,6 +69,23 @@ def test_ddl(self):
6869
"CREATE TABLE foo AS WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo"
6970
)
7071

72+
# ALTER TABLE ADD COLUMN not supported, it needs to be generated as ALTER TABLE ADD COLUMNS
73+
self.validate_identity(
74+
"ALTER TABLE `foo`.`bar` ADD COLUMN `end_ts` BIGINT",
75+
write_sql="ALTER TABLE `foo`.`bar` ADD COLUMNS (`end_ts` BIGINT)",
76+
)
77+
78+
def test_dml(self):
79+
self.validate_all(
80+
"SELECT CAST(ds AS VARCHAR) AS ds FROM (VALUES ('2022-01-01')) AS t(ds)",
81+
read={"": "SELECT CAST(ds AS STRING) AS ds FROM (VALUES ('2022-01-01')) AS t(ds)"},
82+
write={
83+
"hive": "SELECT CAST(ds AS STRING) AS ds FROM (VALUES ('2022-01-01')) AS t(ds)",
84+
"trino": "SELECT CAST(ds AS VARCHAR) AS ds FROM (VALUES ('2022-01-01')) AS t(ds)",
85+
"athena": "SELECT CAST(ds AS VARCHAR) AS ds FROM (VALUES ('2022-01-01')) AS t(ds)",
86+
},
87+
)
88+
7189
def test_ddl_quoting(self):
7290
self.validate_identity("CREATE SCHEMA `foo`")
7391
self.validate_identity("CREATE SCHEMA foo")
@@ -111,6 +129,10 @@ def test_ddl_quoting(self):
111129
'CREATE VIEW `foo` AS SELECT "id" FROM `tbl`',
112130
write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"',
113131
)
132+
self.validate_identity(
133+
"DROP VIEW IF EXISTS `foo`.`bar`",
134+
write_sql='DROP VIEW IF EXISTS "foo"."bar"',
135+
)
114136

115137
self.validate_identity(
116138
'ALTER TABLE "foo" ADD COLUMNS ("id" STRING)',
@@ -128,6 +150,8 @@ def test_ddl_quoting(self):
128150
write_sql='CREATE TABLE "foo" AS WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"',
129151
)
130152

153+
self.validate_identity("DESCRIBE foo.bar", write_sql="DESCRIBE `foo`.`bar`", identify=True)
154+
131155
def test_dml_quoting(self):
132156
self.validate_identity("SELECT a AS foo FROM tbl")
133157
self.validate_identity('SELECT "a" AS "foo" FROM "tbl"')
@@ -167,3 +191,39 @@ def test_dml_quoting(self):
167191
write_sql='WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"',
168192
identify=True,
169193
)
194+
195+
def test_ctas(self):
196+
# Hive tables use 'external_location' to specify the table location, Iceberg tables use 'location' to specify the table location
197+
# The 'table_type' property is used to determine if it's a Hive or an Iceberg table
198+
# ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties
199+
ctas_hive = exp.Create(
200+
this=exp.to_table("foo.bar"),
201+
kind="TABLE",
202+
properties=exp.Properties(
203+
expressions=[
204+
exp.FileFormatProperty(this=exp.Literal.string("parquet")),
205+
exp.LocationProperty(this=exp.Literal.string("s3://foo")),
206+
]
207+
),
208+
expression=exp.select("1"),
209+
)
210+
self.assertEqual(
211+
ctas_hive.sql(dialect=self.dialect, identify=True),
212+
"CREATE TABLE \"foo\".\"bar\" WITH (format='parquet', external_location='s3://foo') AS SELECT 1",
213+
)
214+
215+
ctas_iceberg = exp.Create(
216+
this=exp.to_table("foo.bar"),
217+
kind="TABLE",
218+
properties=exp.Properties(
219+
expressions=[
220+
exp.Property(this=exp.var("table_type"), value=exp.Literal.string("iceberg")),
221+
exp.LocationProperty(this=exp.Literal.string("s3://foo")),
222+
]
223+
),
224+
expression=exp.select("1"),
225+
)
226+
self.assertEqual(
227+
ctas_iceberg.sql(dialect=self.dialect, identify=True),
228+
"CREATE TABLE \"foo\".\"bar\" WITH (table_type='iceberg', location='s3://foo') AS SELECT 1",
229+
)

0 commit comments

Comments
 (0)