Skip to content

Commit 068c765

Browse files
authored
feat: Allow using entity's join_key in get_online_features (#2420)
* allowing using entity's join_key in get_online_features Signed-off-by: pyalex <[email protected]> * fix tests Signed-off-by: pyalex <[email protected]>
1 parent 04dea73 commit 068c765

File tree

6 files changed

+84
-58
lines changed

6 files changed

+84
-58
lines changed

sdk/python/feast/feature_store.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,9 +1266,11 @@ def _get_online_features(
12661266
features=features, allow_cache=True, hide_dummy_entity=False
12671267
)
12681268

1269-
entity_name_to_join_key_map, entity_type_map = self._get_entity_maps(
1270-
requested_feature_views
1271-
)
1269+
(
1270+
entity_name_to_join_key_map,
1271+
entity_type_map,
1272+
join_keys_set,
1273+
) = self._get_entity_maps(requested_feature_views)
12721274

12731275
# Extract Sequence from RepeatedValue Protobuf.
12741276
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
@@ -1322,22 +1324,32 @@ def _get_online_features(
13221324
join_key_values: Dict[str, List[Value]] = {}
13231325
request_data_features: Dict[str, List[Value]] = {}
13241326
# Entity rows may be either entities or request data.
1325-
for entity_name, values in entity_proto_values.items():
1327+
for join_key_or_entity_name, values in entity_proto_values.items():
13261328
# Found request data
13271329
if (
1328-
entity_name in needed_request_data
1329-
or entity_name in needed_request_fv_features
1330+
join_key_or_entity_name in needed_request_data
1331+
or join_key_or_entity_name in needed_request_fv_features
13301332
):
1331-
if entity_name in needed_request_fv_features:
1333+
if join_key_or_entity_name in needed_request_fv_features:
13321334
# If the data was requested as a feature then
13331335
# make sure it appears in the result.
1334-
requested_result_row_names.add(entity_name)
1335-
request_data_features[entity_name] = values
1336+
requested_result_row_names.add(join_key_or_entity_name)
1337+
request_data_features[join_key_or_entity_name] = values
13361338
else:
1337-
try:
1338-
join_key = entity_name_to_join_key_map[entity_name]
1339-
except KeyError:
1340-
raise EntityNotFoundException(entity_name, self.project)
1339+
if join_key_or_entity_name in join_keys_set:
1340+
join_key = join_key_or_entity_name
1341+
else:
1342+
try:
1343+
join_key = entity_name_to_join_key_map[join_key_or_entity_name]
1344+
except KeyError:
1345+
raise EntityNotFoundException(
1346+
join_key_or_entity_name, self.project
1347+
)
1348+
else:
1349+
warnings.warn(
1350+
"Using entity name is deprecated. Use join_key instead."
1351+
)
1352+
13411353
# All join keys should be returned in the result.
13421354
requested_result_row_names.add(join_key)
13431355
join_key_values[join_key] = values
@@ -1422,7 +1434,9 @@ def _get_columnar_entity_values(
14221434
return res
14231435
return cast(Dict[str, List[Any]], columnar)
14241436

1425-
def _get_entity_maps(self, feature_views):
1437+
def _get_entity_maps(
1438+
self, feature_views
1439+
) -> Tuple[Dict[str, str], Dict[str, ValueType], Set[str]]:
14261440
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
14271441
entity_name_to_join_key_map: Dict[str, str] = {}
14281442
entity_type_map: Dict[str, ValueType] = {}
@@ -1444,7 +1458,11 @@ def _get_entity_maps(self, feature_views):
14441458
)
14451459
entity_name_to_join_key_map[entity_name] = join_key
14461460
entity_type_map[join_key] = entity.value_type
1447-
return entity_name_to_join_key_map, entity_type_map
1461+
return (
1462+
entity_name_to_join_key_map,
1463+
entity_type_map,
1464+
set(entity_name_to_join_key_map.values()),
1465+
)
14481466

14491467
@staticmethod
14501468
def _get_table_entity_values(

sdk/python/tests/example_repos/example_feature_repo_1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828

2929
driver = Entity(
3030
name="driver", # The name is derived from this argument, not object name.
31+
join_key="driver_id",
3132
value_type=ValueType.INT64,
3233
description="driver id",
3334
)
3435

3536
customer = Entity(
3637
name="customer", # The name is derived from this argument, not object name.
38+
join_key="customer_id",
3739
value_type=ValueType.STRING,
3840
)
3941

sdk/python/tests/integration/e2e/test_universal_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def check_offline_and_online_features(
4545
# Check online store
4646
response_dict = fs.get_online_features(
4747
[f"{fv.name}:value"],
48-
[{"driver": driver_id}],
48+
[{"driver_id": driver_id}],
4949
full_feature_names=full_feature_names,
5050
).to_dict()
5151

sdk/python/tests/integration/online_store/test_online_retrieval.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_online() -> None:
3434
provider = store._get_provider()
3535

3636
driver_key = EntityKeyProto(
37-
join_keys=["driver"], entity_values=[ValueProto(int64_val=1)]
37+
join_keys=["driver_id"], entity_values=[ValueProto(int64_val=1)]
3838
)
3939
provider.online_write_batch(
4040
config=store.config,
@@ -54,7 +54,7 @@ def test_online() -> None:
5454
)
5555

5656
customer_key = EntityKeyProto(
57-
join_keys=["customer"], entity_values=[ValueProto(string_val="5")]
57+
join_keys=["customer_id"], entity_values=[ValueProto(string_val="5")]
5858
)
5959
provider.online_write_batch(
6060
config=store.config,
@@ -75,7 +75,7 @@ def test_online() -> None:
7575
)
7676

7777
customer_key = EntityKeyProto(
78-
join_keys=["customer", "driver"],
78+
join_keys=["customer_id", "driver_id"],
7979
entity_values=[ValueProto(string_val="5"), ValueProto(int64_val=1)],
8080
)
8181
provider.online_write_batch(
@@ -100,15 +100,18 @@ def test_online() -> None:
100100
"customer_profile:name",
101101
"customer_driver_combined:trips",
102102
],
103-
entity_rows=[{"driver": 1, "customer": "5"}, {"driver": 1, "customer": 5}],
103+
entity_rows=[
104+
{"driver_id": 1, "customer_id": "5"},
105+
{"driver_id": 1, "customer_id": 5},
106+
],
104107
full_feature_names=False,
105108
).to_dict()
106109

107110
assert "lon" in result
108111
assert "avg_orders_day" in result
109112
assert "name" in result
110-
assert result["driver"] == [1, 1]
111-
assert result["customer"] == ["5", "5"]
113+
assert result["driver_id"] == [1, 1]
114+
assert result["customer_id"] == ["5", "5"]
112115
assert result["lon"] == ["1.0", "1.0"]
113116
assert result["avg_orders_day"] == [1.0, 1.0]
114117
assert result["name"] == ["John", "John"]
@@ -117,7 +120,7 @@ def test_online() -> None:
117120
# Ensure features are still in result when keys not found
118121
result = store.get_online_features(
119122
features=["customer_driver_combined:trips"],
120-
entity_rows=[{"driver": 0, "customer": 0}],
123+
entity_rows=[{"driver_id": 0, "customer_id": 0}],
121124
full_feature_names=False,
122125
).to_dict()
123126

@@ -127,7 +130,7 @@ def test_online() -> None:
127130
with pytest.raises(FeatureViewNotFoundException):
128131
store.get_online_features(
129132
features=["driver_locations_bad:lon"],
130-
entity_rows=[{"driver": 1}],
133+
entity_rows=[{"driver_id": 1}],
131134
full_feature_names=False,
132135
)
133136

@@ -152,7 +155,7 @@ def test_online() -> None:
152155
"customer_profile:name",
153156
"customer_driver_combined:trips",
154157
],
155-
entity_rows=[{"driver": 1, "customer": 5}],
158+
entity_rows=[{"driver_id": 1, "customer_id": 5}],
156159
full_feature_names=False,
157160
).to_dict()
158161
assert result["lon"] == ["1.0"]
@@ -173,7 +176,7 @@ def test_online() -> None:
173176
"customer_profile:name",
174177
"customer_driver_combined:trips",
175178
],
176-
entity_rows=[{"driver": 1, "customer": 5}],
179+
entity_rows=[{"driver_id": 1, "customer_id": 5}],
177180
full_feature_names=False,
178181
).to_dict()
179182

