@@ -191,7 +191,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
191191 // Checking the convergence of the thing
192192
193193 update_factors_left ();
194- {
194+ /* {
195195 DistArray<Tile, Policy> abcd_old, abcd_new, diff;
196196 abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
197197 ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
@@ -203,7 +203,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
203203 std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
204204 std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
205205 std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
206- }
206+ }*/
207207 // Preserve symmetry in the structure
208208 {
209209 cp_factors[3 ] = cp_factors[0 ].clone ();
@@ -216,7 +216,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
216216 }
217217 // Update the core tensor and don't rescale to normalized
218218 update_core ();
219- {
219+ /* {
220220 DistArray<Tile, Policy> abcd_old, abcd_new, diff;
221221 abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
222222 ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
@@ -228,24 +228,25 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
228228 std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
229229 std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
230230 std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
231- }
231+ }*/
232232 // update_factors_right();
233- // {
234- // DistArray<Tile, Policy> abcd_old, abcd_new, diff;
235- // abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
236- // ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
237- // abcd_new("a,b,c,d") = TA::einsum(cp_factors[0]("P,a"), cp_factors[1]("P,b"), "P,a,b")("P,a,b") *
238- // (UnNormalizedLeft("P,X") * UnNormalizedRight("Q,X")) * TA::einsum(cp_factors[3]("Q,c"), cp_factors[4]("Q,d"), "Q,c,d")("Q,c,d");
239- // diff("a,b,c,d") = abcd_new("a,b,c,d") - abcd_old("a,b,c,d");
240- // std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
241- // std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
242- // std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
243- // }
233+ /* {
234+ DistArray<Tile, Policy> abcd_old, abcd_new, diff;
235+ abcd_old("a,b,c,d") = TA::einsum(ref_orb_a("a,m,P"), ref_orb_b("b,m,P"), "a,b,m,P")("a,b,m,P") *
236+ ref_core("P,Q") * TA::einsum(ref_orb_c("c,m,Q"), ref_orb_d("d,m,Q"), "c,d,m,Q")("c,d,m,Q");
237+ abcd_new("a,b,c,d") = TA::einsum(cp_factors[0]("P,a"), cp_factors[1]("P,b"), "P,a,b")("P,a,b") *
238+ (UnNormalizedLeft("P,X") * UnNormalizedRight("Q,X")) * TA::einsum(cp_factors[3]("Q,c"), cp_factors[4]("Q,d"), "Q,c,d")("Q,c,d");
239+ diff("a,b,c,d") = abcd_new("a,b,c,d") - abcd_old("a,b,c,d");
240+ std::cout << "Norm old: " << TA::norm2(abcd_old) << std::endl;
241+ std::cout << "Norm new: " << TA::norm2(abcd_new) << std::endl;
242+ std::cout << "Error: " << TA::norm2(diff) / TA::norm2(abcd_old) << std::endl;
243+ }*/
244244
245245 converged = this ->check_thc_fit (verbose);
246246
247247 ++iter;
248248 } while (iter < max_iter && !converged);
249+ this ->unNormalized_Factor = cp_factors[4 ];
249250 }
250251
251252 // These assume the center is a sqrt of the core tensor.
@@ -359,12 +360,17 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
359360 this->normalize_factor(MTtKRP);
360361 cp_factors[5] = MTtKRP;
361362 }*/
363+
364+ //
362365 void update_factors_left (){
363366 DistArray<Tile, Policy> env, b_mON, MttKRP, W, W_env, pq;
364367
365368 // solve for A
366369 env (" m,M,P" ) = ref_core (" M,N" ) * THC_times_CPD[1 ](" m,N,Q" ) * cp_factors[2 ](" P,Q" );
367370 b_mON (" m,M,P" ) = ref_orb_b (" b,m,M" ) * cp_factors[1 ](" P,b" );
371+ env.truncate ();
372+ b_mON.truncate ();
373+
368374 MttKRP (" P,a" ) = TA::einsum (env (" m,M,P" ), b_mON (" m,M,P" ), " m,M,P" )(" m,M,P" ) * ref_orb_a (" a,m,M" );
369375
370376 DistArray<Tile, Policy> temp;
@@ -380,6 +386,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
380386 world.gop .fence (); // N.B. seems to deadlock without this
381387
382388 this ->normalize_factor (MttKRP);
389+ MttKRP.truncate ();
383390 cp_factors[0 ] = MttKRP;
384391 this ->partial_grammian [0 ](" r,rp" ) = MttKRP (" r,n" ) * MttKRP (" rp,n" );
385392 pq (" m,M,P" ) = ref_orb_a (" a,m,M" ) * MttKRP (" P,a" );
@@ -394,6 +401,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
394401
395402 UnNormalizedLeft = MttKRP.clone ();
396403 this ->normalize_factor (MttKRP);
404+ MttKRP.truncate ();
397405 cp_factors[1 ] = MttKRP;
398406 this ->partial_grammian [1 ](" r,rp" ) = MttKRP (" r,n" ) * MttKRP (" rp,n" );
399407 THC_times_CPD[0 ](" m,M,P" ) = pq (" m,M,P" ) * (ref_orb_b (" b,m,M" ) * MttKRP (" P,b" ));
@@ -408,6 +416,7 @@ class THC_LT_THC_ALS : public CP<Tile, Policy> {
408416 R = math::linalg::lu_inv (TA::einsum (this ->partial_grammian [2 ](" P,Q" ), this ->partial_grammian [3 ](" P,Q" )," P,Q" ));
409417
410418 cp_factors[2 ](" P,Q" ) = L (" P,L" ) * MttKRP (" L,M" ) * R (" Q,M" );
419+ cp_factors[2 ].truncate ();
411420 }
412421
413422 void update_factors_right (){
0 commit comments