@@ -489,16 +489,33 @@ size_all_negative(::SpinGlass) = false
489489size_all_positive (:: SpinGlass ) =  false 
490490
491491#  NOTE: `findmin` and `findmax` are required by `ProblemReductions.jl`
492+ """ 
493+     GTNSolver(; optimizer=TreeSA(), single=false, usecuda=false, T=Float64) 
494+ 
495+ A generic tensor network based backend for the `findbest`, `findmin` and `findmax` interfaces in `ProblemReductions.jl`. 
496+ 
497+ Keyword arguments 
498+ ------------------------------------- 
499+ * `optimizer` is the optimizer for the tensor network contraction. 
500+ * `single` is a switch to return single solution instead of all solutions. 
501+ * `usecuda` is a switch to use CUDA (when applicable), user need to call statement `using CUDA` before turning on this switch. 
502+ * `T` is the "base" element type, sometimes can be used to reduce the memory cost. 
503+ """ 
492504Base. @kwdef  struct  GTNSolver
493505    optimizer:: OMEinsum.CodeOptimizer  =  TreeSA ()
506+     single:: Bool  =  false 
494507    usecuda:: Bool  =  false 
495508    T:: Type  =  Float64
496509end 
497- function  Base. findmin (problem:: AbstractProblem , solver:: GTNSolver )
498-     res =  collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), ConfigsMin (; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
499-     return  map (x ->  ProblemReductions. id_to_config (problem, Int .(x) .+  1 ), res)
500- end 
501- function  Base. findmax (problem:: AbstractProblem , solver:: GTNSolver )
502-     res =  collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), ConfigsMax (; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
503-     return  map (x ->  ProblemReductions. id_to_config (problem, Int .(x) .+  1 ), res)
504- end 
510+ for  (PROP, SPROP, SOLVER) in  [
511+         (:ConfigsMin , :SingleConfigMin , :findmin ), (:ConfigsMax , :SingleConfigMax , :findmax )
512+     ]
513+     @eval  function  Base. $ (SOLVER)(problem:: AbstractProblem , solver:: GTNSolver )
514+         if  solver. single
515+             res =  [solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), $ (SPROP)(); usecuda= solver. usecuda, T= solver. T)[]. c. data]
516+         else 
517+             res =  collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), $ (PROP)(; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
518+         end 
519+         return  map (x ->  ProblemReductions. id_to_config (problem, Int .(x) .+  1 ), res)
520+     end 
521+ end 
0 commit comments