Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
16df2ff
Doc: added documentation of negative log likelihood function
jkauerl Jun 3, 2024
1ce0142
Feat: created function signature for negative log likelihood
jkauerl Jun 3, 2024
cd8240d
Feat: implementation of the negative log likelihood
jkauerl Jun 3, 2024
20e1f24
Test: created a test for the negative log likelihood
jkauerl Jun 3, 2024
f83510e
Feat: added the needed exports
jkauerl Jun 3, 2024
a5f9567
Feat: added explicit checks regarding
jkauerl Jun 4, 2024
a56b324
Test: added test cases for the checks
jkauerl Jun 4, 2024
886ae77
Feat: added another empty check
jkauerl Jun 4, 2024
314e895
Test: added tests for negative values
jkauerl Jun 18, 2024
748dbd5
Refactoring: added auxiliary function to check range of values
jkauerl Jun 18, 2024
c794c08
Docs: added link to an article with the explanation
jkauerl Jun 18, 2024
bee6f8f
Fix: fixed cargo clippy warning by moving location of function
jkauerl Jun 18, 2024
d925eb3
Docs: added the algorithm to the directory with link
jkauerl Jun 19, 2024
65671b7
Fix: reverted the file to previous format
jkauerl Jun 19, 2024
060454d
Merge branch 'master' into feat/ml/loss/nll
vil02 Jun 27, 2024
55f1d18
Fix: removed an innecesary condition
jkauerl Jul 3, 2024
346510a
Merge branch 'feat/ml/loss/nll' of https://github.com/jkauerl/algorit…
jkauerl Jul 3, 2024
8ef77e8
Fix: changed return type to Result instead of Option
jkauerl Jul 3, 2024
e11b6e0
Fix: fixed test and moved position of if statements
jkauerl Jul 3, 2024
abd523a
Feat: added suggestions and removed one condition
jkauerl Jul 4, 2024
d1dff79
Tests: added suggestion to use a macro for testing purposes
jkauerl Jul 4, 2024
171ef93
Fix: fixed clippy issue and wrapped tests
jkauerl Jul 4, 2024
7771e43
Docs: added more documentation for the binary problem
jkauerl Jul 4, 2024
555b3fb
style: remove blank line
vil02 Jul 9, 2024
8c62652
Merge branch 'master' into feat/ml/loss/nll
vil02 Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DIRECTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
* [Kl Divergence Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/kl_divergence_loss.rs)
* [Mean Absolute Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_absolute_error_loss.rs)
* [Mean Squared Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_squared_error_loss.rs)
* [Negative Log Likelihood Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/negative_log_likelihood.rs)
* Optimization
* [Adam](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/adam.rs)
* [Gradient Descent](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/gradient_descent.rs)
Expand Down
2 changes: 2 additions & 0 deletions src/machine_learning/loss_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ mod huber_loss;
mod kl_divergence_loss;
mod mean_absolute_error_loss;
mod mean_squared_error_loss;
mod negative_log_likelihood;

pub use self::hinge_loss::hng_loss;
pub use self::huber_loss::huber_loss;
pub use self::kl_divergence_loss::kld_loss;
pub use self::mean_absolute_error_loss::mae_loss;
pub use self::mean_squared_error_loss::mse_loss;
pub use self::negative_log_likelihood::neg_log_likelihood;
100 changes: 100 additions & 0 deletions src/machine_learning/loss_function/negative_log_likelihood.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Negative Log Likelihood Loss Function
//
// The `neg_log_likelihood` function calculates the Negative Log Likelyhood loss,
// which is a loss function used for classification problems in machine learning.
//
// ## Formula
//
// For a pair of actual and predicted values, represented as vectors `y_true` and
// `y_pred`, the Negative Log Likelihood loss is calculated as:
//
// - loss = `-y_true * log(y_pred) - (1 - y_true) * log(1 - y_pred)`.
//
// It returns the average loss by dividing the `total_loss` by total no. of
// elements.
//
// https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81
// http://neuralnetworksanddeeplearning.com/chap3.html
// Derivation of the formula:
// https://medium.com/@bhardwajprakarsh/negative-log-likelihood-loss-why-do-we-use-it-for-binary-classification-7625f9e3c944

pub fn neg_log_likelihood(
y_true: &[f64],
y_pred: &[f64],
) -> Result<f64, NegativeLogLikelihoodLossError> {
// Checks if the inputs are empty
if y_true.len() != y_pred.len() {
return Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength);
}
// Checks if the length of the actual and predicted values are equal
if y_pred.is_empty() {
return Err(NegativeLogLikelihoodLossError::EmptyInputs);
}
// Checks values are between 0 and 1
if !are_all_values_in_range(y_true) || !are_all_values_in_range(y_pred) {
return Err(NegativeLogLikelihoodLossError::InvalidValues);
}

let mut total_loss: f64 = 0.0;
for (p, a) in y_pred.iter().zip(y_true.iter()) {
let loss: f64 = -a * p.ln() - (1.0 - a) * (1.0 - p).ln();
total_loss += loss;
}
Ok(total_loss / (y_pred.len() as f64))
}

#[derive(Debug, PartialEq, Eq)]
pub enum NegativeLogLikelihoodLossError {
InputsHaveDifferentLength,
EmptyInputs,
InvalidValues,
}

fn are_all_values_in_range(values: &[f64]) -> bool {
values.iter().all(|&x| (0.0..=1.0).contains(&x))
}

#[cfg(test)]
mod tests {
use super::*;

macro_rules! test_with_wrong_inputs {
($($name:ident: $inputs:expr,)*) => {
$(
#[test]
fn $name() {
let (values_a, values_b, expected_error) = $inputs;
assert_eq!(neg_log_likelihood(&values_a, &values_b), expected_error);
assert_eq!(neg_log_likelihood(&values_b, &values_a), expected_error);
}
)*
}
}

test_with_wrong_inputs! {
different_length: (vec![0.9, 0.0, 0.8], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)),
different_length_one_empty: (vec![], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)),
value_greater_than_1: (vec![1.1, 0.0, 0.8], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)),
value_greater_smaller_than_0: (vec![0.9, 0.0, -0.1], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)),
empty_input: (vec![], vec![], Err(NegativeLogLikelihoodLossError::EmptyInputs)),
}

macro_rules! test_neg_log_likelihood {
($($name:ident: $inputs:expr,)*) => {
$(
#[test]
fn $name() {
let (actual_values, predicted_values, expected) = $inputs;
assert_eq!(neg_log_likelihood(&actual_values, &predicted_values).unwrap(), expected);
}
)*
}
}

test_neg_log_likelihood! {
set_0: (vec![1.0, 0.0, 1.0], vec![0.9, 0.1, 0.8], 0.14462152754328741),
set_1: (vec![1.0, 0.0, 1.0], vec![0.1, 0.2, 0.3], 1.2432338162113972),
set_2: (vec![0.0, 1.0, 0.0], vec![0.1, 0.2, 0.3], 0.6904911240102196),
set_3: (vec![1.0, 0.0, 1.0, 0.0], vec![0.9, 0.1, 0.8, 0.2], 0.164252033486018),
}
}
1 change: 1 addition & 0 deletions src/machine_learning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ pub use self::loss_function::huber_loss;
pub use self::loss_function::kld_loss;
pub use self::loss_function::mae_loss;
pub use self::loss_function::mse_loss;
pub use self::loss_function::neg_log_likelihood;
pub use self::optimization::gradient_descent;
pub use self::optimization::Adam;