Skip to content

Commit 79e4b16

Browse files
committed
BUG: Fix blas + axpy by only using it for certain cases
Only use blas axpy for 1D arrays or contiguous arrays.
1 parent e1decd8 commit 79e4b16

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

src/linalg/impl_linalg.rs

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,10 @@ impl<A, S, D> ArrayBase<S, D>
281281
D: Dimension,
282282
{
283283
/// Perform the operation `self += alpha * rhs` efficiently, where
284-
/// `alpha` is a scalar and `rhs` is another array. This operation is
285-
/// also known as `axpy` in BLAS.
284+
/// `alpha` is a scalar and `rhs` is another array.
285+
///
286+
/// This operation calls BLAS `axpy` if blas is enabled, the scalar type
287+
/// is `f32` or `f64` and the dimensions are compatible.
286288
///
287289
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
288290
///
@@ -322,25 +324,19 @@ impl<A, S, D> ArrayBase<S, D>
322324
A: LinalgScalar,
323325
E: Dimension,
324326
{
325-
debug_assert_eq!(self.len(), rhs.len());
326-
assert!(self.len() == rhs.len());
327-
{
327+
if (same_type::<A, f32>() || same_type::<A, f64>()) &&
328+
self.len() == rhs.len() {
328329
macro_rules! axpy {
329330
($ty:ty, $func:ident) => {{
330-
if blas_compat::<$ty, _, _>(self) && blas_compat::<$ty, _, _>(rhs) {
331-
let order = Dimension::_fastest_varying_stride_order(&self.strides);
332-
let incx = self.strides()[order[0]];
333-
334-
let order = Dimension::_fastest_varying_stride_order(&rhs.strides);
335-
let incy = self.strides()[order[0]];
336-
331+
if let Some(self_stride) = blas_compat_axpy::<$ty, _, _>(self) {
332+
if let Some(rhs_stride) = blas_compat_axpy::<$ty, _, _>(rhs) {
337333
unsafe {
338334
let (lhs_ptr, n, incx) = blas_1d_params(self.ptr,
339335
self.len(),
340-
incx);
336+
self_stride);
341337
let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr,
342-
rhs.len(),
343-
incy);
338+
self.len(),
339+
rhs_stride);
344340
blas_sys::c::$func(
345341
n,
346342
cast_as(&alpha),
@@ -351,8 +347,10 @@ impl<A, S, D> ArrayBase<S, D>
351347
return;
352348
}
353349
}
354-
}}
350+
}
351+
}}
355352
}
353+
356354
axpy!{f32, cblas_saxpy};
357355
axpy!{f64, cblas_daxpy};
358356
}
@@ -595,29 +593,29 @@ fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
595593
}
596594

597595
#[cfg(feature="blas")]
598-
fn blas_compat<A, S, D>(a: &ArrayBase<S, D>) -> bool
596+
// Return the interelement stride if the memory layout is blas compatible
597+
fn blas_compat_axpy<A, S, D>(a: &ArrayBase<S, D>) -> Option<isize>
599598
where S: Data,
600599
A: 'static,
601600
S::Elem: 'static,
602601
D: Dimension,
603602
{
604603
if !same_type::<A, S::Elem>() {
605-
return false;
604+
return None
606605
}
607606

608-
match D::equispaced_stride(&a.raw_dim(), &a.strides) {
609-
Some(stride) => {
610-
if a.len() as isize * stride > blas_index::max_value() as isize ||
611-
stride < blas_index::min_value() as isize {
612-
return false;
613-
}
614-
},
615-
None => {
616-
return false;
607+
// FIXME: Support more memory layouts
608+
if a.ndim() == 1 {
609+
let stride = a.strides()[0];
610+
if (a.len() * stride.abs() as usize) < blas_index::max_value() as usize {
611+
return Some(stride)
612+
}
613+
} else if D::is_contiguous(&a.dim, &a.strides) {
614+
if a.len() < blas_index::max_value() as usize {
615+
return Some(1)
617616
}
618617
}
619-
620-
true
618+
None
621619
}
622620

623621
#[cfg(feature="blas")]

0 commit comments

Comments
 (0)