@@ -50,8 +50,8 @@ def __init__(self, alpha: float = 1.0, feature_type: str = "discrete") -> None:
5050 # Model parameters
5151 self .classes_ : np .ndarray | None = None
5252 self .class_prior_ : dict [int , float ] = {}
53- self .feature_count_ : dict [int , dict [int , int ]] = {}
54- self .feature_log_prob_ : dict [int , dict [int , float ]] = {}
53+ self .feature_count_ : dict [int , dict [int , dict [ int , int ] ]] = {}
54+ self .feature_log_prob_ : dict [int , dict [int , dict [ int , float ] ]] = {}
5555 self .feature_mean_ : dict [int , dict [int , float ]] = {}
5656 self .feature_var_ : dict [int , dict [int , float ]] = {}
5757 self .n_features_ : int | None = None
@@ -104,7 +104,7 @@ def _compute_class_prior(self, y: np.ndarray) -> dict[int, float]:
104104 return prior
105105
106106 def _compute_feature_counts (self , x : np .ndarray , y : np .ndarray
107- ) -> dict [int , dict [int , int ]]:
107+ ) -> dict [int , dict [int , dict [ int , int ] ]]:
108108 """
109109 Compute feature counts for each class (for discrete features).
110110
@@ -139,12 +139,12 @@ def _compute_feature_counts(self, x: np.ndarray, y: np.ndarray
139139
140140 for feature_value in np .unique (x [:, feature_idx ]):
141141 count = np .sum (x_class [:, feature_idx ] == feature_value )
142- feature_counts [class_label ][feature_idx ][feature_value ] = count
142+ feature_counts [class_label ][feature_idx ][int ( feature_value ) ] = int ( count )
143143
144144 return feature_counts
145145
146146 def _compute_feature_statistics (self , x : np .ndarray , y : np .ndarray
147- ) -> tuple [dict , dict ]:
147+ ) -> tuple [dict [ int , dict [ int , float ]], dict [ int , dict [ int , float ]] ]:
148148 """
149149 Compute mean and variance for each feature in each class (continuous features).
150150
@@ -296,6 +296,9 @@ def _predict_log_proba_discrete(self, x: np.ndarray) -> np.ndarray:
296296 Returns:
297297 Log probability matrix of shape (n_samples, n_classes)
298298 """
299+ if self .classes_ is None :
300+ raise ValueError ("Model must be fitted before predict" )
301+
299302 n_samples = x .shape [0 ]
300303 n_classes = len (self .classes_ )
301304 log_proba = np .zeros ((n_samples , n_classes ))
@@ -310,13 +313,14 @@ def _predict_log_proba_discrete(self, x: np.ndarray) -> np.ndarray:
310313 feature_value = x [sample_idx , feature_idx ]
311314
312315 # Get log probability for this feature value in this class
316+ feature_value_int = int (feature_value )
313317 if (
314- feature_value
318+ feature_value_int
315319 in self .feature_log_prob_ [class_label ][feature_idx ]
316320 ):
317321 log_prob = self .feature_log_prob_ [class_label ][
318322 feature_idx
319- ][feature_value ]
323+ ][feature_value_int ]
320324 else :
321325 # Unseen feature value: use Laplace smoothing
322326 all_values = list (
@@ -347,6 +351,9 @@ def _predict_log_proba_continuous(self, x: np.ndarray) -> np.ndarray:
347351 Returns:
348352 Log probability matrix of shape (n_samples, n_classes)
349353 """
354+ if self .classes_ is None :
355+ raise ValueError ("Model must be fitted before predict" )
356+
350357 n_samples = x .shape [0 ]
351358 n_classes = len (self .classes_ )
352359 log_proba = np .zeros ((n_samples , n_classes ))
@@ -362,9 +369,10 @@ def _predict_log_proba_continuous(self, x: np.ndarray) -> np.ndarray:
362369
363370 # Compute Gaussian log probabilities for all samples
364371 feature_values = x [:, feature_idx ]
365- log_proba [:, i ] += self ._gaussian_log_probability (
366- feature_values , means , variances
367- )
372+ log_proba [:, i ] += np .array ([
373+ self ._gaussian_log_probability (val , means , variances )
374+ for val in feature_values
375+ ])
368376
369377 return log_proba
370378
@@ -445,6 +453,9 @@ def predict(self, x: np.ndarray) -> np.ndarray:
445453 >>> len(predictions) == x_test.shape[0]
446454 True
447455 """
456+ if self .classes_ is None :
457+ raise ValueError ("Model must be fitted before predict" )
458+
448459 log_proba = self .predict_log_proba (x )
449460 predictions = self .classes_ [np .argmax (log_proba , axis = 1 )]
450461 return predictions
0 commit comments