Nvidia target via NVRTC and Nim ↦ CUDA DSL #487
Conversation
|
Sorry for taking so long to review and discuss. ApproachI think this is the best way forward for several reasons:
Regarding OpenCL, I spent some time to look into how to generate it from LLVM IR but it seems like we need external tools to convert llvm to spirv (https://llvm.org/devmtg/2021-11/slides/2021-SPIR-V-SupportinLLVMandClang.pdf) and generating source code would avoid that. This is only relevant for Intel GPU though, but with the rise of AI, it might be that they become fast enough to accelerate KZG commitments. Trusted external binaries should be reduced to the minimum as we need the threat model/attach surface to be as small as possible i.e. we trust Nim, Clang, nvcc and that's all Regarding Apple Metal, I spent a lot of time to investigate Julia's compilation pipeline (see https://github.com/JuliaGPU/GPUCompiler.jl) but the maintaining a fork of LLVM to allow downgrading IR to be compatible with Metal IR is a bit too much to maintain. On the header fileThe main appeal of NVRTC approach is not needing header file and so ealing with path issues that plagued Arraymancer (BLAS and cuda config, see: https://github.com/mratsim/Arraymancer/blob/v0.7.33/nim.cfg) There should be a On Nim feature supportIt would be much cleaner to support generics, static int and static enums, which should be possible if we use a typed macro, and should be possible with untyped if we rewrite the generic to C++ templates since Cuda/Hip works in C++ mode by default (but not OpenCL and maybe not Metal so not portable). |
|
Another note I would still keep the LLVM IR code generator as it would enable the following use cases:
|
283f9de to
9f9b785
Compare
|
Just updated the code to:
The main thing missing now is smarter detection of the CUDA paths on the user's machine (right now they are hardcoded). We can in principle also adapt the |
mratsim
left a comment
There was a problem hiding this comment.
I wonder if by using dynlib: libnvrtc-builtins.so we can avoid hardcoding the path.
I think it should work for the .so, and it helps on Windows as well. I'm unsure about about the static library linking step though.
| # Add the device runtime (provides printf support) | ||
| ## NOTE: Linking requires yout to pass the path to `libcudadevrt.a` at CT | ||
| res = cuLinkAddFile(linkState, CU_JIT_INPUT_LIBRARY, | ||
| "/usr/local/cuda/targets/x86_64-linux/lib/libcudadevrt.a", # Adjust path as needed |
| let threadIdx* = NvThreadIdx() | ||
|
|
||
| ## Similar for procs. They don't need any implementation, as they won't ever be actually called. | ||
| proc printf*(fmt: string) {.varargs.} = discard |
There was a problem hiding this comment.
I think we can declare it with something like
proc cu_printf*(fmt: cstring): cint {.sideeffect, importc: "printf", dynlib: "libnvrtc-builtins.so", varargs, discardable, tags:[WriteIOEffect].}and not need to import cuda.h and deal with paths
|
|
||
| ## Similar for procs. They don't need any implementation, as they won't ever be actually called. | ||
| proc printf*(fmt: string) {.varargs.} = discard | ||
| proc memcpy*(dst, src: pointer, size: int) = discard |
IMPORTANT NOTE: For LLVM we generate `array_t` types for the finite field elements. By doing this we make it impossible to just copy over a Constantine finite field element or elliptic curve point (which are also an `array_t` type). Therefore, we have a `CUfunctionLLVM` type, which is used to differentiate between different `execCuda` calls, based on their "origin" (i.e. LLVM or NVRTC backends). Based on that backend we either allow passing simple structs by their host pointer or force a copy.
This was added for a reason after all in 5d66b52
By mapping them to a regular cast.
```
proc setZero(a: var BigInt) {.device.} =
proc setOne(a: var BigInt) {.device.} =
proc add(r: var BigInt, a, b: BigInt) {.device.} =
proc sub(r: var BigInt, a, b: BigInt) {.device.} =
proc mul(r: var BigInt, a, b: BigInt) {.device.} =
proc ccopy(a: var BigInt, b: BigInt, condition: bool) {.device.} =
proc csetZero(r: var BigInt, condition: bool) {.device.} =
proc csetOne(r: var BigInt, condition: bool) {.device.} =
proc cadd(r: var BigInt, a: BigInt, condition: bool) {.device.} =
proc csub(r: var BigInt, a: BigInt, condition: bool) {.device.} =
proc doubleElement(r: var BigInt, a: BigInt) {.device.} =
proc nsqr(r: var BigInt, a: BigInt, count: int) {.device.} =
proc isZero(r: var bool, a: BigInt) {.device.} =
proc isOdd(r: var bool, a: BigInt) {.device.} =
proc neg(r: var BigInt, a: BigInt) {.device.} =
proc cneg(r: var BigInt, a: BigInt, condition: bool) {.device.} =
proc shiftRight(r: var BigInt, k: uint32) {.device.} =
proc div2(r: var BigInt) {.device.} =
```
[back]
To fix:
```nim
const code = cuda:
proc sum() {.device.} =
let inputIdx = 0
let rateIdx = 0
if inputIdx > 0 and rateIdx > 0:
discard
return
echo code
```
which previously produced:
```
extern "C" __device__ void sum(){
long long inputIdx = 0;
long long rateIdx = 0;
if (((0 < inputIdx); && (0 < rateIdx);)) {
;
};
return ;
};
```
and now:
```
extern "C" __device__ void sum(){
long long inputIdx = 0;
long long rateIdx = 0;
if (((0 < inputIdx) && (0 < rateIdx))) {
;
};
return ;
};
```
This allows the user to e.g. allocate and copy memory to a CUDA device before calling `execute`.
so that one can e.g. copy to a global symbol before execution
i.e. to write
```cuda
extern shared int foo[];
```
we will now support
```nim
var foo {.cuExtern, shared.}: array[0, int]
```
(the 0 size is the current placeholder on how to designate a `[]`
array from Nim)
Allows to map a proc to a custom name. Useful for names that we can't write due to Nim limitations (i.e. starting with an underscore or names that match Nim keywords)
One can either define a Nim `const` for a variable that is already a
const at the Nim compile time or use
```nim
var foo {.constant.}: theType
```
if one wishes to copy to the symbol before kernel execution. This is
useful for global constants that are not filled in at CUDA compile
time, but before execution. For example:
```nim
# Filled with `copyToSymbol` at runtime from host!
var rc16 {.constant.}: array[30, array[BABYBEAR_WIDTH, BigInt]]
var matInternalDiagM1 {.constant.}: array[BABYBEAR_WIDTH, BigInt]
var montyInverse {.constant.}: BigInt
```
And in the host code:
```nim
var nvrtc = initNvrtc(CudaCode)
nvrtc.compile()
nvrtc.getPtx()
nvrtc.load()
var p2bb = Poseidon2BabyBear.init()
# copy Poseidon2 constants to CUDA kernel
nvrtc.copyToSymbol("rc16", p2bb.rc16)
nvrtc.copyToSymbol("matInternalDiagM1", p2bb.matInternalDiagM1)
nvrtc.copyToSymbol("montyInverse", p2bb.montyInverse)
```
Need to finalize the logic of mapping 64 bit limbs to 32 bit limb constants
|
Merging this for now. Next steps:
|
(Note: this is a draft, because it is a) a proof of concept and b) still depends on
nimcudafor simplicity)Table of contents & introduction
library
cudadesignThis (draft) PR adds an experimental alternative to generating code
targeting Nvidia GPUs. Instead of relying on LLVM to generate Nvidia PTX
instructions, this PR adds 3 pieces:
defined in
codegen_nvidia.nim:https://github.com/mratsim/constantine/blob/master/constantine/math_compiler/codegen_nvidia.nim
(if we decide to go ahead with this PR, I'll merge the two. They are
compatible, the new one just has a few extra features),
library) compiler and compile a string of CUDA code,
A few words on each of these first:
CUDA execution
Starting with this as it is already present in Constantine. Once one has
a compiled CUDA kernel and wishes to execute it, in principle one needs
to:
are not pure value types
cuLaunchKernelmaking sure to pass all parameters as an arrayof pointers
Instead of having to do this manually, we use a typed macro, which
determines the required action based on the parameters passed to it.
The basic usage looks like:
where
resandinputsare tuples (to support heterogeneous types).Arguments passed as
resare treated as output parameters. They willboth be copied to the device and afterwards back to the local
identifiers.
inputswill either be passed by value or copied, depending on if thedata is
reftype or not. NOTE: A currently not implemented feature isdeep copying data structures, which contain references / pointers
themselves. This is important in particular if one wishes to pass data
as a struct of arrays (SoA).
In practice in the context of the code of this PR, you don't directly
interact with
execCuda. This is done via the NVRTC compiler in thenext section.
NOTE: The parameters will be passed in the order:
restuple in the tuple orderinputstuple in their orderThis means that your output arguments must be the first arguments of the
kernel currently!
NVRTC compiler helper
This is essentially an equivalent of the LLVM based
NvidiaAssemblerpart of the LLVM backend,
https://github.com/mratsim/constantine/blob/master/constantine/math_compiler/codegen_nvidia.nim#L501-L512
Similarly to all CUDA work, lots of boilerplate code is required to
initialize the device, set up the compilation pipeline, call the
compiler on the CUDA string etc. As most of this is identical in the
majority of use cases, we can automate it away. NOTE: We will likely
want to eventually add some context or config object to store e.g.
specific parameters to pass to the NVRTC compiler for example.
As an example, let's look at what the Saxpy
example
from the CUDA documentation looks like for us now.
Clearly, most of the steps (
compile,getPTtx) could also just bedone as part of the
execute. I just haven't merged them yet.We can see that the majority of the code is now setting up the input
data for the kernel.
Important note about CUDA library
To fully support all CUDA features using NVRTC, we need to use the
headerpragma in the CUDA wrapper. See thisnimcudaissue about theproblem:
SciNim/nimcuda#27
(Note: the current existing CUDA wrapper in Constantine also avoids the
headerpragma. Once we switch to using our own, we'll have to makethat change and thus need the code below)
This implies that we need to know the path to the CUDA libraries at
compile time. Given that most people on linux systems tend to install
CUDA outside their package manager, this implies we need to pass the
path to the compiler.
The
runtime_compile.nimfile contains the following variables:You can compile a program using
-d:CudaPath=<path/to/your/cuda>to setthe paths accordingly.
CUDA code generator
This brings us to the most interesting part of this PR. In the example
above we simply had a string of raw CUDA code. But for anyone who tends
to write Nim, this is likely not the most attractive nor elegant
solution. So instead for the Saxpy example from above, we can write:
Due to the anyhow somewhat restricted nature of writing CUDA code, the
vast majority of practical code is already supported. You likely won't
think about CUDA devices for complex string handling or ref object
madness as your first choice. Note that the features you'd expect to see
all work. We can access arrays, we have more sane types
(
UncheckedArrayinstead of raw pointers), can access the CUDA specialblock / thread related variables etc. The latter is implemented by
defining dummy types in
runtime_compile.nim, which are only there tomake the Nim compiler as part of the typed macro pass happy. Also,
typical CUDA annotations like
__global__are mapped to Nim pragmas asyou can see.
Important Nim features that are currently not supported:
resultvariableseq,stringetc.openArraywhileloops (simple)casestatements (should be straightforward, but likely not veryuseful)
echoon device (but you canprintf, see below!)staticFor. Constantine's is currently slightlybroken in the macro.
Important Nim features that do work:
ifstatementsforloopsseq[T](for T being value types!) to a kernel (technicallya feature of
execCuda) and usingseq[T]as a return typetemplatesin thecudamacro.cudamacro, you can create a template with a
staticbody accessingthe constant. The template body will be replaced by the constant
value
cfloat,cintetc)whenstatements to avoid a runtime branchdereferencing, casting, object constructors, …)
statically sized array from a runtime value (or C / C++ for that
matter). So
BigInt(limbs: someArray)is invalid. You'll needto
memcpy/ manually assign data. Statically known arrays workthough.
Important CUDA features currently not supported:
__shared__memory (just needs to implement the pragma)__synchthreadsand similar functions (also just need a Nim namefor them and then map them to their CUDA name)
Important CUDA features that do work:
blockIdx,blockDim,threadIdx__global__,__device__,__forceinline__pragmas (viaequivalent Nim pragmas without the
_asmstatementprintfon device (obviously only use this to debug)memcpyNotes on the
cudadesignInitially I started out with an
untypedmacro and thought I'd justhave the Nim code be only one layer above being a pure string literal.
Essentially just mapping Nim constructs directly to fixed strings. But I
quickly realized that having a
typedmacro would be much better,because we could actually access type information and use templates in
the body (as they are expanded before the typed macro is executed!).
I think it is likely possible to go one step further than the
current code and access Nim procs defined outside the
cudamacro, as long as they are in scope (and not overloaded!). With a
typed macro we can get its body, insert it into the CUDA context and
treat them as
__device__functions.I mainly think about this not really for the purpose of sharing lots of
code between the CUDA target and other targets. While the code sharing
could theoretically be quite beneficial, I think likely it won't be very
practical. Most likely different targets require a very different
approach in many details. E.g. the low level primitives using inline PTX
instructions. At a higher level one will need different approaches due
to the trade offs needed for efficient parallelism on Nvidia GPUs
compared to a CPU approach.
However, what I do think would be very useful is to be able to split the
cudamacro into multiple pieces (similar to how one writes Nim macrosreally). Say one
cudacall for type definitions, one for some devicefunctions etc. But due to the typed nature, this implies all the defined
types and functions would need to be visible in a global scope, which
currently would not be the case.
Profiling Nvidia code
Although it is probably obvious, it is worth mentioning that you can of
course use an Nvidia profiler (
nvproforncu) on Nim binaries, whichuse this feature.
A more complex example
For a more complex example, see the BigInt example file part of this PR.
There we implement modular addition for finite field elements, similar
to the current existing implementation for the LLVM target (using inline
PTX instructions).
It shows how one defines a type on the CUDA device, accesses a constant
from Constantine (the field modulus) using a
templatewith astaticbody, how to construct objects on device and more. You'll see that the
code essentially looks like normal host code.
Be aware of course, to actually achieve really high performance, just
launching lots of blocks with many threads won't give you an O(1-10k)
(depending on # of CUDA cores) speedup over a single thread. You'll
need to make sure to first go down the rabbit hole of thinking about
memory coalescence, blocks, warps and all that… As an example, a simple
benchmark performing additions of 2^25 pairs of finite field elements
of
BN254_Snarksusing aBigInttype, which stores 8uint32limbsleads to only a 10x speedup compared to a single CPU core (using our
very optimized CPU code of course).
nvprofshows that the memoryperformance in that case is only 12.5%, because each thread has to jump
over the 8 limbs of the neighboring threads/lanes. This leads to non
coalesced memory access and causes a massive performance penalty. I
mention this in particular, because to implement a stucture of array
(SoA) approach for the data (where we have a single
BigIntstype,which has one array for limb 0, one for limb 1 and so on) is currently
not supported in the context of copying data to the device via
execCuda. We need to extend the "when and what to copy" logic in themacro first.