-
Notifications
You must be signed in to change notification settings - Fork 28
Description
I'm trying to read through the code to understand it and noticed something that might be a mistake or just it's my misunderstanding of the code:
static void matmul(float[] xout, float[] x, FloatBuffer w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
MemorySegment wSegment = MemorySegment.ofBuffer(w);
IntStream.range(0, d).parallel().forEach(i -> {
float val = 0f;
int j = 0;
if (USE_VECTOR_API) {
VectorSpecies<Float> species = FloatVector.SPECIES_256;
FloatVector sum0 = FloatVector.zero(species);
FloatVector sum1 = FloatVector.zero(species);
FloatVector sum2 = FloatVector.zero(species);
FloatVector sum3 = FloatVector.zero(species);
int width = species.length();
int upperBound = n - n % (4 * width);
for (; j < upperBound; j += 4 * width) {
var wj0 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 0 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var wj1 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 1 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var wj2 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 2 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var wj3 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 3 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var xj0 = FloatVector.fromArray(species, x, j + 0 * width);
var xj1 = FloatVector.fromArray(species, x, j + 1 * width);
var xj2 = FloatVector.fromArray(species, x, j + 2 * width);
var xj3 = FloatVector.fromArray(species, x, j + 3 * width);
sum0 = wj0.fma(xj0, sum0);
sum1 = wj1.fma(xj1, sum1);
sum2 = wj2.fma(xj2, sum2);
sum3 = wj3.fma(xj3, sum3);
}
val = sum0.add(sum1).add(sum2).add(sum3).reduceLanes(VectorOperators.ADD);
}
// Graal's auto-vectorization.
int upperBound = n & ~3;
float[] sum = new float[4];
for (; j < upperBound; j += sum.length) {
sum[0] += w.get(i * n + j + 0) * x[j + 0];
sum[1] += w.get(i * n + j + 1) * x[j + 1];
sum[2] += w.get(i * n + j + 2) * x[j + 2];
sum[3] += w.get(i * n + j + 3) * x[j + 3];
}
val += sum[0] + sum[1] + sum[2] + sum[3];
for (; j < n; j++) {
val += w.get(i * n + j) * x[j];
}
xout[i] = val;
});
}
First, there's a small inefficiency in the if (USE_VECTOR_API) {
line. Since that is a constant, having the if statement for every forEach
call is inefficient. The JIT might optimize it away eventually but I would still have that outside of the block to keep the code efficient as the JIT isn't magic.
The main thing that isn't clear to me. It seems the code does the operation twice when running under a regular JIT. Shouldn't the rest of the code be under an else statement?
This is how I think it should be if I'm reading the code correctly. I haven't tested it though so I might be completely off here:
static void matmul(float[] xout, float[] x, FloatBuffer w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
MemorySegment wSegment = MemorySegment.ofBuffer(w);
if (USE_VECTOR_API) {
IntStream.range(0, d).parallel().forEach(i -> {
int j = 0;
VectorSpecies<Float> species = FloatVector.SPECIES_256;
FloatVector sum0 = FloatVector.zero(species);
FloatVector sum1 = FloatVector.zero(species);
FloatVector sum2 = FloatVector.zero(species);
FloatVector sum3 = FloatVector.zero(species);
int width = species.length();
int upperBound = n - n % (4 * width);
for (; j < upperBound; j += 4 * width) {
var wj0 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 0 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var wj1 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 1 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var wj2 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 2 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var wj3 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 3 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
var xj0 = FloatVector.fromArray(species, x, j + 0 * width);
var xj1 = FloatVector.fromArray(species, x, j + 1 * width);
var xj2 = FloatVector.fromArray(species, x, j + 2 * width);
var xj3 = FloatVector.fromArray(species, x, j + 3 * width);
sum0 = wj0.fma(xj0, sum0);
sum1 = wj1.fma(xj1, sum1);
sum2 = wj2.fma(xj2, sum2);
sum3 = wj3.fma(xj3, sum3);
}
xout[i] = sum0.add(sum1).add(sum2).add(sum3).reduceLanes(VectorOperators.ADD);
});
} else {
// Graal's auto-vectorization.
IntStream.range(0, d).parallel().forEach(i -> {
int j = 0;
float val = 0;
int upperBound = n & ~3;
float[] sum = new float[4];
for (; j < upperBound; j += sum.length) {
sum[0] += w.get(i * n + j) * x[j];
sum[1] += w.get(i * n + j + 1) * x[j + 1];
sum[2] += w.get(i * n + j + 2) * x[j + 2];
sum[3] += w.get(i * n + j + 3) * x[j + 3];
}
val += sum[0] + sum[1] + sum[2] + sum[3];
for (; j < n; j++) {
val += w.get(i * n + j) * x[j];
}
xout[i] = val;
});
}
}
Thanks for the project. It's very interesting!