@@ -188,7 +191,7 @@ def test_online() -> None:
188191
"customer_profile:name",
189192
"customer_driver_combined:trips",
190193
],
191-
entity_rows=[{"driver": 1, "customer": 5}],
194+
entity_rows=[{"driver_id": 1, "customer_id": 5}],
192195
full_feature_names=False,
193196
).to_dict()
194197
assert result["lon"] == ["1.0"]
@@ -214,7 +217,7 @@ def test_online() -> None:
214217
"customer_profile:name",
215218
"customer_driver_combined:trips",
216219
],
217-
entity_rows=[{"driver": 1, "customer": 5}],
220+
entity_rows=[{"driver_id": 1, "customer_id": 5}],
218221
full_feature_names=False,
219222
).to_dict()
220223
assert result["lon"] == ["1.0"]
@@ -234,7 +237,7 @@ def test_online() -> None:
234237
"customer_profile:name",
235238
"customer_driver_combined:trips",
236239
],
237-
entity_rows=[{"driver": 1, "customer": 5}],
240+
entity_rows=[{"driver_id": 1, "customer_id": 5}],
238241
full_feature_names=False,
239242
).to_dict()
240243
assert result["lon"] == ["1.0"]
@@ -284,7 +287,7 @@ def test_online_to_df():
284287
3 3.0 0.3
285288
"""
286289
driver_key = EntityKeyProto(
287-
join_keys=["driver"], entity_values=[ValueProto(int64_val=d)]
290+
join_keys=["driver_id"], entity_values=[ValueProto(int64_val=d)]
288291
)
289292
provider.online_write_batch(
290293
config=store.config,
@@ -311,7 +314,7 @@ def test_online_to_df():
311314
6 6.0 foo6 60
312315
"""
313316
customer_key = EntityKeyProto(
314-
join_keys=["customer"], entity_values=[ValueProto(string_val=str(c))]
317+
join_keys=["customer_id"], entity_values=[ValueProto(string_val=str(c))]
315318
)
316319
provider.online_write_batch(
317320
config=store.config,
@@ -340,7 +343,7 @@ def test_online_to_df():
340343
6 3 18
341344
"""
342345
combo_keys = EntityKeyProto(
343-
join_keys=["customer", "driver"],
346+
join_keys=["customer_id", "driver_id"],
344347
entity_values=[ValueProto(string_val=str(c)), ValueProto(int64_val=d)],
345348
)
346349
provider.online_write_batch(
@@ -369,7 +372,7 @@ def test_online_to_df():
369372
],
370373
# Reverse the row order
371374
entity_rows=[
372-
{"driver": d, "customer": c}
375+
{"driver_id": d, "customer_id": c}
373376
for (d, c) in zip(reversed(driver_ids), reversed(customer_ids))
374377
],
375378
).to_df()
@@ -381,8 +384,8 @@ def test_online_to_df():
381384
1 4 1.0 0.1 4.0 foo4 40 4
382385
"""
383386
df_dict = {
384-
"driver": driver_ids,
385-
"customer": [str(c) for c in customer_ids],
387+
"driver_id": driver_ids,
388+
"customer_id": [str(c) for c in customer_ids],
386389
"lon": [str(d * lon_multiply) for d in driver_ids],
387390
"lat": [d * lat_multiply for d in driver_ids],
388391
"avg_orders_day": [c * avg_order_day_multiply for c in customer_ids],
@@ -392,8 +395,8 @@ def test_online_to_df():
392395
}
393396
# Requested column order
394397
ordered_column = [
395-
"driver",
396-
"customer",
398+
"driver_id",
399+
"customer_id",
397400
"lon",
398401
"lat",
399402
"avg_orders_day",

0 commit comments

Comments
 (0)