-
Notifications
You must be signed in to change notification settings - Fork 16
Add basic ranking example. #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,224 @@ | ||
| """ | ||
| # Recommending movies: ranking | ||
|
|
||
| Recommender systems are often composed of two stages: | ||
|
|
||
| 1. The retrieval stage is responsible for selecting an initial set of hundreds | ||
| of candidates from all possible candidates. The main objective of this model | ||
| is to efficiently weed out all candidates that the user is not interested in. | ||
| Because the retrieval model may be dealing with millions of candidates, it | ||
| has to be computationally efficient. | ||
| 2. The ranking stage takes the outputs of the retrieval model and fine-tunes | ||
| them to select the best possible handful of recommendations. Its task is to | ||
| narrow down the set of items the user may be interested in to a shortlist of | ||
| likely candidates. | ||
|
|
||
| In this tutorial, we're going to focus on the first stage, retrieval. If you are | ||
| interested in the ranking stage, have a look at our | ||
| [retrieval](https://github.com/keras-team/keras-rs/blob/main/examples/basic_retrieval.py) | ||
| tutorial. | ||
|
|
||
| In this tutorial, we're going to: | ||
|
|
||
| 1. Get our data and split it into a training and test set. | ||
| 2. Implement a ranking model. | ||
| 3. Fit and evaluate it. | ||
| 4. Test running predictions with the model. | ||
| """ | ||
|
|
||
| import keras | ||
|
|
||
| # Needed for the dataset | ||
| import tensorflow as tf | ||
| import tensorflow_datasets as tfds | ||
|
|
||
| """ | ||
| ## Preparing the dataset | ||
|
|
||
| We're going to use the same data as the | ||
| [retrieval](https://github.com/keras-team/keras-rs/blob/main/examples/basic_retrieval.py) | ||
| tutorial. The ratings are the objectives we are trying to predict. | ||
| """ | ||
|
|
||
| # Ratings data. | ||
| ratings = tfds.load("movielens/100k-ratings", split="train") | ||
| # Features of all the available movies. | ||
| movies = tfds.load("movielens/100k-movies", split="train") | ||
|
|
||
| """ | ||
| In the Movielens dataset, user ids are integers (represented as strings) | ||
| starting at 1 and with no gap. Normally, you would need to create a lookup table | ||
| to map user ids to integers from 0 to N-1. But as a simplication, we'll use the | ||
| user id directly as an index in our model, in particular to lookup the user | ||
| embedding from the user embedding table. So we need do know the number of users. | ||
| """ | ||
|
|
||
| users_count = ( | ||
| ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) | ||
| .reduce(tf.constant(0, tf.int32), tf.maximum) | ||
| .numpy() | ||
| ) | ||
|
|
||
| """ | ||
| In the Movielens dataset, movie ids are integers (represented as strings) | ||
| starting at 1 and with no gap. Normally, you would need to create a lookup table | ||
| to map movie ids to integers from 0 to N-1. But as a simplication, we'll use the | ||
| movie id directly as an index in our model, in particular to lookup the movie | ||
| embedding from the movie embedding table. So we need do know the number of | ||
| movies. | ||
| """ | ||
|
|
||
| movies_count = movies.cardinality().numpy() | ||
|
|
||
| """ | ||
| The inputs to the model are the user ids and movie ids and the labels are the | ||
| ratings. | ||
| """ | ||
|
|
||
|
|
||
| def preprocess_rating(x): | ||
| return ( | ||
| # Inputs are user ids and movie ids | ||
| { | ||
| "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), | ||
| "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), | ||
| }, | ||
| # Labels are ratings between 0 and 1. | ||
| (x["user_rating"] - 1.0) / 4.0, | ||
| ) | ||
|
|
||
|
|
||
| """ | ||
| We'll split the data by putting 80% of the ratings in the train set, and 20% in | ||
| the test set. | ||
| """ | ||
|
|
||
| shuffled_ratings = ratings.map(preprocess_rating).shuffle( | ||
| 100_000, seed=42, reshuffle_each_iteration=False | ||
| ) | ||
| train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() | ||
| test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() | ||
|
|
||
| """ | ||
| ## Implementing the Model | ||
|
|
||
| ### Architecture | ||
|
|
||
| Ranking models do not face the same efficiency constraints as retrieval models | ||
| do, and so we have a little bit more freedom in our choice of architectures. | ||
|
|
||
| A model composed of multiple stacked dense layers is a relatively common | ||
| architecture for ranking tasks. We can implement it as follows: | ||
| """ | ||
|
|
||
|
|
||
| class RankingModel(keras.Model): | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_users, | ||
| num_candidates, | ||
| embedding_dimension=32, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| # Embedding table for users. | ||
| self.user_embedding = keras.layers.Embedding( | ||
| num_users, embedding_dimension | ||
| ) | ||
| # Embedding table for candidates. | ||
| self.candidate_embedding = keras.layers.Embedding( | ||
| num_candidates, embedding_dimension | ||
| ) | ||
| # Predictions. | ||
| self.ratings = keras.Sequential( | ||
| [ | ||
| # Learn multiple dense layers. | ||
| keras.layers.Dense(256, activation="relu"), | ||
| keras.layers.Dense(64, activation="relu"), | ||
| # Make rating predictions in the final layer. | ||
| keras.layers.Dense(1), | ||
| ] | ||
| ) | ||
|
|
||
| def call(self, inputs): | ||
| user_id, movie_id = inputs["user_id"], inputs["movie_id"] | ||
| user_embeddings = self.user_embedding(user_id) | ||
| candidate_embeddings = self.candidate_embedding(movie_id) | ||
| return self.ratings( | ||
| keras.ops.concatenate( | ||
| [user_embeddings, candidate_embeddings], axis=1 | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| """ | ||
| Let's first instantiate the model. Note that we add `+ 1` to the number of users | ||
| and movies to account for the fact that id zero is not used for either (ids | ||
| start at 1), but still takes a row in the embedding tables. | ||
| """ | ||
|
|
||
| model = RankingModel(users_count + 1, movies_count + 1) | ||
|
|
||
| """ | ||
| ### Loss and metrics | ||
|
|
||
| The next component is the loss used to train our model. Keras has several losses | ||
| to make this easy. In this instance, we'll make use of the `MeanSquaredError` | ||
| loss in order to predict the ratings. We'll also look at the | ||
| `RootMeanSquaredError` metric. | ||
| """ | ||
|
|
||
| model.compile( | ||
| loss=keras.losses.MeanSquaredError(), | ||
| metrics=[keras.metrics.RootMeanSquaredError()], | ||
| optimizer=keras.optimizers.Adagrad(learning_rate=0.1), | ||
| ) | ||
|
|
||
| """ | ||
| ## Fitting and evaluating | ||
|
|
||
| After defining the model, we can use the standard Keras `model.fit()` to train | ||
| the model. | ||
| """ | ||
|
|
||
| model.fit(train_ratings, epochs=5) | ||
|
|
||
| """ | ||
| As the model trains, the loss is falling and the RMSE metric is improving. | ||
|
|
||
| Finally, we can evaluate our model on the test set. The lower the RMSE metric, | ||
| the more accurate our model is at predicting ratings. | ||
| """ | ||
|
|
||
| model.evaluate(test_ratings, return_dict=True) | ||
|
|
||
| """ | ||
| ## Testing the ranking model | ||
|
|
||
| So far, we have only handled movies by id. Now is the time to create a mapping | ||
| keyed by movie ids to be able to surface the titles. | ||
| """ | ||
|
|
||
| movie_id_to_movie_title = { | ||
| int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() | ||
| } | ||
| movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. | ||
|
|
||
| """ | ||
| Now we can test the ranking model by computing predictions for a set of movies | ||
| and then rank these movies based on the predictions: | ||
| """ | ||
|
|
||
| user_id = 42 | ||
| movie_ids = [204, 141, 131] | ||
| predictions = model.predict( | ||
| { | ||
| "user_id": keras.ops.array([user_id] * len(movie_ids)), | ||
| "movie_id": keras.ops.array(movie_ids), | ||
| } | ||
| ) | ||
| predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=1)) | ||
|
|
||
| for movie_id, prediction in zip(movie_ids, predictions): | ||
| print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these IDs guaranteed to be sequential? If they're just random int-like, we would want to reindex them from 0 to num_users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cheated because I know the dataset and they are sequential.
The original example was more general because it wasn't making any assumption on the IDs, not even that they were numbers, because it's using a
StringLookuplayer as part of the model.StringLookupis not multi-backend so we can't do that as part of the model, it would have to be part of preprocessing. But then it gets more complicated to demonstrate inference (in the retrieval example), because the reverse lookup is needed, so we would need:We have 3 options:
StringLookupoutside of the modelThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's ok if you clearly document this gotcha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some detailed language about this shortcut.