@@ -37,6 +37,7 @@ def __init__(
3737            class_map : dict  =  None ,
3838            input_key : str  =  'image' ,
3939            target_key : str  =  'label' ,
40+             additional_features : Optional [list [str ]] =  None ,
4041            download : bool  =  False ,
4142            trust_remote_code : bool  =  False 
4243    ):
@@ -65,9 +66,18 @@ def __init__(
6566        self .split_info  =  self .dataset .info .splits [split ]
6667        self .num_samples  =  self .split_info .num_examples 
6768
69+         if  isinstance (additional_features , str ):
70+             self .additional_features  =  [additional_features ]
71+         elif  isinstance (additional_features , list ):
72+             self .additional_features  =  additional_features 
73+         else :
74+             self .additional_features  =  []
75+ 
6876    def  __getitem__ (self , index ):
6977        item  =  self .dataset [index ]
7078        image  =  item [self .image_key ]
79+         features  =  [item [feat ] for  feat  in  self .additional_features ]
80+ 
7181        if  'bytes'  in  image  and  image ['bytes' ]:
7282            image  =  io .BytesIO (image ['bytes' ])
7383        else :
@@ -76,7 +86,8 @@ def __getitem__(self, index):
7686        label  =  item [self .label_key ]
7787        if  self .remap_class :
7888            label  =  self .class_to_idx [label ]
79-         return  image , label 
89+ 
90+         return  image , label , * features 
8091
8192    def  __len__ (self ):
8293        return  len (self .dataset )
0 commit comments