PTX and SASS (Nvidia)¶
C Kernel to Source Code¶
In [1]:
import numpy as np
import pyopencl as cl
import pyopencl.array as cla
ctx = cl.create_some_context(answers=["nvi", 1])
queue = cl.CommandQueue(ctx)
/usr/lib/python3.6/importlib/_bootstrap_external.py:426: ImportWarning: Not importing directory /home/andreask_work/src/env-3.6/lib/python3.6/site-packages/sphinxcontrib: missing __init__ _warnings.warn(msg.format(portions[0]), ImportWarning)
In [2]:
prg = cl.Program(ctx, """
__kernel void sum(
__global const float *a_g, __global const float *b_g, __global float *res_g)
{
int gid = get_global_id(0);
res_g[gid] = a_g[gid] + b_g[gid];
}
""").build()
print(prg.binaries[0].decode())
// // Generated by NVIDIA NVVM Compiler // // Compiler Build ID: CL-24786619 // Driver 390.87 // Based on LLVM 3.4svn // .version 6.1 .target sm_52, texmode_independent .address_size 64 // .globl sum .const .align 4 .u32 pyopencl_defeat_cache_ba263b69975445bfabd08e2a274f5a06; .entry sum( .param .u64 .ptr .global .align 4 sum_param_0, .param .u64 .ptr .global .align 4 sum_param_1, .param .u64 .ptr .global .align 4 sum_param_2 ) { .reg .f32 %f<4>; .reg .b32 %r<7>; .reg .b64 %rd<8>; ld.param.u64 %rd1, [sum_param_0]; ld.param.u64 %rd2, [sum_param_1]; ld.param.u64 %rd3, [sum_param_2]; mov.b32 %r1, %envreg3; mov.u32 %r2, %ntid.x; mov.u32 %r3, %ctaid.x; mad.lo.s32 %r4, %r3, %r2, %r1; mov.u32 %r5, %tid.x; add.s32 %r6, %r4, %r5; mul.wide.s32 %rd4, %r6, 4; add.s64 %rd5, %rd1, %rd4; ld.global.f32 %f1, [%rd5]; add.s64 %rd6, %rd2, %rd4; ld.global.f32 %f2, [%rd6]; add.f32 %f3, %f1, %f2; add.s64 %rd7, %rd3, %rd4; st.global.f32 [%rd7], %f3; ret; }
Comments:
- Intel or AT&T style?
- Note: address spaces always explicit
- What is
ctaid.x
?%ntid.x
? - How does parameter passing work?
- Is this the lowest-level abstraction?
In [3]:
!mkdir -p tmp
hacked_binary = prg.binaries[0].replace(b".version 6.1", b".version 6.0")
with open("tmp/binary.ptx", "wb") as outf:
outf.write(hacked_binary)
!(cd tmp; ptxas --gpu-name sm_61 --verbose binary.ptx -o binary.o)
ptxas info : 0 bytes gmem, 4 bytes cmem[3] ptxas info : Compiling entry function 'sum' for 'sm_61' ptxas info : Function properties for sum 0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads ptxas info : Used 8 registers, 344 bytes cmem[0]
In [4]:
!/usr/local/cuda/bin/cuobjdump --dump-sass tmp/binary.o
code for sm_61 Function : sum .headerflags @"EF_CUDA_SM61 EF_CUDA_PTX_SM(EF_CUDA_SM61)" /* 0x001c4400fe0007f6 */ /*0008*/ MOV R1, c[0x0][0x20]; /* 0x4c98078000870001 */ /*0010*/ { MOV R3, c[0x0][0x8]; /* 0x4c98078000270003 */ /*0018*/ S2R R0, SR_CTAID.X; } /* 0xf0c8000002570000 */ /* 0x081fd841fe20073f */ /*0028*/ S2R R2, SR_TID.X; /* 0xf0c8000002170002 */ /*0030*/ XMAD R3, R0.reuse, R3, c[0x0][0x3c]; /* 0x5100018000f70003 */ /*0038*/ XMAD.MRG R5, R0.reuse, c[0x0] [0x8].H1, RZ; /* 0x4f107f8000270005 */ /* 0x001f8402fec007f6 */ /*0048*/ XMAD.PSL.CBCC R0, R0.H1, R5.H1, R3; /* 0x5b30019800570000 */ /*0050*/ IADD R0, R0, R2; /* 0x5c10000000270000 */ /*0058*/ SHL R6, R0, 0x2; /* 0x3848000000270006 */ /* 0x081fc840fec007f5 */ /*0068*/ SHR R0, R0, 0x1e; /* 0x3829000001e70000 */ /*0070*/ IADD R2.CC, R6.reuse, c[0x0][0x140]; /* 0x4c10800005070602 */ /*0078*/ IADD.X R3, R0.reuse, c[0x0][0x144]; /* 0x4c10080005170003 */ /* 0x001fc800eec207f0 */ /*0088*/ { IADD R4.CC, R6.reuse, c[0x0][0x148]; /* 0x4c10800005270604 */ /*0090*/ LDG.E R2, [R2]; } /* 0xeed4200000070202 */ /*0098*/ IADD.X R5, R0, c[0x0][0x14c]; /* 0x4c10080005370005 */ /* 0x001fc400fcc00771 */ /*00a8*/ LDG.E R4, [R4]; /* 0xeed4200000070404 */ /*00b0*/ IADD R6.CC, R6, c[0x0][0x150]; /* 0x4c10800005470606 */ /*00b8*/ IADD.X R7, R0, c[0x0][0x154]; /* 0x4c10080005570007 */ /* 0x001ffc001e2047f2 */ /*00c8*/ FADD R0, R2, R4; /* 0x5c58000000470200 */ /*00d0*/ STG.E [R6], R0; /* 0xeedc200000070600 */ /*00d8*/ EXIT; /* 0xe30000000007000f */ /* 0x001f8000fc0007ff */ /*00e8*/ BRA 0xe0; /* 0xe2400fffff07000f */ /*00f0*/ NOP; /* 0x50b0000000070f00 */ /*00f8*/ NOP; /* 0x50b0000000070f00 */ ..............
Is Division Expensive?¶
In [13]:
prg = cl.Program(ctx, """
__kernel void sum(
__global float *a_g, int n)
{
int gid = get_global_id(0);
// try dividing by n
int row = gid / 117;
int col = gid % 117;
a_g[row * 128 + col] *= 2;
// a_g[gid] *= 2;
}
""").build()
hacked_binary = prg.binaries[0].replace(b".version 6.1", b".version 6.0")
with open("tmp/binary.ptx", "wb") as outf:
outf.write(hacked_binary)
!(cd tmp; ptxas --gpu-name sm_60 --verbose binary.ptx -o binary.o)
!/usr/local/cuda/bin/cuobjdump --dump-sass tmp/binary.o | cut -c -80
ptxas info : 0 bytes gmem, 4 bytes cmem[3] ptxas info : Compiling entry function 'sum' for 'sm_60' ptxas info : Function properties for sum 0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads ptxas info : Used 6 registers, 332 bytes cmem[0] code for sm_60 Function : sum .headerflags @"EF_CUDA_SM60 EF_CUDA_PTX_SM(EF_CUDA_SM60)" /*0008*/ MOV R1, c[0x0][0x20]; /*0010*/ { MOV R3, c[0x0][0x8]; /*0018*/ S2R R0, SR_CTAID.X; } /*0028*/ S2R R2, SR_TID.X; /*0030*/ XMAD R3, R0, R3, c[0x0][0x3c]; /*0038*/ XMAD.MRG R5, R0, c[0x0] [0x8].H1, RZ; /*0048*/ XMAD.PSL.CBCC R0, R0.H1, R5.H1, R3; /*0050*/ MOV32I R5, 0x8c08c08d; /*0058*/ IADD R0, R0, R2; /*0068*/ XMAD R2, R0, R5, RZ; /*0070*/ XMAD.U16.S16 R3, R0, R5.H1, RZ; /*0078*/ XMAD.S16.S16.CSFU R4, R0.H1, R5.H1, R0; /*0088*/ XMAD.S16.U16.CHI R2, R0.H1, R5, R2; /*0090*/ IADD3.RS R2, R2, R3, R4; /*0098*/ SHR R3, R2.reuse, 0x6; /*00a8*/ LEA.HI R2, R2, R3, RZ, 0x1; /*00b0*/ IADD R3, -R2, RZ; /*00b8*/ XMAD R0, R3, 0x75, R0; /*00c8*/ XMAD.PSL R0, R3.H1, 0x75, R0; /*00d0*/ ISCADD R0, R2, R0, 0x7; /*00d8*/ ISCADD R2.CC, R0.reuse, c[0x0][0x140], 0x2; /*00e8*/ SHR R0, R0, 0x1e; /*00f0*/ IADD.X R3, R0, c[0x0][0x144]; /*00f8*/ LDG.E R0, [R2]; /*0108*/ FMUL R0, R0, 2; /*0110*/ STG.E [R2], R0; /*0118*/ EXIT; /*0128*/ BRA 0x120; /*0130*/ NOP; /*0138*/ NOP; ..............
An Example with Control Flow¶
In [68]:
prg = cl.Program(ctx, """
__kernel void sum(
__global const float *a_g, __global const float *b_g, __global float *res_g, int n)
{
int gsize = get_global_size(0);
for (int i = get_global_id(0); i < n; i += gsize)
res_g[i] = a_g[i] + b_g[i];
res_g[get_global_id(0)] = 15;
}
""").build()
hacked_binary = prg.binaries[0].replace(b".version 6.1", b".version 6.0")
with open("tmp/binary.ptx", "wb") as outf:
outf.write(hacked_binary)
!(cd tmp; ptxas --gpu-name sm_60 --verbose binary.ptx -o binary.o)
!/usr/local/cuda/bin/cuobjdump --dump-sass tmp/binary.o | cut -c -80
ptxas info : 0 bytes gmem, 4 bytes cmem[3] ptxas info : Compiling entry function 'sum' for 'sm_60' ptxas info : Function properties for sum 0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads ptxas info : Used 12 registers, 348 bytes cmem[0] code for sm_60 Function : sum .headerflags @"EF_CUDA_SM60 EF_CUDA_PTX_SM(EF_CUDA_SM60)" /*0008*/ MOV R1, c[0x0][0x20]; /*0010*/ { MOV R11, c[0x0][0x8]; /*0018*/ S2R R0, SR_CTAID.X; } /*0028*/ S2R R2, SR_TID.X; /*0030*/ { XMAD R3, R0.reuse, R11, c[0x0][0x3c]; /*0038*/ SSY 0x150; } /*0048*/ XMAD.MRG R5, R0.reuse, c[0x0] [0x8].H1, RZ; /*0050*/ XMAD.PSL.CBCC R0, R0.H1, R5.H1, R3; /*0058*/ IADD R0, R0, R2; /*0068*/ ISETP.GE.AND P0, PT, R0, c[0x0][0x158], PT; /*0070*/ @P0 SYNC; /*0078*/ MOV R8, R0; /*0088*/ SHL R9, R8.reuse, 0x2; /*0090*/ SHR R7, R8, 0x1e; /*0098*/ IADD R4.CC, R9.reuse, c[0x0][0x140]; /*00a8*/ IADD.X R5, R7.reuse, c[0x0][0x144]; /*00b0*/ { IADD R2.CC, R9, c[0x0][0x148]; /*00b8*/ LDG.E R4, [R4]; } /*00c8*/ IADD.X R3, R7, c[0x0][0x14c]; /*00d0*/ LDG.E R2, [R2]; /*00d8*/ XMAD R8, R11.reuse, c[0x0] [0x48], R8; /*00e8*/ XMAD.MRG R6, R11.reuse, c[0x0] [0x48].H1, RZ; /*00f0*/ XMAD.PSL.CBCC R8, R11.H1, R6.H1, R8; /*00f8*/ IADD R9.CC, R9, c[0x0][0x150]; /*0108*/ ISETP.LT.AND P0, PT, R8, c[0x0][0x158], PT; /*0110*/ MOV R6, R9; /*0118*/ IADD.X R7, R7, c[0x0][0x154]; /*0128*/ FADD R9, R2, R4; /*0130*/ STG.E [R6], R9; /*0138*/ @P0 BRA 0x80; /*0148*/ SYNC; /*0150*/ ISCADD R2.CC, R0.reuse, c[0x0][0x150], 0x2; /*0158*/ SHR R0, R0, 0x1e; /*0168*/ IADD.X R3, R0, c[0x0][0x154]; /*0170*/ MOV32I R0, 0x41700000; /*0178*/ STG.E [R2], R0; /*0188*/ EXIT; /*0190*/ BRA 0x190; /*0198*/ NOP; /*01a8*/ NOP; /*01b0*/ NOP; /*01b8*/ NOP; ..............
- Spot something that doesn't quite seem to belong?
From CUDA¶
Vector add stolen from ORNL.
In [34]:
%%writefile tmp/vector-add.cu
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
__global__ void vecAdd(double *a, double *b, double *c, int n)
{
// Get our global thread ID
int id = blockIdx.x*blockDim.x+threadIdx.x;
// Make sure we do not go out of bounds
if (id < n)
c[id] = a[id] + b[id];
}
int main( int argc, char* argv[] )
{
// Size of vectors
int n = 100000;
double *h_a;
double *h_b;
double *h_c;
double *d_a;
double *d_b;
double *d_c;
size_t bytes = n*sizeof(double);
h_a = (double*)malloc(bytes);
h_b = (double*)malloc(bytes);
h_c = (double*)malloc(bytes);
cudaMalloc(&d_a, bytes);
cudaMalloc(&d_b, bytes);
cudaMalloc(&d_c, bytes);
int i;
for( i = 0; i < n; i++ ) {
h_a[i] = sin(i)*sin(i);
h_b[i] = cos(i)*cos(i);
}
// Copy host vectors to device
cudaMemcpy( d_a, h_a, bytes, cudaMemcpyHostToDevice);
cudaMemcpy( d_b, h_b, bytes, cudaMemcpyHostToDevice);
int blockSize, gridSize;
blockSize = 1024;
gridSize = (int)ceil((float)n/blockSize);
vecAdd<<<gridSize, blockSize>>>(d_a, d_b, d_c, n);
cudaMemcpy( h_c, d_c, bytes, cudaMemcpyDeviceToHost );
double sum = 0;
for(i=0; i<n; i++)
sum += h_c[i];
printf("final result: %f\n", sum/n);
cudaFree(d_a);
cudaFree(d_b);
cudaFree(d_c);
free(h_a);
free(h_b);
free(h_c);
return 0;
}
Overwriting tmp/vector-add.cu
In [35]:
!(cd tmp; nvcc -c -ccbin g++-7 vector-add.cu)
!/usr/local/cuda/bin/cuobjdump --dump-sass tmp/vector-add.o
Fatbin elf code: ================ arch = sm_30 code version = [1,7] producer = cuda host = linux compile_size = 64bit code for sm_30 Function : _Z6vecAddPdS_S_i .headerflags @"EF_CUDA_SM30 EF_CUDA_PTX_SM(EF_CUDA_SM30)" /* 0x2202e2c2828232b7 */ /*0008*/ MOV R1, c[0x0][0x44]; /* 0x2800400110005de4 */ /*0010*/ S2R R0, SR_CTAID.X; /* 0x2c00000094001c04 */ /*0018*/ S2R R3, SR_TID.X; /* 0x2c0000008400dc04 */ /*0020*/ IMAD R0, R0, c[0x0][0x28], R3; /* 0x20064000a0001ca3 */ /*0028*/ ISETP.GE.AND P0, PT, R0, c[0x0][0x158], PT; /* 0x1b0e40056001dc23 */ /*0030*/ @P0 EXIT; /* 0x80000000000001e7 */ /*0038*/ ISCADD R2.CC, R0, c[0x0][0x140], 0x3; /* 0x4001400500009c63 */ /* 0x22c04282c04282b7 */ /*0048*/ MOV32I R9, 0x8; /* 0x1800000020025de2 */ /*0050*/ IMAD.HI.X R3, R0, R9, c[0x0][0x144]; /* 0x209280051000dce3 */ /*0058*/ ISCADD R4.CC, R0, c[0x0][0x148], 0x3; /* 0x4001400520011c63 */ /*0060*/ LD.E.64 R2, [R2]; /* 0x8400000000209ca5 */ /*0068*/ IMAD.HI.X R5, R0, R9, c[0x0][0x14c]; /* 0x2092800530015ce3 */ /*0070*/ LD.E.64 R4, [R4]; /* 0x8400000000411ca5 */ /*0078*/ ISCADD R8.CC, R0, c[0x0][0x150], 0x3; /* 0x4001400540021c63 */ /* 0x20000002e04293f7 */ /*0088*/ IMAD.HI.X R9, R0, R9, c[0x0][0x154]; /* 0x2092800550025ce3 */ /*0090*/ DADD R6, R4, R2; /* 0x4800000008419c01 */ /*0098*/ ST.E.64 [R8], R6; /* 0x9400000000819ca5 */ /*00a0*/ EXIT; /* 0x8000000000001de7 */ /*00a8*/ BRA 0xa8; /* 0x4003ffffe0001de7 */ /*00b0*/ NOP; /* 0x4000000000001de4 */ /*00b8*/ NOP; /* 0x4000000000001de4 */ ........................... Fatbin ptx code: ================ arch = sm_30 code version = [6,0] producer = cuda host = linux compile_size = 64bit compressed
- What is
_Z6vecAddPdS_S_i
?
In [36]:
!echo _Z6vecAddPdS_S_i | c++filt
vecAdd(double*, double*, double*, int)
Inline PTX¶
In [47]:
prg = cl.Program(ctx, """
__kernel void getlaneid(__global int *d_ptr, int length)
{
int elemID = get_global_id(0);
if (elemID < length)
{
unsigned int laneid;
asm("mov.u32 %0, %%laneid;" : "=r"(laneid));
d_ptr[elemID] = laneid;
}
}
""").build()
print(prg.binaries[0].decode())
// // Generated by NVIDIA NVVM Compiler // // Compiler Build ID: CL-24786619 // Driver 390.87 // Based on LLVM 3.4svn // .version 6.1 .target sm_52, texmode_independent .address_size 64 // .globl getlaneid .const .align 4 .u32 pyopencl_defeat_cache_8c062881a319435694e9b93964b505dc; .entry getlaneid( .param .u64 .ptr .global .align 4 getlaneid_param_0, .param .u32 getlaneid_param_1 ) { .reg .pred %p<2>; .reg .b32 %r<9>; .reg .b64 %rd<4>; ld.param.u64 %rd1, [getlaneid_param_0]; ld.param.u32 %r2, [getlaneid_param_1]; mov.b32 %r3, %envreg3; mov.u32 %r4, %ctaid.x; mov.u32 %r5, %ntid.x; mad.lo.s32 %r6, %r4, %r5, %r3; mov.u32 %r7, %tid.x; add.s32 %r1, %r6, %r7; setp.ge.s32 %p1, %r1, %r2; @%p1 bra BB0_2; // inline asm mov.u32 %r8, %laneid; // inline asm mul.wide.s32 %rd2, %r1, 4; add.s64 %rd3, %rd1, %rd2; st.global.u32 [%rd3], %r8; BB0_2: ret; }
- What do the constraints mean again?
- Spot the inline assembly
- Observe how the
if
is realized - Observe the realization of
get_global_id()
In [48]:
a = cla.empty(queue, 5000, np.uint32)
prg.getlaneid(queue, lanes.shape, None, a.data, np.uint32(a.size))
Out[48]:
<pyopencl._cl.Event at 0x7fa528c8b360>
In [50]:
a[:500]
Out[50]:
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], dtype=uint32)
In [ ]: