Skip to content

Possibly inefficient code #9

@shai-almog

Description

@shai-almog

I'm trying to read through the code to understand it and noticed something that might be a mistake or just it's my misunderstanding of the code:

    static void matmul(float[] xout, float[] x, FloatBuffer w, int n, int d) {
        // W (d,n) @ x (n,) -> xout (d,)
        // by far the most amount of time is spent inside this little function
        MemorySegment wSegment = MemorySegment.ofBuffer(w);
        IntStream.range(0, d).parallel().forEach(i -> {
            float val = 0f;
            int j = 0;
            if (USE_VECTOR_API) {
                VectorSpecies<Float> species = FloatVector.SPECIES_256;
                FloatVector sum0 = FloatVector.zero(species);
                FloatVector sum1 = FloatVector.zero(species);
                FloatVector sum2 = FloatVector.zero(species);
                FloatVector sum3 = FloatVector.zero(species);
                int width = species.length();
                int upperBound = n - n % (4 * width);
                for (; j < upperBound; j += 4 * width) {
                    var wj0 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 0 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var wj1 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 1 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var wj2 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 2 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var wj3 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 3 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                    var xj0 = FloatVector.fromArray(species, x, j + 0 * width);
                    var xj1 = FloatVector.fromArray(species, x, j + 1 * width);
                    var xj2 = FloatVector.fromArray(species, x, j + 2 * width);
                    var xj3 = FloatVector.fromArray(species, x, j + 3 * width);
                    sum0 = wj0.fma(xj0, sum0);
                    sum1 = wj1.fma(xj1, sum1);
                    sum2 = wj2.fma(xj2, sum2);
                    sum3 = wj3.fma(xj3, sum3);
                }
                val = sum0.add(sum1).add(sum2).add(sum3).reduceLanes(VectorOperators.ADD);
            }

            // Graal's auto-vectorization.
            int upperBound = n & ~3;
            float[] sum = new float[4];
            for (; j < upperBound; j += sum.length) {
                sum[0] += w.get(i * n + j + 0) * x[j + 0];
                sum[1] += w.get(i * n + j + 1) * x[j + 1];
                sum[2] += w.get(i * n + j + 2) * x[j + 2];
                sum[3] += w.get(i * n + j + 3) * x[j + 3];
            }
            val += sum[0] + sum[1] + sum[2] + sum[3];

            for (; j < n; j++) {
                val += w.get(i * n + j) * x[j];
            }
            xout[i] = val;
        });
    }

First, there's a small inefficiency in the if (USE_VECTOR_API) { line. Since that is a constant, having the if statement for every forEach call is inefficient. The JIT might optimize it away eventually but I would still have that outside of the block to keep the code efficient as the JIT isn't magic.

The main thing that isn't clear to me. It seems the code does the operation twice when running under a regular JIT. Shouldn't the rest of the code be under an else statement?

This is how I think it should be if I'm reading the code correctly. I haven't tested it though so I might be completely off here:

static void matmul(float[] xout, float[] x, FloatBuffer w, int n, int d) {
    // W (d,n) @ x (n,) -> xout (d,)
    // by far the most amount of time is spent inside this little function
    MemorySegment wSegment = MemorySegment.ofBuffer(w);
    if (USE_VECTOR_API) {
        IntStream.range(0, d).parallel().forEach(i -> {
            int j = 0;
            VectorSpecies<Float> species = FloatVector.SPECIES_256;
            FloatVector sum0 = FloatVector.zero(species);
            FloatVector sum1 = FloatVector.zero(species);
            FloatVector sum2 = FloatVector.zero(species);
            FloatVector sum3 = FloatVector.zero(species);
            int width = species.length();
            int upperBound = n - n % (4 * width);
            for (; j < upperBound; j += 4 * width) {
                var wj0 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 0 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var wj1 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 1 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var wj2 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 2 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var wj3 = FloatVector.fromMemorySegment(species, wSegment, (i * n + j + 3 * width) * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
                var xj0 = FloatVector.fromArray(species, x, j + 0 * width);
                var xj1 = FloatVector.fromArray(species, x, j + 1 * width);
                var xj2 = FloatVector.fromArray(species, x, j + 2 * width);
                var xj3 = FloatVector.fromArray(species, x, j + 3 * width);
                sum0 = wj0.fma(xj0, sum0);
                sum1 = wj1.fma(xj1, sum1);
                sum2 = wj2.fma(xj2, sum2);
                sum3 = wj3.fma(xj3, sum3);
            }
            xout[i] = sum0.add(sum1).add(sum2).add(sum3).reduceLanes(VectorOperators.ADD);
        });
    } else {
        // Graal's auto-vectorization.
        IntStream.range(0, d).parallel().forEach(i -> {
            int j = 0;
            float val = 0;
            int upperBound = n & ~3;
            float[] sum = new float[4];
            for (; j < upperBound; j += sum.length) {
                sum[0] += w.get(i * n + j) * x[j];
                sum[1] += w.get(i * n + j + 1) * x[j + 1];
                sum[2] += w.get(i * n + j + 2) * x[j + 2];
                sum[3] += w.get(i * n + j + 3) * x[j + 3];
            }
            val += sum[0] + sum[1] + sum[2] + sum[3];

            for (; j < n; j++) {
                val += w.get(i * n + j) * x[j];
            }
            xout[i] = val;
        });
    }
}

Thanks for the project. It's very interesting!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions