11import warnings
22
3+ from pytensor .tensor import TensorVariable , mul
4+
35
46try :
57 import xarray as xr
@@ -133,10 +135,99 @@ def __complex__(self):
133135 "Call `.astype(complex)` for the symbolic equivalent."
134136 )
135137
138+ # DataArray-like attributes
139+ # https://docs.xarray.dev/en/latest/api.html#id1
140+ @property
141+ def values (self ) -> TensorVariable :
142+ from pytensor .xtensor .basic import tensor_from_xtensor
143+
144+ return tensor_from_xtensor (self )
145+
146+ data = values
147+
148+ @property
149+ def coords (self ):
150+ raise NotImplementedError ("coords not implemented for XTensorVariable" )
151+
152+ @property
153+ def dims (self ) -> tuple [str ]:
154+ return self .type .dims
155+
156+ @property
157+ def sizes (self ) -> dict [str , TensorVariable ]:
158+ return dict (zip (self .dims , self .shape ))
159+
160+ @property
161+ def as_numpy (self ):
162+ # No-op, since the underlying data is always a numpy array
163+ return self
164+
165+ # ndarray attributes
166+ # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
167+ @property
168+ def ndim (self ) -> int :
169+ return self .type .ndim
170+
171+ @property
172+ def shape (self ) -> tuple [TensorVariable ]:
173+ from pytensor .xtensor .basic import tensor_from_xtensor
174+
175+ return tuple (tensor_from_xtensor (self ).shape )
176+
177+ @property
178+ def size (self ):
179+ return mul (* self .shape )
180+
181+ @property
182+ def dtype (self ):
183+ return self .type .dtype
184+
185+ # DataArray contents
186+ # https://docs.xarray.dev/en/latest/api.html#dataarray-contents
187+ def rename (self , new_name_or_name_dict , ** names ):
188+ from pytensor .xtensor .basic import rename
189+
190+ if isinstance (new_name_or_name_dict , str ):
191+ # TODO: Should we make a symbolic copy?
192+ self .name = new_name_or_name_dict
193+ name_dict = None
194+ else :
195+ name_dict = new_name_or_name_dict
196+ return rename (name_dict , ** names )
197+
198+ # def swap_dims(self, *args, **kwargs):
199+ # ...
200+ #
201+ # def expand_dims(self, *args, **kwargs):
202+ # ...
203+ #
204+ # def squeeze(self):
205+ # ...
206+
207+ def copy (self ):
208+ from pytensor .xtensor .math import identity
209+
210+ return identity (self )
211+
212+ def astype (self , dtype ):
213+ from pytensor .xtensor .math import cast
214+
215+ return cast (self , dtype )
216+
217+ def item (self ):
218+ raise NotImplementedError ("item not implemented for XTensorVariable" )
219+
220+ # Indexing
221+ # https://docs.xarray.dev/en/latest/api.html#id2
136222 def __setitem__ (self , key , value ):
137- raise TypeError (
138- "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
139- )
223+ raise TypeError ("XTensorVariable does not support item assignment." )
224+
225+ @property
226+ def loc (self ):
227+ raise NotImplementedError ("loc not implemented for XTensorVariable" )
228+
229+ def sel (self , * args , ** kwargs ):
230+ raise NotImplementedError ("sel not implemented for XTensorVariable" )
140231
141232 def __getitem__ (self , idx ):
142233 from pytensor .xtensor .indexing import index
@@ -159,11 +250,6 @@ def __getitem__(self, idx):
159250
160251 return index (self , * idx )
161252
162- def sel (self , * args , ** kwargs ):
163- raise NotImplementedError (
164- "sel not implemented for XTensorVariable, use isel instead"
165- )
166-
167253 def isel (
168254 self ,
169255 indexers : dict [str , Any ] | None = None ,
@@ -208,6 +294,81 @@ def isel(
208294
209295 return index (self , * indices )
210296
297+ def _head_tail_or_thin (
298+ self ,
299+ indexers : dict [str , Any ] | int | None ,
300+ indexers_kwargs : dict [str , Any ],
301+ * ,
302+ kind : Literal ["head" , "tail" , "thin" ],
303+ ):
304+ if indexers_kwargs :
305+ if indexers is not None :
306+ raise ValueError (
307+ "Cannot pass both indexers and indexers_kwargs to head"
308+ )
309+ indexers = indexers_kwargs
310+
311+ if indexers is None :
312+ if kind == "thin" :
313+ raise TypeError (
314+ "thin() indexers must be either dict-like or a single integer"
315+ )
316+ else :
317+ # Default to 5 for head and tail
318+ indexers = {dim : 5 for dim in self .type .dims }
319+
320+ elif not isinstance (indexers , dict ):
321+ indexers = {dim : indexers for dim in self .type .dims }
322+
323+ if kind == "head" :
324+ indices = {dim : slice (None , value ) for dim , value in indexers .items ()}
325+ elif kind == "tail" :
326+ sizes = self .sizes
327+ # Can't use slice(-value, None), in case value is zero
328+ indices = {
329+ dim : slice (sizes [dim ] - value , None ) for dim , value in indexers .items ()
330+ }
331+ elif kind == "thin" :
332+ indices = {dim : slice (None , None , value ) for dim , value in indexers .items ()}
333+ return self .isel (indices )
334+
335+ def head (self , indexers : dict [str , Any ] | int | None = None , ** indexers_kwargs ):
336+ return self ._head_tail_or_thin (indexers , indexers_kwargs , kind = "head" )
337+
338+ def tail (self , indexers : dict [str , Any ] | int | None = None , ** indexers_kwargs ):
339+ return self ._head_tail_or_thin (indexers , indexers_kwargs , kind = "tail" )
340+
341+ def thin (self , indexers : dict [str , Any ] | int | None = None , ** indexers_kwargs ):
342+ return self ._head_tail_or_thin (indexers , indexers_kwargs , kind = "thin" )
343+
344+ # ndarray methods
345+ # https://docs.xarray.dev/en/latest/api.html#id7
346+ def clip (self , min , max ):
347+ from pytensor .xtensor .math import clip
348+
349+ return clip (self , min , max )
350+
351+ def conj (self ):
352+ from pytensor .xtensor .math import conj
353+
354+ return conj (self )
355+
356+ @property
357+ def imag (self ):
358+ from pytensor .xtensor .math import imag
359+
360+ return imag (self )
361+
362+ @property
363+ def real (self ):
364+ from pytensor .xtensor .math import real
365+
366+ return real (self )
367+
368+ # @property
369+ # def T(self):
370+ # ...
371+
211372
212373class XTensorConstantSignature (tuple ):
213374 def __eq__ (self , other ):
0 commit comments