Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/code-quality-checks.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Code Quality Checks
on: [push]
jobs:
run-tests:
run-unit-tests:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does unit test mean in this context?

Is this mock unit test without the need for databricks account?
or is it integration test e2e which require databricks account

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit means no databricks account is required.
e2e means databricks account is required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm updating the CONTRIBUTING doc with this info.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Jesse.

runs-on: ubuntu-latest
steps:
#----------------------------------------------
Expand Down Expand Up @@ -48,7 +48,7 @@ jobs:
# run test suite
#----------------------------------------------
- name: Run tests
run: poetry run pytest tests/
run: poetry run python -m pytest tests/unit
check-linting:
runs-on: ubuntu-latest
steps:
Expand Down
55 changes: 55 additions & 0 deletions .github/workflows/e2e-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: Core e2e Tests
on: [push]
jobs:
run-core-e2e-tests:
runs-on: ubuntu-latest
env:
host: ${{ secrets.E2E_TEST_HOST }}
http_path: ${{ secrets.E2E_TEST_HTTP_PATH }}
access_token: ${{ secrets.E2E_TEST_TOKEN }}
steps:
#----------------------------------------------
# check-out repo and set-up python
#----------------------------------------------
- name: Check out repository
uses: actions/checkout@v2
- name: Set up python
id: setup-python
uses: actions/setup-python@v2
with:
python-version: 3.7
#----------------------------------------------
# ----- install & configure poetry -----
#----------------------------------------------
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true

#----------------------------------------------
# load cached venv if cache exists
#----------------------------------------------
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v2
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }}
#----------------------------------------------
# install dependencies if cache does not exist
#----------------------------------------------
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
#----------------------------------------------
# install your root project, if required
#----------------------------------------------
- name: Install library
run: poetry install --no-interaction
#----------------------------------------------
# run test suite
#----------------------------------------------
- name: Run tests
run: poetry run python -m pytest tests/e2e/driver_tests.py::PySQLCoreTestSuite
Empty file added tests/__init__.py
Empty file.
Empty file added tests/e2e/common/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions tests/e2e/common/core_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import decimal
import datetime
from collections import namedtuple

TypeFailure = namedtuple(
"TypeFailure", "query,columnType,resultType,resultValue,"
"actualValue,actualType,description,conf")
ResultFailure = namedtuple(
"ResultFailure", "query,columnType,resultType,resultValue,"
"actualValue,actualType,description,conf")
ExecFailure = namedtuple(
"ExecFailure", "query,columnType,resultType,resultValue,"
"actualValue,actualType,description,conf,error")


class SmokeTestMixin:
def test_smoke_test(self):
with self.cursor() as cursor:
cursor.execute("select 0")
rows = cursor.fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], 0)


class CoreTestMixin:
"""
This mixin expects to be mixed with a CursorTest-like class with the following extra attributes:
validate_row_value_type: bool
validate_result: bool
"""

# A list of (subquery, column_type, python_type, expected_result)
# To be executed as "SELECT {} FROM RANGE(...)" and "SELECT {}"
range_queries = [
("TRUE", 'boolean', bool, True),
("cast(1 AS TINYINT)", 'byte', int, 1),
("cast(1000 AS SMALLINT)", 'short', int, 1000),
("cast(100000 AS INTEGER)", 'integer', int, 100000),
("cast(10000000000000 AS BIGINT)", 'long', int, 10000000000000),
("cast(100.001 AS DECIMAL(6, 3))", 'decimal', decimal.Decimal, 100.001),
("date '2020-02-20'", 'date', datetime.date, datetime.date(2020, 2, 20)),
("unhex('f000')", 'binary', bytes, b'\xf0\x00'), # pyodbc internal mismatch
("'foo'", 'string', str, 'foo'),
# SPARK-32130: 6.x: "4 weeks 2 days" vs 7.x: "30 days"
# ("interval 30 days", str, str, "interval 4 weeks 2 days"),
# ("interval 3 days", str, str, "interval 3 days"),
("CAST(NULL AS DOUBLE)", 'double', type(None), None),
]

# Full queries, only the first column of the first row is checked
queries = [("NULL UNION (SELECT 1) order by 1", 'integer', type(None), None)]

def run_tests_on_queries(self, default_conf):
failures = []
for (query, columnType, rowValueType, answer) in self.range_queries:
with self.cursor(default_conf) as cursor:
failures.extend(
self.run_query(cursor, query, columnType, rowValueType, answer, default_conf))
failures.extend(
self.run_range_query(cursor, query, columnType, rowValueType, answer,
default_conf))

for (query, columnType, rowValueType, answer) in self.queries:
with self.cursor(default_conf) as cursor:
failures.extend(
self.run_query(cursor, query, columnType, rowValueType, answer, default_conf))

if failures:
self.fail("Failed testing result set with Arrow. "
"Failed queries: {}".format("\n\n".join([str(f) for f in failures])))

def run_query(self, cursor, query, columnType, rowValueType, answer, conf):
full_query = "SELECT {}".format(query)
expected_column_types = self.expected_column_types(columnType)
try:
cursor.execute(full_query)
(result, ) = cursor.fetchone()
if not all(cursor.description[0][1] == type for type in expected_column_types):
return [
TypeFailure(full_query, expected_column_types, rowValueType, answer, result,
type(result), cursor.description, conf)
]
if self.validate_row_value_type and type(result) is not rowValueType:
return [
TypeFailure(full_query, expected_column_types, rowValueType, answer, result,
type(result), cursor.description, conf)
]
if self.validate_result and str(answer) != str(result):
return [
ResultFailure(full_query, query, expected_column_types, rowValueType, answer,
result, type(result), cursor.description, conf)
]
return []
except Exception as e:
return [
ExecFailure(full_query, columnType, rowValueType, None, None, None,
cursor.description, conf, e)
]

