1717from fsspec .implementations .cache_mapper import create_cache_mapper
1818from fsspec .implementations .cache_metadata import CacheMetadata
1919from fsspec .spec import AbstractBufferedFile
20+ from fsspec .transaction import Transaction
2021from fsspec .utils import infer_compression
2122
2223if TYPE_CHECKING :
2526logger = logging .getLogger ("fsspec.cached" )
2627
2728
29+ class WriteCachedTransaction (Transaction ):
30+ def complete (self , commit = True ):
31+ rpaths = [f .path for f in self .files ]
32+ lpaths = [f .fn for f in self .files ]
33+ if commit :
34+ self .fs .put (lpaths , rpaths )
35+ # else remove?
36+ self .fs ._intrans = False
37+
38+
2839class CachingFileSystem (AbstractFileSystem ):
2940 """Locally caching filesystem, layer over any other FS
3041
@@ -415,6 +426,10 @@ def __getattribute__(self, item):
415426 "__eq__" ,
416427 "to_json" ,
417428 "cache_size" ,
429+ "pipe_file" ,
430+ "pipe" ,
431+ "start_transaction" ,
432+ "end_transaction" ,
418433 ]:
419434 # all the methods defined in this class. Note `open` here, since
420435 # it calls `_open`, but is actually in superclass
@@ -423,7 +438,10 @@ def __getattribute__(self, item):
423438 )
424439 if item in ["__reduce_ex__" ]:
425440 raise AttributeError
426- if item in ["_cache" ]:
441+ if item in ["transaction" ]:
442+ # property
443+ return type (self ).transaction .__get__ (self )
444+ if item in ["_cache" , "transaction_type" ]:
427445 # class attributes
428446 return getattr (type (self ), item )
429447 if item == "__class__" :
@@ -512,7 +530,13 @@ def open_many(self, open_files):
512530 self ._mkcache ()
513531 else :
514532 return [
515- LocalTempFile (self .fs , path , mode = open_files .mode ) for path in paths
533+ LocalTempFile (
534+ self .fs ,
535+ path ,
536+ mode = open_files .mode ,
537+ fn = os .path .join (self .storage [- 1 ], self ._mapper (path )),
538+ )
539+ for path in paths
516540 ]
517541
518542 if self .compression :
@@ -625,7 +649,8 @@ def cat(
625649 def _open (self , path , mode = "rb" , ** kwargs ):
626650 path = self ._strip_protocol (path )
627651 if "r" not in mode :
628- return LocalTempFile (self , path , mode = mode )
652+ fn = self ._make_local_details (path )
653+ return LocalTempFile (self , path , mode = mode , fn = fn )
629654 detail = self ._check_file (path )
630655 if detail :
631656 detail , fn = detail
@@ -692,6 +717,7 @@ class SimpleCacheFileSystem(WholeFileCacheFileSystem):
692717
693718 protocol = "simplecache"
694719 local_file = True
720+ transaction_type = WriteCachedTransaction
695721
696722 def __init__ (self , ** kwargs ):
697723 kw = kwargs .copy ()
@@ -716,6 +742,22 @@ def save_cache(self):
716742 def load_cache (self ):
717743 pass
718744
745+ def pipe_file (self , path , value = None , ** kwargs ):
746+ if self ._intrans :
747+ with self .open (path , "wb" ) as f :
748+ f .write (value )
749+ else :
750+ super ().pipe_file (path , value )
751+
752+ def pipe (self , path , value = None , ** kwargs ):
753+ if isinstance (path , str ):
754+ self .pipe_file (self ._strip_protocol (path ), value , ** kwargs )
755+ elif isinstance (path , dict ):
756+ for k , v in path .items ():
757+ self .pipe_file (self ._strip_protocol (k ), v , ** kwargs )
758+ else :
759+ raise ValueError ("path must be str or dict" )
760+
719761 def cat_ranges (
720762 self , paths , starts , ends , max_gap = None , on_error = "return" , ** kwargs
721763 ):
@@ -729,14 +771,17 @@ def cat_ranges(
729771
730772 def _open (self , path , mode = "rb" , ** kwargs ):
731773 path = self ._strip_protocol (path )
774+ sha = self ._mapper (path )
732775
733776 if "r" not in mode :
734- return LocalTempFile (self , path , mode = mode )
777+ fn = os .path .join (self .storage [- 1 ], sha )
778+ return LocalTempFile (
779+ self , path , mode = mode , autocommit = not self ._intrans , fn = fn
780+ )
735781 fn = self ._check_file (path )
736782 if fn :
737783 return open (fn , mode )
738784
739- sha = self ._mapper (path )
740785 fn = os .path .join (self .storage [- 1 ], sha )
741786 logger .debug ("Copying %s to local cache" , path )
742787 kwargs ["mode" ] = mode
@@ -767,13 +812,9 @@ def _open(self, path, mode="rb", **kwargs):
767812class LocalTempFile :
768813 """A temporary local file, which will be uploaded on commit"""
769814
770- def __init__ (self , fs , path , fn = None , mode = "wb" , autocommit = True , seek = 0 ):
771- if fn :
772- self .fn = fn
773- self .fh = open (fn , mode )
774- else :
775- fd , self .fn = tempfile .mkstemp ()
776- self .fh = open (fd , mode )
815+ def __init__ (self , fs , path , fn , mode = "wb" , autocommit = True , seek = 0 ):
816+ self .fn = fn
817+ self .fh = open (fn , mode )
777818 self .mode = mode
778819 if seek :
779820 self .fh .seek (seek )
0 commit comments