@@ -1504,29 +1504,30 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
15041504
15051505 ADDITIONAL_CONFIGS = combinations_grid (split = (None , "train" , "test" ), split_ratio = (10 , 1 , 19 ))
15061506
1507+ _NUM_FRAMES = 20
1508+
15071509 def inject_fake_data (self , tmpdir , config ):
15081510 base_folder = os .path .join (tmpdir , self .DATASET_CLASS .__name__ )
15091511 os .makedirs (base_folder , exist_ok = True )
1510- num_samples = 20
1512+ num_samples = 5
15111513 data = np .concatenate (
15121514 [
15131515 np .zeros ((config ["split_ratio" ], num_samples , 64 , 64 )),
1514- np .ones ((20 - config ["split_ratio" ], num_samples , 64 , 64 )),
1516+ np .ones ((self . _NUM_FRAMES - config ["split_ratio" ], num_samples , 64 , 64 )),
15151517 ]
15161518 )
15171519 np .save (os .path .join (base_folder , "mnist_test_seq.npy" ), data )
15181520 return num_samples
15191521
15201522 @datasets_utils .test_all_configs
15211523 def test_split (self , config ):
1522- if config ["split" ] is None :
1523- return
1524-
1525- with self .create_dataset (config ) as (dataset , info ):
1524+ with self .create_dataset (config ) as (dataset , _ ):
15261525 if config ["split" ] == "train" :
15271526 assert (dataset .data == 0 ).all ()
1528- else :
1527+ elif config [ "split" ] == "test" :
15291528 assert (dataset .data == 1 ).all ()
1529+ else :
1530+ assert dataset .data .size ()[1 ] == self ._NUM_FRAMES
15301531
15311532
15321533class DatasetFolderTestCase (datasets_utils .ImageDatasetTestCase ):
0 commit comments