def run_range_query(self, cursor, query, columnType, rowValueType, expected, conf):
full_query = "SELECT {}, id FROM RANGE({})".format(query, 5000)
expected_column_types = self.expected_column_types(columnType)
try:
cursor.execute(full_query)
while True:
rows = cursor.fetchmany(1000)
if len(rows) <= 0:
break
for index, (result, id) in enumerate(rows):
if not all(cursor.description[0][1] == type for type in expected_column_types):
return [
TypeFailure(full_query, expected_column_types, rowValueType, expected,
result, type(result), cursor.description, conf)
]
if self.validate_row_value_type and type(result) \
is not rowValueType:
return [
TypeFailure(full_query, expected_column_types, rowValueType, expected,
result, type(result), cursor.description, conf)
]
if self.validate_result and str(expected) != str(result):
return [
ResultFailure(full_query, expected_column_types, rowValueType, expected,
result, type(result), cursor.description, conf)
]
return []
except Exception as e:
return [
ExecFailure(full_query, columnType, rowValueType, None, None, None,
cursor.description, conf, e)
]
48 changes: 48 additions & 0 deletions tests/e2e/common/decimal_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from decimal import Decimal

import pyarrow


class DecimalTestsMixin:
decimal_and_expected_results = [
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
# TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False
#("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)),
("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)),
("1 AS DECIMAL(1, 0)", Decimal("1"), pyarrow.decimal128(1, 0)),
("0.00000 AS DECIMAL(5, 3)", Decimal("0.000"), pyarrow.decimal128(5, 3)),
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
]

multi_decimals_and_expected_results = [
(["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
[Decimal("1.00"), Decimal("100.001"), None], pyarrow.decimal128(6, 3)),
(["1 AS DECIMAL(6, 3)", "2 AS DECIMAL(5, 2)"], [Decimal('1.000'),
Decimal('2.000')], pyarrow.decimal128(6,
3)),
]

def test_decimals(self):
with self.cursor({}) as cursor:
for (decimal, expected_value, expected_type) in self.decimal_and_expected_results:
query = "SELECT CAST ({})".format(decimal)
with self.subTest(query=query):
cursor.execute(query)
table = cursor.fetchmany_arrow(1)
self.assertEqual(table.field(0).type, expected_type)
self.assertEqual(table.to_pydict().popitem()[1][0], expected_value)

def test_multi_decimals(self):
with self.cursor({}) as cursor:
for (decimals, expected_values,
expected_type) in self.multi_decimals_and_expected_results:
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])
query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str)

with self.subTest(query=query):
cursor.execute(query)
table = cursor.fetchall_arrow()
self.assertEqual(table.field(0).type, expected_type)
self.assertEqual(table.to_pydict().popitem()[1], expected_values)
100 changes: 100 additions & 0 deletions tests/e2e/common/large_queries_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
import math
import time

log = logging.getLogger(__name__)


class LargeQueriesMixin:
"""
This mixin expects to be mixed with a CursorTest-like class
"""

def fetch_rows(self, cursor, row_count, fetchmany_size):
"""
A generator for rows. Fetches until the end or up to 5 minutes.
"""
# TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone
# in the Python client
max_fetch_time = 5 * 60 # Fetch for at most 5 minutes

rows = self.get_some_rows(cursor, fetchmany_size)
start_time = time.time()
n = 0
while rows:
for row in rows:
n += 1
yield row
if time.time() - start_time >= max_fetch_time:
log.warning("Fetching rows timed out")
break
rows = self.get_some_rows(cursor, fetchmany_size)
if not rows:
# Read all the rows, row_count should match
self.assertEqual(n, row_count)

num_fetches = max(math.ceil(n / 10000), 1)
latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1
print('Fetched {} rows with an avg latency of {} per fetch, '.format(n, latency_ms) +
'assuming 10K fetch size.')

def test_query_with_large_wide_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8192 # B
rows = resultSize // width
cols = width // 36

# Set the fetchmany_size to get 10MB of data a go
fetchmany_size = 10 * 1024 * 1024 // width
# This is used by PyHive tests to determine the buffer size
self.arraysize = 1000
with self.cursor() as cursor:
uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)])
cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows))
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle.
self.assertEqual(len(row[1]), 36)

def test_query_with_large_narrow_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8 # sizeof(long)
rows = resultSize / width

# Set the fetchmany_size to get 10MB of data a go
fetchmany_size = 10 * 1024 * 1024 // width
# This is used by PyHive tests to determine the buffer size
self.arraysize = 10000000
with self.cursor() as cursor:
cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows))
for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)):
self.assertEqual(row[0], row_id)

def test_long_running_query(self):
""" Incrementally increase query size until it takes at least 5 minutes,
and asserts that the query completes successfully.
"""
minutes = 60
min_duration = 5 * minutes

duration = -1
scale0 = 10000
scale_factor = 1
with self.cursor() as cursor:
while duration < min_duration:
self.assertLess(scale_factor, 512, msg="Detected infinite loop")
start = time.time()

cursor.execute("""SELECT count(*)
FROM RANGE({scale}) x
JOIN RANGE({scale0}) y
ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%"
""".format(scale=scale_factor * scale0, scale0=scale0))

n, = cursor.fetchone()
self.assertEqual(n, 0)

duration = time.time() - start
current_fraction = duration / min_duration
print('Took {} s with scale factor={}'.format(duration, scale_factor))
# Extrapolate linearly to reach 5 min and add 50% padding to push over the limit
scale_factor = math.ceil(1.5 * scale_factor / current_fraction)
Loading