|
18 | 18 | Cluster object. |
19 | 19 | """ |
20 | 20 |
|
21 | | -from dataclasses import dataclass, field |
22 | 21 | import pathlib |
23 | | -import typing |
24 | 22 | import warnings |
| 23 | +from dataclasses import dataclass, field, fields |
| 24 | +from typing import Dict, List, Optional, Union, get_args, get_origin |
25 | 25 |
|
26 | 26 | dir = pathlib.Path(__file__).parent.parent.resolve() |
27 | 27 |
|
@@ -73,43 +73,45 @@ class ClusterConfiguration: |
73 | 73 | """ |
74 | 74 |
|
75 | 75 | name: str |
76 | | - namespace: str = None |
77 | | - head_info: list = field(default_factory=list) |
78 | | - head_cpus: typing.Union[int, str] = 2 |
79 | | - head_memory: typing.Union[int, str] = 8 |
80 | | - head_gpus: int = None # Deprecating |
81 | | - head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict) |
82 | | - machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"] |
83 | | - worker_cpu_requests: typing.Union[int, str] = 1 |
84 | | - worker_cpu_limits: typing.Union[int, str] = 1 |
85 | | - min_cpus: typing.Union[int, str] = None # Deprecating |
86 | | - max_cpus: typing.Union[int, str] = None # Deprecating |
| 76 | + namespace: Optional[str] = None |
| 77 | + head_info: List[str] = field(default_factory=list) |
| 78 | + head_cpus: Union[int, str] = 2 |
| 79 | + head_memory: Union[int, str] = 8 |
| 80 | + head_gpus: Optional[int] = None # Deprecating |
| 81 | + head_extended_resource_requests: Dict[str, int] = field(default_factory=dict) |
| 82 | + machine_types: List[str] = field( |
| 83 | + default_factory=list |
| 84 | + ) # ["m4.xlarge", "g4dn.xlarge"] |
| 85 | + worker_cpu_requests: Union[int, str] = 1 |
| 86 | + worker_cpu_limits: Union[int, str] = 1 |
| 87 | + min_cpus: Optional[Union[int, str]] = None # Deprecating |
| 88 | + max_cpus: Optional[Union[int, str]] = None # Deprecating |
87 | 89 | num_workers: int = 1 |
88 | | - worker_memory_requests: typing.Union[int, str] = 2 |
89 | | - worker_memory_limits: typing.Union[int, str] = 2 |
90 | | - min_memory: typing.Union[int, str] = None # Deprecating |
91 | | - max_memory: typing.Union[int, str] = None # Deprecating |
92 | | - num_gpus: int = None # Deprecating |
| 90 | + worker_memory_requests: Union[int, str] = 2 |
| 91 | + worker_memory_limits: Union[int, str] = 2 |
| 92 | + min_memory: Optional[Union[int, str]] = None # Deprecating |
| 93 | + max_memory: Optional[Union[int, str]] = None # Deprecating |
| 94 | + num_gpus: Optional[int] = None # Deprecating |
93 | 95 | template: str = f"{dir}/templates/base-template.yaml" |
94 | 96 | appwrapper: bool = False |
95 | | - envs: dict = field(default_factory=dict) |
| 97 | + envs: Dict[str, str] = field(default_factory=dict) |
96 | 98 | image: str = "" |
97 | | - image_pull_secrets: list = field(default_factory=list) |
| 99 | + image_pull_secrets: List[str] = field(default_factory=list) |
98 | 100 | write_to_file: bool = False |
99 | 101 | verify_tls: bool = True |
100 | | - labels: dict = field(default_factory=dict) |
101 | | - worker_extended_resource_requests: typing.Dict[str, int] = field( |
102 | | - default_factory=dict |
103 | | - ) |
104 | | - extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict) |
| 102 | + labels: Dict[str, str] = field(default_factory=dict) |
| 103 | + worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict) |
| 104 | + extended_resource_mapping: Dict[str, str] = field(default_factory=dict) |
105 | 105 | overwrite_default_resource_mapping: bool = False |
| 106 | + local_queue: Optional[str] = None |
106 | 107 |
|
107 | 108 | def __post_init__(self): |
108 | 109 | if not self.verify_tls: |
109 | 110 | print( |
110 | 111 | "Warning: TLS verification has been disabled - Endpoint checks will be bypassed" |
111 | 112 | ) |
112 | 113 |
|
| 114 | + self._validate_types() |
113 | 115 | self._memory_to_string() |
114 | 116 | self._str_mem_no_unit_add_GB() |
115 | 117 | self._memory_to_resource() |
@@ -139,9 +141,7 @@ def _combine_extended_resource_mapping(self): |
139 | 141 | **self.extended_resource_mapping, |
140 | 142 | } |
141 | 143 |
|
142 | | - def _validate_extended_resource_requests( |
143 | | - self, extended_resources: typing.Dict[str, int] |
144 | | - ): |
| 144 | + def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]): |
145 | 145 | for k in extended_resources.keys(): |
146 | 146 | if k not in self.extended_resource_mapping.keys(): |
147 | 147 | raise ValueError( |
@@ -206,4 +206,34 @@ def _memory_to_resource(self): |
206 | 206 | warnings.warn("max_memory is being deprecated, use worker_memory_limits") |
207 | 207 | self.worker_memory_limits = f"{self.max_memory}G" |
208 | 208 |
|
209 | | - local_queue: str = None |
| 209 | + def _validate_types(self): |
| 210 | + """Validate the types of all fields in the ClusterConfiguration dataclass.""" |
| 211 | + for field_info in fields(self): |
| 212 | + value = getattr(self, field_info.name) |
| 213 | + expected_type = field_info.type |
| 214 | + if not self._is_type(value, expected_type): |
| 215 | + raise TypeError( |
| 216 | + f"'{field_info.name}' should be of type {expected_type}" |
| 217 | + ) |
| 218 | + |
| 219 | + @staticmethod |
| 220 | + def _is_type(value, expected_type): |
| 221 | + """Check if the value matches the expected type.""" |
| 222 | + |
| 223 | + def check_type(value, expected_type): |
| 224 | + origin_type = get_origin(expected_type) |
| 225 | + args = get_args(expected_type) |
| 226 | + if origin_type is Union: |
| 227 | + return any(check_type(value, union_type) for union_type in args) |
| 228 | + if origin_type is list: |
| 229 | + return all(check_type(elem, args[0]) for elem in value) |
| 230 | + if origin_type is dict: |
| 231 | + return all( |
| 232 | + check_type(k, args[0]) and check_type(v, args[1]) |
| 233 | + for k, v in value.items() |
| 234 | + ) |
| 235 | + if origin_type is tuple: |
| 236 | + return all(check_type(elem, etype) for elem, etype in zip(value, args)) |
| 237 | + return isinstance(value, expected_type) |
| 238 | + |
| 239 | + return check_type(value, expected_type) |
0 commit comments