@@ -80,6 +80,7 @@ def calc_arrayfire(A, b, x0, maxiter=10):
8080 beta_num = af .dot (r , r )
8181 beta = beta_num / alpha_num
8282 p = r + af .tile (beta , p .dims ()[0 ]) * p
83+ af .eval (x )
8384 res = x0 - x
8485 return x , af .dot (res , res )
8586
@@ -137,11 +138,11 @@ def timeit(calc, iters, args):
137138
138139def test ():
139140 print ("\n Testing benchmark functions..." )
140- A , b , x0 = setup_input (50 ) # dense A
141+ A , b , x0 = setup_input (n = 50 , sparsity = 7 ) # dense A
141142 Asp = to_sparse (A )
142143 x1 , _ = calc_arrayfire (A , b , x0 )
143144 x2 , _ = calc_arrayfire (Asp , b , x0 )
144- if af .sum (af .abs (x1 - x2 )/ x2 > 1e-6 ):
145+ if af .sum (af .abs (x1 - x2 )/ x2 > 1e-5 ):
145146 raise ValueError ("arrayfire test failed" )
146147 if np :
147148 An = to_numpy (A )
@@ -162,11 +163,13 @@ def test():
162163
163164
164165def bench (n = 4 * 1024 , sparsity = 7 , maxiter = 10 , iters = 10 ):
166+
165167 # generate data
166168 print ("\n Generating benchmark data for n = %i ..." % n )
167169 A , b , x0 = setup_input (n , sparsity ) # dense A
168170 Asp = to_sparse (A ) # sparse A
169171 input_info (A , Asp )
172+
170173 # make benchmarks
171174 print ("Benchmarking CG solver for n = %i ..." % n )
172175 t1 = timeit (calc_arrayfire , iters , args = (A , b , x0 , maxiter ))
@@ -192,9 +195,8 @@ def bench(n=4*1024, sparsity=7, maxiter=10, iters=10):
192195 if (len (sys .argv ) > 1 ):
193196 af .set_device (int (sys .argv [1 ]))
194197
195- af .info ()
196-
198+ af .info ()
197199 test ()
198-
200+
199201 for n in (128 , 256 , 512 , 1024 , 2048 , 4096 ):
200202 bench (n )
0 commit comments