11import datetime
22import decimal
33from enum import Enum , auto
4- from typing import Optional , Sequence
4+ from typing import Optional , Sequence , Any
55
66from databricks .sql .exc import NotSupportedError
77from databricks .sql .thrift_api .TCLIService .ttypes import (
88 TSparkParameter ,
99 TSparkParameterValue ,
10+ TSparkParameterValueArg ,
1011)
1112
1213import datetime
@@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):
5455
5556
5657TAllowedParameterValue = Union [
57- str , int , float , datetime .datetime , datetime .date , bool , decimal .Decimal , None
58+ str ,
59+ int ,
60+ float ,
61+ datetime .datetime ,
62+ datetime .date ,
63+ bool ,
64+ decimal .Decimal ,
65+ None ,
66+ list ,
67+ dict ,
68+ tuple ,
5869]
5970
6071
@@ -82,6 +93,7 @@ class DbsqlParameterBase:
8293
8394 CAST_EXPR : str
8495 name : Optional [str ]
96+ value : Any
8597
8698 def as_tspark_param (self , named : bool ) -> TSparkParameter :
8799 """Returns a TSparkParameter object that can be passed to the DBR thrift server."""
@@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
98110 def _tspark_param_value (self ):
99111 return TSparkParameterValue (stringValue = str (self .value ))
100112
113+ def _tspark_value_arg (self ):
114+ """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
115+ return TSparkParameterValueArg (value = str (self .value ), type = self ._cast_expr ())
116+
101117 def _cast_expr (self ):
102118 return self .CAST_EXPR
103119
@@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
428444 CAST_EXPR = DatabricksSupportedType .TINYINT .name
429445
430446
447+ class ArrayParameter (DbsqlParameterBase ):
448+ """Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""
449+
450+ def __init__ (self , value : Sequence [Any ], name : Optional [str ] = None ):
451+ """
452+ :value:
453+ The value to bind for this parameter. This will be casted to a ARRAY.
454+ :name:
455+ If None, your query must contain a `?` marker. Like:
456+
457+ ```sql
458+ SELECT * FROM table WHERE field = ?
459+ ```
460+ If not None, your query should contain a named parameter marker. Like:
461+ ```sql
462+ SELECT * FROM table WHERE field = :my_param
463+ ```
464+
465+ The `name` argument to this function would be `my_param`.
466+ """
467+ self .name = name
468+ self .value = [dbsql_parameter_from_primitive (val ) for val in value ]
469+
470+ def as_tspark_param (self , named : bool = False ) -> TSparkParameter :
471+ """Returns a TSparkParameter object that can be passed to the DBR thrift server."""
472+
473+ tsp = TSparkParameter (type = self ._cast_expr ())
474+ tsp .arguments = [val ._tspark_value_arg () for val in self .value ]
475+
476+ if named :
477+ tsp .name = self .name
478+ tsp .ordinal = False
479+ elif not named :
480+ tsp .ordinal = True
481+ return tsp
482+
483+ def _tspark_value_arg (self ):
484+ """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
485+ tva = TSparkParameterValueArg (type = self ._cast_expr ())
486+ tva .arguments = [val ._tspark_value_arg () for val in self .value ]
487+ return tva
488+
489+ CAST_EXPR = DatabricksSupportedType .ARRAY .name
490+
491+
492+ class MapParameter (DbsqlParameterBase ):
493+ """Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""
494+
495+ def __init__ (self , value : dict , name : Optional [str ] = None ):
496+ """
497+ :value:
498+ The value to bind for this parameter. This will be casted to a MAP.
499+ :name:
500+ If None, your query must contain a `?` marker. Like:
501+
502+ ```sql
503+ SELECT * FROM table WHERE field = ?
504+ ```
505+ If not None, your query should contain a named parameter marker. Like:
506+ ```sql
507+ SELECT * FROM table WHERE field = :my_param
508+ ```
509+
510+ The `name` argument to this function would be `my_param`.
511+ """
512+ self .name = name
513+ self .value = [
514+ dbsql_parameter_from_primitive (item )
515+ for key , val in value .items ()
516+ for item in (key , val )
517+ ]
518+
519+ def as_tspark_param (self , named : bool = False ) -> TSparkParameter :
520+ """Returns a TSparkParameter object that can be passed to the DBR thrift server."""
521+
522+ tsp = TSparkParameter (type = self ._cast_expr ())
523+ tsp .arguments = [val ._tspark_value_arg () for val in self .value ]
524+ if named :
525+ tsp .name = self .name
526+ tsp .ordinal = False
527+ elif not named :
528+ tsp .ordinal = True
529+ return tsp
530+
531+ def _tspark_value_arg (self ):
532+ """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
533+ tva = TSparkParameterValueArg (type = self ._cast_expr ())
534+ tva .arguments = [val ._tspark_value_arg () for val in self .value ]
535+ return tva
536+
537+ CAST_EXPR = DatabricksSupportedType .MAP .name
538+
539+
431540class DecimalParameter (DbsqlParameterBase ):
432541 """Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""
433542
@@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
543652 # havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
544653 # its logic
545654
546- if type (value ) is int :
655+ if isinstance (value , bool ):
656+ return BooleanParameter (value = value , name = name )
657+ elif isinstance (value , int ):
547658 return dbsql_parameter_from_int (value , name = name )
548- elif type (value ) is str :
659+ elif isinstance (value , str ) :
549660 return StringParameter (value = value , name = name )
550- elif type (value ) is float :
661+ elif isinstance (value , float ) :
551662 return FloatParameter (value = value , name = name )
552- elif type (value ) is datetime .datetime :
663+ elif isinstance (value , datetime .datetime ) :
553664 return TimestampParameter (value = value , name = name )
554- elif type (value ) is datetime .date :
665+ elif isinstance (value , datetime .date ) :
555666 return DateParameter (value = value , name = name )
556- elif type (value ) is bool :
557- return BooleanParameter (value = value , name = name )
558- elif type (value ) is decimal .Decimal :
667+ elif isinstance (value , decimal .Decimal ):
559668 return DecimalParameter (value = value , name = name )
669+ elif isinstance (value , dict ):
670+ return MapParameter (value = value , name = name )
671+ elif isinstance (value , Sequence ) and not isinstance (value , str ):
672+ return ArrayParameter (value = value , name = name )
560673 elif value is None :
561674 return VoidParameter (value = value , name = name )
562-
563675 else :
564676 raise NotSupportedError (
565677 f"Could not infer parameter type from value: { value } - { type (value )} \n "
@@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
581693 TimestampNTZParameter ,
582694 TinyIntParameter ,
583695 DecimalParameter ,
696+ ArrayParameter ,
697+ MapParameter ,
584698]
585699
586700
0 commit comments