11use  clap:: Parser ; 
22use  cudarc:: cublas; 
33use  cudarc:: cublas:: sys:: cublasOperation_t; 
4- use  cudarc:: cublas:: { Gemm ,  GemmConfig } ; 
4+ use  cudarc:: cublas:: { CudaBlas ,   Gemm ,  GemmConfig } ; 
55use  cudarc:: driver:: result:: mem_get_info; 
6- use  cudarc:: driver:: { sys,  CudaDevice ,  CudaSlice } ; 
6+ use  cudarc:: driver:: { sys,  CudaDevice ,  CudaSlice ,   DeviceRepr ,   ValidAsZeroBits } ; 
77use  nvml_wrapper:: bitmasks:: device:: ThrottleReasons ; 
88use  nvml_wrapper:: enum_wrappers:: device:: TemperatureSensor :: Gpu ; 
99use  nvml_wrapper:: Nvml ; 
1010use  rand:: rngs:: SmallRng ; 
1111use  rand:: RngCore ; 
1212use  rand:: SeedableRng ; 
13+ use  std:: fmt:: Debug ; 
1314use  std:: sync:: Arc ; 
1415use  tokio:: signal; 
1516use  tokio:: sync:: mpsc:: { 
@@ -25,7 +26,7 @@ const GPU_THROTTLING_REASON: &str =
2526const  GPU_FLOPS_REASON :  & str  =
2627    "GPU is not performing as expected. Check the flops values and temperatures" ; 
2728
28- type  AllocBufferTuple  = ( CudaSlice < f32 > ,  CudaSlice < f32 > ,  Vec < CudaSlice < f32 > > ) ; 
29+ type  AllocBufferTuple < T >  = ( CudaSlice < T > ,  CudaSlice < T > ,  Vec < CudaSlice < T > > ) ; 
2930
3031#[ derive( Parser ,  Debug ) ]  
3132#[ clap( author,  version,  about,  long_about = None ) ]  
@@ -43,6 +44,9 @@ struct Args {
4344     /// If the TFLOPS are within `tflops_tolerance`% of the best performing GPU, test will pass 
4445     #[ clap( long,  default_value = "10" ) ]  
4546    tflops_tolerance :  f64 , 
47+     /// Use BF16 precision instead of FP32. GPU must support BF16 type. If unset, will use BF16 only if all GPUs support it. 
48+      #[ clap( long) ]  
49+     use_bf16 :  Option < bool > , 
4650} 
4751
4852#[ derive( Debug ,  Clone ) ]  
@@ -115,6 +119,25 @@ struct Config {
115119    nvml_lib_path :  String , 
116120    tflops_tolerance :  f64 , 
117121    tolerate_software_throttling :  bool , 
122+     use_bf16 :  Option < bool > , 
123+ } 
124+ 
125+ trait  VariablePrecisionFloat : 
126+     Copy  + Debug  + Send  + Sync  + Unpin  + DeviceRepr  + ValidAsZeroBits  + ' static 
127+ { 
128+     fn  from_f32 ( f :  f32 )  -> Self ; 
129+ } 
130+ 
131+ impl  VariablePrecisionFloat  for  f32  { 
132+     fn  from_f32 ( f :  f32 )  -> Self  { 
133+         f
134+     } 
135+ } 
136+ 
137+ impl  VariablePrecisionFloat  for  half:: bf16  { 
138+     fn  from_f32 ( f :  f32 )  -> Self  { 
139+         half:: bf16:: from_f32 ( f) 
140+     } 
118141} 
119142
120143#[ tokio:: main]  
@@ -134,6 +157,7 @@ async fn main() {
134157        nvml_lib_path :  args. nvml_lib_path . clone ( ) , 
135158        tflops_tolerance :  args. tflops_tolerance , 
136159        tolerate_software_throttling :  args. tolerate_software_throttling , 
160+         use_bf16 :  args. use_bf16 , 
137161    } ; 
138162
139163    match  run ( config) . await  { 
@@ -173,15 +197,45 @@ async fn run(config: Config) -> anyhow::Result<()> {
173197        ) ; 
174198    } 
175199
200+     let  use_bf16 = if  let  Some ( requested_bf16)  = config. use_bf16  { 
201+         // If explicitly requested, check if all GPUs support it 
202+         let  all_support_bf16 = gpus. iter ( ) . all ( |gpu| supports_bf16 ( gpu) . unwrap_or ( false ) ) ; 
203+         if  requested_bf16 && !all_support_bf16 { 
204+             return  Err ( anyhow:: anyhow!( 
205+                 "BF16 was explicitly requested but not all GPUs support it. Remove the --use-bf16 flag" 
206+             ) ) ; 
207+         } 
208+         requested_bf16
209+     }  else  { 
210+         // Auto-detect: use BF16 only if all GPUs support it 
211+         gpus. iter ( ) . all ( |gpu| supports_bf16 ( gpu) . unwrap_or ( false ) ) 
212+     } ; 
213+ 
214+     println ! ( "Using {}" ,  if  use_bf16 {  "BF16"  }  else {  "FP32"  } ) ; 
215+ 
216+     if  use_bf16 { 
217+         run_with_precision :: < half:: bf16 > ( config,  gpus) . await 
218+     }  else  { 
219+         run_with_precision :: < f32 > ( config,  gpus) . await 
220+     } 
221+ } 
222+ 
223+ async  fn  run_with_precision < T :  VariablePrecisionFloat > ( 
224+     config :  Config , 
225+     gpus :  Vec < Arc < CudaDevice > > , 
226+ )  -> anyhow:: Result < ( ) > 
227+ where 
228+     CudaBlas :  Gemm < T > , 
229+ { 
176230    // create 2 matrix with random values 
177231    println ! ( "Creating random matrices" ) ; 
178232    // use SmallRng to create random values, we don't need cryptographic security but we need speed 
179233    let  mut  small_rng = SmallRng :: from_entropy ( ) ; 
180-     let  mut  a = vec ! [ 0.0f32 ;  SIZE  *  SIZE ] ; 
181-     let  mut  b = vec ! [ 0.0f32 ;  SIZE  *  SIZE ] ; 
234+     let  mut  a = vec ! [ T :: from_f32 ( 0.0 ) ;  SIZE  *  SIZE ] ; 
235+     let  mut  b = vec ! [ T :: from_f32 ( 0.0 ) ;  SIZE  *  SIZE ] ; 
182236    for  i in  0 ..SIZE  *  SIZE  { 
183-         a[ i]  = small_rng. next_u32 ( )  as  f32 ; 
184-         b[ i]  = small_rng. next_u32 ( )  as  f32 ; 
237+         a[ i]  = T :: from_f32 ( small_rng. next_u32 ( )  as  f32 ) ; 
238+         b[ i]  = T :: from_f32 ( small_rng. next_u32 ( )  as  f32 ) ; 
185239    } 
186240    println ! ( "Matrices created" ) ; 
187241
@@ -221,9 +275,17 @@ async fn run(config: Config) -> anyhow::Result<()> {
221275    } ) ; 
222276    handles. push ( t) ; 
223277    // burn the GPU for given duration 
224-     tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( config. duration_secs ) ) . await ; 
225-     stop. store ( true ,  std:: sync:: atomic:: Ordering :: Relaxed ) ; 
226-     drop ( tx) ; 
278+     let  wait = tokio:: spawn ( async  move  { 
279+         let  mut  interval = tokio:: time:: interval ( std:: time:: Duration :: from_secs ( 1 ) ) ; 
280+         let  mut  tick = 0 ; 
281+         while  !stop. load ( std:: sync:: atomic:: Ordering :: Relaxed )  && tick < config. duration_secs  { 
282+             interval. tick ( ) . await ; 
283+             tick += 1 ; 
284+         } 
285+         stop. store ( true ,  std:: sync:: atomic:: Ordering :: Relaxed ) ; 
286+         drop ( tx) ; 
287+     } ) ; 
288+     handles. push ( wait) ; 
227289    for  handle in  handles { 
228290        handle. await . expect ( "Thread panicked" ) ; 
229291    } 
@@ -252,6 +314,16 @@ fn poll_throttling(nvml: &Nvml, gpu_count: usize) -> anyhow::Result<Vec<Throttle
252314    Ok ( throttling) 
253315} 
254316
317+ fn  supports_bf16 ( gpu :  & Arc < CudaDevice > )  -> anyhow:: Result < bool >  { 
318+     Ok ( 
319+         gpu. attribute ( sys:: CUdevice_attribute_enum :: CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR ) ?
320+             >= 8 
321+             && gpu. attribute ( 
322+                 sys:: CUdevice_attribute_enum :: CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR , 
323+             ) ? >= 0 , 
324+     ) 
325+ } 
326+ 
255327async  fn  report_progress ( 
256328    config :  Config , 
257329    gpu_count :  usize , 
@@ -408,13 +480,16 @@ fn are_gpus_healthy(
408480    ( reasons. is_empty ( ) ,  reasons) 
409481} 
410482
411- async  fn  burn_gpu ( 
483+ async  fn  burn_gpu < T :   VariablePrecisionFloat > ( 
412484    gpu_idx :  usize , 
413-     a :  Vec < f32 > , 
414-     b :  Vec < f32 > , 
485+     a :  Vec < T > , 
486+     b :  Vec < T > , 
415487    tx :  Sender < ( usize ,  usize ) > , 
416488    stop :  Arc < std:: sync:: atomic:: AtomicBool > , 
417- )  -> anyhow:: Result < usize >  { 
489+ )  -> anyhow:: Result < usize > 
490+ where 
491+     CudaBlas :  Gemm < T > , 
492+ { 
418493    let  gpu = CudaDevice :: new ( gpu_idx) ?; 
419494    // compute the output matrix size 
420495    let  ( free_mem,  _)  = get_gpu_memory ( gpu. clone ( ) ) ?; 
@@ -425,8 +500,8 @@ async fn burn_gpu(
425500        mem_to_use / 1024  / 1024 , 
426501        free_mem / 1024  / 1024 
427502    ) ; 
428-     let  iters =
429-         ( mem_to_use -  2   *   SIZE   *   SIZE   *   size_of :: < f32 > ( ) )   / ( SIZE  *  SIZE  *  size_of :: < f32 > ( ) ) ; 
503+     let  iters =  ( mem_to_use -  2   *   SIZE   *   SIZE   *   get_memory_size :: < T > ( ) ) 
504+         / ( SIZE  *  SIZE  *  get_memory_size :: < T > ( ) ) ; 
430505    let  ( a_gpu,  b_gpu,  mut  out_slices_gpu)  = alloc_buffers ( gpu. clone ( ) ,  a,  b,  iters) ?; 
431506    let  handle = cublas:: safe:: CudaBlas :: new ( gpu) ?; 
432507    let  mut  i = 0 ; 
@@ -441,44 +516,51 @@ async fn burn_gpu(
441516    Ok ( i) 
442517} 
443518
519+ fn  get_memory_size < T > ( )  -> usize  { 
520+     size_of :: < T > ( ) 
521+ } 
522+ 
444523fn  get_gpu_memory ( gpu :  Arc < CudaDevice > )  -> anyhow:: Result < ( usize ,  usize ) >  { 
445524    CudaDevice :: new ( gpu. ordinal ( ) ) ?; 
446525    let  mem_info = mem_get_info ( ) ?; 
447526    Ok ( mem_info) 
448527} 
449528
450- fn  alloc_buffers ( 
529+ fn  alloc_buffers < T :   VariablePrecisionFloat > ( 
451530    gpu :  Arc < CudaDevice > , 
452-     a :  Vec < f32 > , 
453-     b :  Vec < f32 > , 
531+     a :  Vec < T > , 
532+     b :  Vec < T > , 
454533    num_out_slices :  usize , 
455- )  -> anyhow:: Result < AllocBufferTuple >  { 
534+ )  -> anyhow:: Result < AllocBufferTuple < T > >  { 
456535    let  a_gpu = gpu. htod_copy ( a) ?; 
457536    let  b_gpu = gpu. htod_copy ( b) ?; 
458537    let  mut  out_slices = vec ! [ ] ; 
459538    for  _ in  0 ..num_out_slices { 
460-         let  out = gpu. alloc_zeros :: < f32 > ( SIZE  *  SIZE ) ?; 
539+         let  out = gpu. alloc_zeros :: < T > ( SIZE  *  SIZE ) ?; 
461540        out_slices. push ( out) ; 
462541    } 
463542    Ok ( ( a_gpu,  b_gpu,  out_slices) ) 
464543} 
465544
466- fn  compute ( 
545+ fn  compute < T :   VariablePrecisionFloat > ( 
467546    handle :  & cublas:: safe:: CudaBlas , 
468-     a :  & CudaSlice < f32 > , 
469-     b :  & CudaSlice < f32 > , 
470-     out :  & mut  CudaSlice < f32 > , 
471- )  -> anyhow:: Result < ( ) >  { 
547+     a :  & CudaSlice < T > , 
548+     b :  & CudaSlice < T > , 
549+     out :  & mut  CudaSlice < T > , 
550+ )  -> anyhow:: Result < ( ) > 
551+ where 
552+     CudaBlas :  Gemm < T > , 
553+ { 
472554    let  cfg = GemmConfig  { 
473555        transa :  cublasOperation_t:: CUBLAS_OP_N , 
474556        transb :  cublasOperation_t:: CUBLAS_OP_N , 
475557        m :  SIZE  as  i32 , 
476558        n :  SIZE  as  i32 , 
477559        k :  SIZE  as  i32 , 
478-         alpha :  1.0 , 
560+         alpha :  T :: from_f32 ( 1.0 ) , 
479561        lda :  SIZE  as  i32 , 
480562        ldb :  SIZE  as  i32 , 
481-         beta :  0.0 , 
563+         beta :  T :: from_f32 ( 0.0 ) , 
482564        ldc :  SIZE  as  i32 , 
483565    } ; 
484566    unsafe  { 
0 commit comments