AVX Playground¶
This notebook allows interactive experimentation with AVX instructions operating on 256-bit vectors represented in numpy arrays.
Python vector intrinsics: You saw it here first! :)
In [1]:
import numpy as np
import ctypes
In [3]:
def address_from_numpy(obj):
ary_intf = getattr(obj, "__array_interface__", None)
if ary_intf is None:
raise RuntimeError("no array interface")
buf_base, is_read_only = ary_intf["data"]
return buf_base + ary_intf.get("offset", 0)
def cptr_from_numpy(obj):
return ctypes.c_void_p(address_from_numpy(obj))
In [6]:
def make_func(operation):
c_code = """
#include <x86intrin.h>
void f(float *a, float *b, float *out)
__m256 avec = _mm256_loadu_ps(a);
__m256 bvec = _mm256_loadu_ps(b);
__m256 result = {operation};
_mm256_storeu_ps(out, result);
from os.path import join
from tempfile import mkdtemp
tempdir = mkdtemp()
import subprocess
with open(join(tempdir, "code.c"), "w") as outf:
cc_proc = subprocess.Popen(
["gcc", "-march=sandybridge", "-shared", join(tempdir, "code.c"), "-o", join(tempdir, "code.so")],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
cc_stdout, cc_stderr = cc_proc.communicate(timeout=4)
if cc_proc.returncode:
raise RuntimeError("C compiler failed. It said:\n<pre>%s</pre>"
% (cc_stdout+cc_stderr).decode())
user_dll = ctypes.CDLL(join(tempdir, "code.so"))
user_func = user_dll.f
user_func.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
user_func.restype = None
def wrapper(a, b=None):
if b is None:
b = np.empty((8,), np.float32)
assert a.dtype == np.float32
assert b.dtype == np.float32
assert a.shape == (8,)
assert b.shape == (8,)
result = np.empty((8,), np.float32)
user_func(cptr_from_numpy(a), cptr_from_numpy(b), cptr_from_numpy(result))
return result
return wrapper
This cell uses the helper above to create the vector "intrinsics". For now, only a few are covered. But it's easy to add more, just by following the pattern.
In [8]:
unpackhi = make_func("_mm256_unpackhi_ps(avec, bvec)")
unpacklo = make_func("_mm256_unpacklo_ps(avec, bvec)")
def make_permute2(sel_lower, sel_upper):
imm = sel_upper << 4 | sel_lower
return make_func(f"_mm256_permute2f128_ps(avec, bvec, {imm})")
permute2_02 = make_permute2(0, 2)
permute2_13 = make_permute2(1, 3)
permute2_12 = make_permute2(1, 2)
permute2_21 = make_permute2(2, 1)
def make_permute(sel_0, sel_1, sel_2, sel_3):
imm = sel_3 << 6 | sel_2 << 4 | sel_1 << 2 | sel_0
return make_func(f"_mm256_permute_ps(avec, {imm})")
permute_3210 = make_permute(3, 2, 1, 0)
permute_2301 = make_permute(2, 3, 0, 1)
def make_shuffle(sel_0, sel_1, sel_2, sel_3):
imm = sel_3 << 6 | sel_2 << 4 | sel_1 << 2 | sel_0
return make_func(f"_mm256_shuffle_ps(avec, bvec, {imm})")
shuffle_0123 = make_shuffle(0, 1, 2, 3)
shuffle_1032 = make_shuffle(1, 0, 3, 2)
shuffle_2301 = make_shuffle(2, 3, 0, 1)
shuffle_0101 = make_shuffle(0, 1, 0, 1)
shuffle_2323 = make_shuffle(2, 3, 2, 3)
The examples below demonstrate how to use these functions:
In [9]:
A = np.arange(8)[:, np.newaxis]*10 + np.arange(6)
A = A.astype(np.float32)
In [10]:
array([[ 0., 1., 2., 3., 4., 5.], [10., 11., 12., 13., 14., 15.], [20., 21., 22., 23., 24., 25.], [30., 31., 32., 33., 34., 35.], [40., 41., 42., 43., 44., 45.], [50., 51., 52., 53., 54., 55.], [60., 61., 62., 63., 64., 65.], [70., 71., 72., 73., 74., 75.]], dtype=float32)
In [11]:
Avec = A.reshape(-1, 8)
array([[ 0., 1., 2., 3., 4., 5., 10., 11.], [12., 13., 14., 15., 20., 21., 22., 23.], [24., 25., 30., 31., 32., 33., 34., 35.], [40., 41., 42., 43., 44., 45., 50., 51.], [52., 53., 54., 55., 60., 61., 62., 63.], [64., 65., 70., 71., 72., 73., 74., 75.]], dtype=float32)
In [12]:
permute2_12(Avec[0], Avec[1])
array([ 4., 5., 10., 11., 12., 13., 14., 15.], dtype=float32)
In [13]:
v1 = np.empty(8, np.float32)
v1[:] = np.arange(8)
v2 = np.empty(8, np.float32)
v2[:] = np.arange(8)+10
shuffle_2301(v1, v2)
array([ 2., 3., 10., 11., 6., 7., 14., 15.], dtype=float32)
In [15]:
A2 = np.empty(Avec.shape, A.dtype)
A2[0] = unpacklo(Avec[0], Avec[1])
A2[1] = unpackhi(Avec[0], Avec[1])
array([[ 0., 12., 1., 13., 4., 20., 5., 21.], [ 2., 14., 3., 15., 10., 22., 11., 23.], [ 0., 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
In [ ]: