@@ -27,10 +27,17 @@ class ProductTree:
27
27
sage: R.<x> = F[]
28
28
sage: ms = [x - a^i for i in range(1024)] # roots of unity
29
29
sage: ys = [F.random_element() for _ in range(1024)] # input vector
30
- sage: zs = ProductTree(ms).remainders(R(ys)) # compute FFT!
30
+ sage: tree = ProductTree(ms)
31
+ sage: zs = tree.remainders(R(ys)) # compute FFT!
31
32
sage: zs == [R(ys) % m for m in ms]
32
33
True
33
34
35
+ Similarly, the :meth:`interpolation` method can be used to implement
36
+ the inverse Fast Fourier Transform::
37
+
38
+ sage: tree.interpolation(zs).padded_list(len(ys)) == ys
39
+ True
40
+
34
41
This class encodes the tree as *layers*: Layer `0` is just a tuple
35
42
of the leaves. Layer `i+1` is obtained from layer `i` by replacing
36
43
each pair of two adjacent elements by their product, starting from
@@ -175,7 +182,6 @@ def remainders(self, x):
175
182
The base ring must support the ``%`` operator for this
176
183
method to work.
177
184
178
-
179
185
INPUT:
180
186
181
187
- ``x`` -- an element of the base ring of this product tree
@@ -196,6 +202,55 @@ def remainders(self, x):
196
202
X = [X [i // 2 ] % V [i ] for i in range (len (V ))]
197
203
return X
198
204
205
+ _crt_bases = None
206
+
207
+ def interpolation (self , xs ):
208
+ r"""
209
+ Given a sequence ``xs`` of values, one per leaf, return a
210
+ single element `x` which is congruent to the `i`\th value in
211
+ ``xs`` modulo the `i`\th leaf, for all `i`.
212
+
213
+ This is an explicit version of the Chinese remainder theorem;
214
+ see also :meth:`CRT`. Using this product tree is faster for
215
+ repeated calls since the required CRT bases are cached after
216
+ the first run.
217
+
218
+ The base ring must support the :func:`xgcd` function for this
219
+ method to work.
220
+
221
+ EXAMPLES::
222
+
223
+ sage: from sage.rings.generic import ProductTree
224
+ sage: vs = prime_range(100)
225
+ sage: tree = ProductTree(vs)
226
+ sage: tree.interpolation([1, 1, 2, 1, 9, 1, 7, 15, 8, 20, 15, 6, 27, 11, 2, 6, 0, 25, 49, 5, 51, 4, 19, 74, 13])
227
+ 1085749272377676749812331719267
228
+
229
+ This method is faster than :func:`CRT` for repeated calls with
230
+ the same moduli::
231
+
232
+ sage: vs = prime_range(1000,2000)
233
+ sage: rs = list(range(len(vs)))
234
+ sage: tree = ProductTree(vs)
235
+ sage: %timeit CRT(rs,vs) # not tested
236
+ 324 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
237
+ sage: %timeit tree.interpolation(rs) # not tested
238
+ 102 µs ± 92.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
239
+ """
240
+ if self ._crt_bases is None :
241
+ from sage .arith .misc import CRT_basis
242
+ self ._crt_bases = []
243
+ for V in self .layers [:- 1 ]:
244
+ B = tuple (CRT_basis (V [i :i + 2 ]) for i in range (0 ,len (V ),2 ))
245
+ self ._crt_bases .append (B )
246
+ if len (xs ) != len (self .layers [0 ]):
247
+ raise ValueError ('number of given elements must equal the number of leaves' )
248
+ for basis ,layer in zip (self ._crt_bases , self .layers [1 :]):
249
+ xs = [sum (c * x for c ,x in zip (cs ,xs [2 * i :2 * i + 2 ])) % mod
250
+ for i ,(cs ,mod ) in enumerate (zip (basis ,layer ))]
251
+ assert len (xs ) == 1
252
+ return xs [0 ]
253
+
199
254
200
255
def prod_with_derivative (pairs ):
201
256
r"""
0 commit comments