Skip to content

Commit 1d5caad

Browse files
committed
implement .interpolation() method for ProductTree
1 parent ebef87a commit 1d5caad

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

src/sage/rings/generic.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ class ProductTree:
2727
sage: R.<x> = F[]
2828
sage: ms = [x - a^i for i in range(1024)] # roots of unity
2929
sage: ys = [F.random_element() for _ in range(1024)] # input vector
30-
sage: zs = ProductTree(ms).remainders(R(ys)) # compute FFT!
30+
sage: tree = ProductTree(ms)
31+
sage: zs = tree.remainders(R(ys)) # compute FFT!
3132
sage: zs == [R(ys) % m for m in ms]
3233
True
3334
35+
Similarly, the :meth:`interpolation` method can be used to implement
36+
the inverse Fast Fourier Transform::
37+
38+
sage: tree.interpolation(zs).padded_list(len(ys)) == ys
39+
True
40+
3441
This class encodes the tree as *layers*: Layer `0` is just a tuple
3542
of the leaves. Layer `i+1` is obtained from layer `i` by replacing
3643
each pair of two adjacent elements by their product, starting from
@@ -175,7 +182,6 @@ def remainders(self, x):
175182
The base ring must support the ``%`` operator for this
176183
method to work.
177184
178-
179185
INPUT:
180186
181187
- ``x`` -- an element of the base ring of this product tree
@@ -196,6 +202,55 @@ def remainders(self, x):
196202
X = [X[i // 2] % V[i] for i in range(len(V))]
197203
return X
198204

205+
_crt_bases = None
206+
207+
def interpolation(self, xs):
208+
r"""
209+
Given a sequence ``xs`` of values, one per leaf, return a
210+
single element `x` which is congruent to the `i`\th value in
211+
``xs`` modulo the `i`\th leaf, for all `i`.
212+
213+
This is an explicit version of the Chinese remainder theorem;
214+
see also :meth:`CRT`. Using this product tree is faster for
215+
repeated calls since the required CRT bases are cached after
216+
the first run.
217+
218+
The base ring must support the :func:`xgcd` function for this
219+
method to work.
220+
221+
EXAMPLES::
222+
223+
sage: from sage.rings.generic import ProductTree
224+
sage: vs = prime_range(100)
225+
sage: tree = ProductTree(vs)
226+
sage: tree.interpolation([1, 1, 2, 1, 9, 1, 7, 15, 8, 20, 15, 6, 27, 11, 2, 6, 0, 25, 49, 5, 51, 4, 19, 74, 13])
227+
1085749272377676749812331719267
228+
229+
This method is faster than :func:`CRT` for repeated calls with
230+
the same moduli::
231+
232+
sage: vs = prime_range(1000,2000)
233+
sage: rs = list(range(len(vs)))
234+
sage: tree = ProductTree(vs)
235+
sage: %timeit CRT(rs,vs) # not tested
236+
324 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
237+
sage: %timeit tree.interpolation(rs) # not tested
238+
102 µs ± 92.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
239+
"""
240+
if self._crt_bases is None:
241+
from sage.arith.misc import CRT_basis
242+
self._crt_bases = []
243+
for V in self.layers[:-1]:
244+
B = tuple(CRT_basis(V[i:i+2]) for i in range(0,len(V),2))
245+
self._crt_bases.append(B)
246+
if len(xs) != len(self.layers[0]):
247+
raise ValueError('number of given elements must equal the number of leaves')
248+
for basis,layer in zip(self._crt_bases, self.layers[1:]):
249+
xs = [sum(c*x for c,x in zip(cs,xs[2*i:2*i+2])) % mod
250+
for i,(cs,mod) in enumerate(zip(basis,layer))]
251+
assert len(xs) == 1
252+
return xs[0]
253+
199254

200255
def prod_with_derivative(pairs):
201256
r"""

0 commit comments

Comments
 (0)