Skip to content

Commit a244a94

Browse files
committed
Updates to thc solver to improve stability
1 parent 855fe42 commit a244a94

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

src/TiledArray/math/solvers/cp/thc_lt_thc_als.h

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
191191
// Checking the convergence of the thing
192192

193193
update_factors_left();
194-
{
194+
/*{
195195
DistArray<Tile, Policy> abcd_old, abcd_new, diff;
196196
abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
197197
ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
@@ -203,7 +203,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
203203
std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
204204
std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
205205
std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
206-
}
206+
}*/
207207
// Preserve symmetry in the structure
208208
{
209209
cp_factors[3] = cp_factors[0].clone();
@@ -216,7 +216,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
216216
}
217217
// Update the core tensor and don't rescale to normalized
218218
update_core();
219-
{
219+
/*{
220220
DistArray<Tile, Policy> abcd_old, abcd_new, diff;
221221
abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
222222
ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
@@ -228,24 +228,25 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
228228
std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
229229
std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
230230
std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
231-
}
231+
}*/
232232
//update_factors_right();
233-
// {
234-
// DistArray<Tile, Policy> abcd_old, abcd_new, diff;
235-
// abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
236-
// ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
237-
// abcd_new("a,b,c,d") = TA::einsum(cp_factors[0]("P,a"), cp_factors[1]("P,b"), "P,a,b")("P,a,b") *
238-
// (UnNormalizedLeft("P,X") * UnNormalizedRight("Q,X")) * TA::einsum(cp_factors[3]("Q,c"), cp_factors[4]("Q,d"), "Q,c,d")("Q,c,d");
239-
// diff("a,b,c,d") = abcd_new("a,b,c,d") - abcd_old("a,b,c,d");
240-
// std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
241-
// std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
242-
// std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
243-
// }
233+
/*{
234+
DistArray<Tile, Policy> abcd_old, abcd_new, diff;
235+
abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
236+
ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
237+
abcd_new("a,b,c,d") = TA::einsum(cp_factors[0]("P,a"), cp_factors[1]("P,b"), "P,a,b")("P,a,b") *
238+
(UnNormalizedLeft("P,X") * UnNormalizedRight("Q,X")) * TA::einsum(cp_factors[3]("Q,c"), cp_factors[4]("Q,d"), "Q,c,d")("Q,c,d");
239+
diff("a,b,c,d") = abcd_new("a,b,c,d") - abcd_old("a,b,c,d");
240+
std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
241+
std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
242+
std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
243+
}*/
244244

245245
converged = this->check_thc_fit(verbose);
246246

247247
++iter;
248248
} while (iter < max_iter && !converged);
249+
this->unNormalized_Factor = cp_factors[4];
249250
}
250251

251252
// These assume the center is a sqrt of the core tensor.
@@ -359,12 +360,17 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
359360
this->normalize_factor(MTtKRP);
360361
cp_factors[5] = MTtKRP;
361362
}*/
363+
364+
//
362365
void update_factors_left(){
363366
DistArray<Tile, Policy> env, b_mON, MttKRP, W, W_env, pq;
364367

365368
// solve for A
366369
env("m,M,P") = ref_core("M,N") * THC_times_CPD[1]("m,N,Q") * cp_factors[2]("P,Q");
367370
b_mON("m,M,P") = ref_orb_b("b,m,M") * cp_factors[1]("P,b");
371+
env.truncate();
372+
b_mON.truncate();
373+
368374
MttKRP("P,a") = TA::einsum(env("m,M,P"), b_mON("m,M,P"), "m,M,P")("m,M,P") * ref_orb_a("a,m,M");
369375

370376
DistArray<Tile, Policy> temp;
@@ -380,6 +386,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
380386
world.gop.fence(); // N.B. seems to deadlock without this
381387

382388
this->normalize_factor(MttKRP);
389+
MttKRP.truncate();
383390
cp_factors[0] = MttKRP;
384391
this->partial_grammian[0]("r,rp") = MttKRP("r,n") * MttKRP("rp,n");
385392
pq("m,M,P") = ref_orb_a("a,m,M") * MttKRP("P,a");
@@ -394,6 +401,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
394401

395402
UnNormalizedLeft = MttKRP.clone();
396403
this->normalize_factor(MttKRP);
404+
MttKRP.truncate();
397405
cp_factors[1] = MttKRP;
398406
this->partial_grammian[1]("r,rp") = MttKRP("r,n") * MttKRP("rp,n");
399407
THC_times_CPD[0]("m,M,P") = pq("m,M,P") * (ref_orb_b("b,m,M") * MttKRP("P,b"));
@@ -408,6 +416,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
408416
R = math::linalg::lu_inv(TA::einsum(this->partial_grammian[2]("P,Q"), this->partial_grammian[3]("P,Q"),"P,Q"));
409417

410418
cp_factors[2]("P,Q") = L("P,L") * MttKRP("L,M") * R("Q,M");
419+
cp_factors[2].truncate();
411420
}
412421

413422
void update_factors_right(){

0 commit comments

Comments
 (0)