@@ -66,28 +66,33 @@ def __init__(
6666 self .split_info = self .dataset .info .splits [split ]
6767 self .num_samples = self .split_info .num_examples
6868
69- if isinstance (additional_features , str ):
70- self .additional_features = [additional_features ]
71- elif isinstance (additional_features , list ):
72- self .additional_features = additional_features
69+ if additional_features is not None :
70+ if isinstance (additional_features , list ):
71+ self .additional_features = additional_features
72+ else :
73+ self .additional_features = [additional_features ]
7374 else :
74- self .additional_features = []
75+ self .additional_features = None
7576
7677 def __getitem__ (self , index ):
7778 item = self .dataset [index ]
7879 image = item [self .image_key ]
79- features = [item [feat ] for feat in self .additional_features ]
8080
8181 if 'bytes' in image and image ['bytes' ]:
8282 image = io .BytesIO (image ['bytes' ])
8383 else :
8484 assert 'path' in image and image ['path' ]
8585 image = open (image ['path' ], 'rb' )
86+
8687 label = item [self .label_key ]
8788 if self .remap_class :
8889 label = self .class_to_idx [label ]
8990
90- return image , label , * features
91+ if self .additional_features is not None :
92+ features = [item [feat ] for feat in self .additional_features ]
93+ return image , label , * features
94+ else :
95+ return image , label
9196
9297 def __len__ (self ):
9398 return len (self .dataset )
0 commit comments