|
1 | 1 | import numpy as np |
2 | | -from sklearn.utils.testing import assert_almost_equal |
| 2 | +from sklearn.utils.testing import assert_almost_equal, assert_array_almost_equal |
3 | 3 |
|
4 | | -from lightning.impl.penalty import project_l1_ball |
| 4 | +from lightning.impl.penalty import project_l1_ball, project_simplex |
| 5 | + |
| 6 | + |
| 7 | +def project_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): |
| 8 | + lower = 0 |
| 9 | + upper = np.max(v) |
| 10 | + current = np.inf |
| 11 | + |
| 12 | + for it in xrange(max_iter): |
| 13 | + if np.abs(current) / z < tau and current < 0: |
| 14 | + break |
| 15 | + |
| 16 | + theta = (upper + lower) / 2.0 |
| 17 | + w = np.maximum(v - theta, 0) |
| 18 | + current = np.sum(w) - z |
| 19 | + if current <= 0: |
| 20 | + upper = theta |
| 21 | + else: |
| 22 | + lower = theta |
| 23 | + return w |
| 24 | + |
| 25 | + |
| 26 | +def test_proj_simplex(): |
| 27 | + rng = np.random.RandomState(0) |
| 28 | + |
| 29 | + v = rng.rand(100) |
| 30 | + w = project_simplex(v, z=10) |
| 31 | + w2 = project_simplex_bisection(v, z=10, max_iter=100) |
| 32 | + assert_array_almost_equal(w, w2, 3) |
| 33 | + |
| 34 | + v = rng.rand(3) |
| 35 | + w = project_simplex(v, z=1) |
| 36 | + w2 = project_simplex_bisection(v, z=1, max_iter=100) |
| 37 | + assert_array_almost_equal(w, w2, 3) |
| 38 | + |
| 39 | + v = rng.rand(2) |
| 40 | + w = project_simplex(v, z=1) |
| 41 | + w2 = project_simplex_bisection(v, z=1, max_iter=100) |
| 42 | + assert_array_almost_equal(w, w2, 3) |
5 | 43 |
|
6 | 44 |
|
7 | 45 | def test_proj_l1_ball(): |
|
0 commit comments