1717""" 
1818
1919import  collections 
20+ import  functools 
2021import  inspect 
2122import  math 
2223import  warnings 
@@ -1043,11 +1044,13 @@ def stateless_call(
10431044                if  self ._remat_mode  is  not   None :
10441045                    outputs  =  self .rematerialized_call (
10451046                        self .quantized_call , * args , ** kwargs 
1046-                     )
1047+                     )( * args ,  ** kwargs ) 
10471048                else :
10481049                    outputs  =  self .quantized_call (* args , ** kwargs )
10491050            elif  self ._remat_mode  is  not   None :
1050-                 outputs  =  self .rematerialized_call (self .call , * args , ** kwargs )
1051+                 outputs  =  self .rematerialized_call (self .call , * args , ** kwargs )(
1052+                     * args , ** kwargs 
1053+                 )
10511054            else :
10521055                outputs  =  self .call (* args , ** kwargs )
10531056            if  return_losses :
@@ -1601,13 +1604,13 @@ def compute_size(x):
16011604
16021605        # Full rematerialization 
16031606        if  self ._remat_mode .mode  ==  "full" :
1604-             return  remat .remat (layer_call )( * args ,  ** kwargs ) 
1607+             return  remat .remat (layer_call )
16051608
16061609        # Apply rematerialization to specific layers 
16071610        elif  self ._remat_mode .mode  ==  "list_of_layers"  and  (
16081611            self .name  in  self ._remat_mode .layer_names 
16091612        ):
1610-             return  remat .remat (layer_call )( * args ,  ** kwargs ) 
1613+             return  remat .remat (layer_call )
16111614
16121615        # Apply rematerialization based on output size threshold 
16131616        elif  self ._remat_mode .mode  ==  "larger_than" :
@@ -1619,20 +1622,24 @@ def compute_size(x):
16191622                output_size 
16201623                and  output_size  >  self ._remat_mode .output_size_threshold 
16211624            ):
1622-                 return  remat .remat (layer_call )( * args ,  ** kwargs ) 
1625+                 return  remat .remat (layer_call )
16231626        elif  self ._remat_mode .mode  ==  "activations" :
16241627            has_activation  =  (
16251628                hasattr (self , "activation" ) and  self .activation  is  not   None 
16261629            )
16271630            if  has_activation :
1628-                 not_rematted_activation  =  self .activation 
1629-                 try :
1630-                     self .activation  =  remat .remat (not_rematted_activation )
1631-                     return  layer_call (* args , ** kwargs )
1632-                 finally :
1633-                     self .activation  =  not_rematted_activation 
16341631
1635-         return  layer_call (* args , ** kwargs )
1632+                 @functools .wraps (layer_call ) 
1633+                 def  rematerialized_activation_call_wrapper (* args , ** kwargs ):
1634+                     original_activation  =  self .activation 
1635+                     self .activation  =  remat .remat (original_activation )
1636+                     try :
1637+                         return  layer_call (* args , ** kwargs )
1638+                     finally :
1639+                         self .activation  =  original_activation 
1640+ 
1641+                 return  rematerialized_activation_call_wrapper 
1642+         return  layer_call 
16361643
16371644
16381645def  is_backend_tensor_or_symbolic (x , allow_none = False ):
0 commit comments