Skip to content

Commit 6af3ea1

Browse files
committed
Fix all mypy type errors in decision tree
- Fixed incompatible types in assignment (best_improvement) - Added None checks for node.left and node.right - Added None check for self.root_ - Added None check for node.value - Added type ignore for Literal type in example - All 12 mypy errors resolved
1 parent ac8c8f5 commit 6af3ea1

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

machine_learning/decision_tree_pruning.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def _reduced_error_pruning(self, x_val: np.ndarray, y_val: np.ndarray) -> None:
287287
improved = True
288288
while improved:
289289
improved = False
290-
best_improvement = 0
290+
best_improvement = 0.0
291291
best_node = None
292292

293293
for node in internal_nodes:
@@ -364,8 +364,8 @@ def _calculate_cost_complexity(self, node: "TreeNode") -> float:
364364
return 0.0
365365

366366
# Calculate cost-complexity for children
367-
left_cc = self._calculate_cost_complexity(node.left)
368-
right_cc = self._calculate_cost_complexity(node.right)
367+
left_cc = self._calculate_cost_complexity(node.left) if node.left else 0.0
368+
right_cc = self._calculate_cost_complexity(node.right) if node.right else 0.0
369369

370370
# Calculate total cost-complexity
371371
total_cc = left_cc + right_cc + self.ccp_alpha
@@ -396,8 +396,10 @@ def _prune_high_cost_nodes(self, node: "TreeNode") -> None:
396396
node.value = 0.0 # Will be updated during fit
397397
else:
398398
# Recursively check children
399-
self._prune_high_cost_nodes(node.left)
400-
self._prune_high_cost_nodes(node.right)
399+
if node.left:
400+
self._prune_high_cost_nodes(node.left)
401+
if node.right:
402+
self._prune_high_cost_nodes(node.right)
401403

402404
def _get_internal_nodes(self, node: "TreeNode") -> list["TreeNode"]:
403405
"""
@@ -413,8 +415,10 @@ def _get_internal_nodes(self, node: "TreeNode") -> list["TreeNode"]:
413415
return []
414416

415417
nodes = [node]
416-
nodes.extend(self._get_internal_nodes(node.left))
417-
nodes.extend(self._get_internal_nodes(node.right))
418+
if node.left:
419+
nodes.extend(self._get_internal_nodes(node.left))
420+
if node.right:
421+
nodes.extend(self._get_internal_nodes(node.right))
418422
return nodes
419423

420424
def _predict_batch(self, x: np.ndarray) -> np.ndarray:
@@ -427,6 +431,9 @@ def _predict_batch(self, x: np.ndarray) -> np.ndarray:
427431
Returns:
428432
Predictions
429433
"""
434+
if self.root_ is None:
435+
raise ValueError("Model must be fitted before predict")
436+
430437
predictions = np.zeros(len(x))
431438
for i, sample in enumerate(x):
432439
predictions[i] = self._predict_single(sample, self.root_)
@@ -444,11 +451,17 @@ def _predict_single(self, sample: np.ndarray, node: "TreeNode") -> int | float:
444451
Prediction
445452
"""
446453
if node.is_leaf:
454+
if node.value is None:
455+
raise ValueError("Leaf node must have a value")
447456
return node.value
448457

449458
if sample[node.feature] <= node.threshold:
459+
if node.left is None:
460+
raise ValueError("Non-leaf node must have left child")
450461
return self._predict_single(sample, node.left)
451462
else:
463+
if node.right is None:
464+
raise ValueError("Non-leaf node must have right child")
452465
return self._predict_single(sample, node.right)
453466

454467
def _calculate_error(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
@@ -637,7 +650,7 @@ def compare_pruning_methods() -> None:
637650
tree = DecisionTreePruning(
638651
max_depth=10,
639652
min_samples_leaf=2,
640-
pruning_method=method,
653+
pruning_method=method, # type: ignore[arg-type]
641654
ccp_alpha=0.01
642655
)
643656

0 commit comments

Comments
 (0)