@@ -287,7 +287,7 @@ def _reduced_error_pruning(self, x_val: np.ndarray, y_val: np.ndarray) -> None:
287287 improved = True
288288 while improved :
289289 improved = False
290- best_improvement = 0
290+ best_improvement = 0.0
291291 best_node = None
292292
293293 for node in internal_nodes :
@@ -364,8 +364,8 @@ def _calculate_cost_complexity(self, node: "TreeNode") -> float:
364364 return 0.0
365365
366366 # Calculate cost-complexity for children
367- left_cc = self ._calculate_cost_complexity (node .left )
368- right_cc = self ._calculate_cost_complexity (node .right )
367+ left_cc = self ._calculate_cost_complexity (node .left ) if node . left else 0.0
368+ right_cc = self ._calculate_cost_complexity (node .right ) if node . right else 0.0
369369
370370 # Calculate total cost-complexity
371371 total_cc = left_cc + right_cc + self .ccp_alpha
@@ -396,8 +396,10 @@ def _prune_high_cost_nodes(self, node: "TreeNode") -> None:
396396 node .value = 0.0 # Will be updated during fit
397397 else :
398398 # Recursively check children
399- self ._prune_high_cost_nodes (node .left )
400- self ._prune_high_cost_nodes (node .right )
399+ if node .left :
400+ self ._prune_high_cost_nodes (node .left )
401+ if node .right :
402+ self ._prune_high_cost_nodes (node .right )
401403
402404 def _get_internal_nodes (self , node : "TreeNode" ) -> list ["TreeNode" ]:
403405 """
@@ -413,8 +415,10 @@ def _get_internal_nodes(self, node: "TreeNode") -> list["TreeNode"]:
413415 return []
414416
415417 nodes = [node ]
416- nodes .extend (self ._get_internal_nodes (node .left ))
417- nodes .extend (self ._get_internal_nodes (node .right ))
418+ if node .left :
419+ nodes .extend (self ._get_internal_nodes (node .left ))
420+ if node .right :
421+ nodes .extend (self ._get_internal_nodes (node .right ))
418422 return nodes
419423
420424 def _predict_batch (self , x : np .ndarray ) -> np .ndarray :
@@ -427,6 +431,9 @@ def _predict_batch(self, x: np.ndarray) -> np.ndarray:
427431 Returns:
428432 Predictions
429433 """
434+ if self .root_ is None :
435+ raise ValueError ("Model must be fitted before predict" )
436+
430437 predictions = np .zeros (len (x ))
431438 for i , sample in enumerate (x ):
432439 predictions [i ] = self ._predict_single (sample , self .root_ )
@@ -444,11 +451,17 @@ def _predict_single(self, sample: np.ndarray, node: "TreeNode") -> int | float:
444451 Prediction
445452 """
446453 if node .is_leaf :
454+ if node .value is None :
455+ raise ValueError ("Leaf node must have a value" )
447456 return node .value
448457
449458 if sample [node .feature ] <= node .threshold :
459+ if node .left is None :
460+ raise ValueError ("Non-leaf node must have left child" )
450461 return self ._predict_single (sample , node .left )
451462 else :
463+ if node .right is None :
464+ raise ValueError ("Non-leaf node must have right child" )
452465 return self ._predict_single (sample , node .right )
453466
454467 def _calculate_error (self , y_true : np .ndarray , y_pred : np .ndarray ) -> float :
@@ -637,7 +650,7 @@ def compare_pruning_methods() -> None:
637650 tree = DecisionTreePruning (
638651 max_depth = 10 ,
639652 min_samples_leaf = 2 ,
640- pruning_method = method ,
653+ pruning_method = method , # type: ignore[arg-type]
641654 ccp_alpha = 0.01
642655 )
643656
0 commit comments