-
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Implementation of Marginal Ranking Loss Function #733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of Marginal Ranking Loss Function #733
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #733 +/- ##
=======================================
Coverage 94.95% 94.95%
=======================================
Files 303 304 +1
Lines 22533 22549 +16
=======================================
+ Hits 21396 21412 +16
Misses 1137 1137 ☔ View full report in Codecov by Sentry. |
|
Let's first focus on #734. |
sozelfist
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to adjust these suggestions on your own. Note that I am just a contributor not a reviewer or maintainer of this repo.
| pub fn mrg_ranking_loss(x_first: &[f64], x_second: &[f64], margin: f64, y_true: f64) -> f64 { | ||
| let mut total_loss: f64 = 0.0; | ||
| for (f, s) in x_first.iter().zip(x_second.iter()) { | ||
| let loss: f64 = (margin - y_true * (f - s)).max(0.0); | ||
| total_loss += loss; | ||
| } | ||
| total_loss / (x_first.len() as f64) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about returning an Option<f64> instead of just f64? When the dimensions of x_first and x_second do not match or x_first is empty, the function return None indicates that there is no computed loss returned. Here is my suggestion:
pub fn mrg_ranking_loss(x_first: &[f64], x_second: &[f64], margin: f64, y_true: f64) -> Option<f64> {
if x_first.len() != x_second.len() || x_first.is_empty() {
return None;
}
let mut total_loss: f64 = 0.0;
for (f, s) in x_first.iter().zip(x_second.iter()) {
let loss: f64 = (margin - y_true * (f - s)).max(0.0);
total_loss += loss;
}
Some(total_loss / (x_first.len() as f64))
}You can review and apply this suggestion on your own.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this approach because it looks more rust-like, thanks for the suggestion!!!
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
|
|
||
| #[test] | ||
| fn test_marginal_ranking_loss() { | ||
| let first_values: Vec<f64> = vec![1.0, 2.0, 3.0]; | ||
| let second_values: Vec<f64> = vec![2.0, 3.0, 4.0]; | ||
| let margin: f64 = 1.0; | ||
| let actual_value: f64 = -1.0; | ||
| assert_eq!( | ||
| mrg_ranking_loss(&first_values, &second_values, margin, actual_value), | ||
| 0.0 | ||
| ); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can increase the number of tests by utilizing a macro similar to this one. Please note that I have given the tests descriptive names for better clarification. Here is my suggestions:
#[cfg(test)]
mod tests {
use super::*;
macro_rules! test_mrg_ranking_loss {
($($name:ident: $test_case:expr,)*) => {
$(
#[test]
fn $name() {
let (x_first, x_second, margin, y_true, expected) = $test_case;
let result = mrg_ranking_loss(&x_first, &x_second, margin, y_true);
assert_eq!(result, expected);
}
)*
};
}
test_mrg_ranking_loss! {
test_simple_ranking_example: (vec![3.0, 5.0, 2.0], vec![2.0, 4.0, 1.0], 1.0, 1.0, Some(0.0)),
test_negative_margin: (vec![1.0, 2.0, 3.0], vec![3.0, 2.0, 1.0], 0.5, -1.0, Some(1.0)),
test_identical_scores: (vec![1.0, 1.0, 1.0], vec![1.0, 1.0, 1.0], 1.0, 1.0, Some(1.0)),
test_mixed_y_true: (vec![3.0, 5.0, 7.0], vec![2.0, 6.0, 1.0], 1.0, -1.0, Some(3.0)),
test_different_lengths: (vec![1.0, 2.0], vec![3.0], 1.0, 1.0, None),
test_empty_vectors: (vec![], vec![], 1.0, 1.0, None),
}
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super, will add the macro while making sure all the tests pass. Thanks for the suggestion!!!
| // Marginal Ranking | ||
| // | ||
| // The 'mrg_ranking_loss' function calculates the Marginal Ranking loss, which is a | ||
| // loss function used for ranking problems in machine learning. | ||
| // | ||
| // ## Formula | ||
| // | ||
| // For a pair of values `x_first` and `x_second`, `margin`, and `y_true`, | ||
| // the Marginal Ranking loss is calculated as: | ||
| // | ||
| // - loss = `max(0, -y_true * (x_first - x_second) + margin)`. | ||
| // | ||
| // It returns the average loss by dividing the `total_loss` by total no. of | ||
| // elements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor change we should make is to provide a link to a page that contains detailed information about the algorithm instead of writing a complex document.
Pull Request Template
Description
PR including the implementation for the Marginal Ranking Loss Function of this issue #559 . It creates a new file called
marginal_ranking.rswith the function calledmrg_ranking_loss. This implementation is inspired on the documentation of PyTorch of Marginal Ranking.Type of change
Checklist:
cargo clippy --all -- -D warningsjust before my last commit and fixed any issue that was found.cargo fmtjust before my last commit.cargo testjust before my last commit and all tests passed.mod.rsfile within its own folder, and in any parent folder(s).DIRECTORY.mdwith the correct link.COUNTRIBUTING.mdand my code follows its guidelines.PD: Im new to Rust and Open Source so I would love a clarification regarding the items I did not checked. Also let me know if I should squash the commits.