CuPy – NumPy & SciPy for GPU¶
Overview¶
CuPy is a NumPy/SciPy-compatible array library for GPU-accelerated computing with Python. CuPy acts as a drop-in replacement to run existing NumPy/SciPy code on NVIDIA CUDA or AMD ROCm platforms.
CuPy provides a ndarray
, sparse matrices, and the associated routines for GPU devices, all having the same API as NumPy and SciPy:
N-dimensional array (
ndarray
): cupy.ndarrayData types (dtypes): boolean (
bool_
), integer (int8
,int16
,int32
,int64
,uint8
,uint16
,uint32
,uint64
), float (float16
,float32
,float64
), and complex (complex64
,complex128
)Supports the semantics identical to
numpy.ndarray
, including basic / advanced indexing and broadcasting
Sparse matrices: cupyx.scipy.sparse
2-D sparse matrix:
csr_matrix
,coo_matrix
,csc_matrix
, anddia_matrix
NumPy Routines
Module-level Functions (
cupy.*
)Linear Algebra Functions (
cupy.linalg.*
)Fast Fourier Transform (
cupy.fft.*
)Random Number Generator (
cupy.random.*
)
SciPy Routines
Discrete Fourier Transforms (
cupyx.scipy.fft.*
andcupyx.scipy.fftpack.*
)Advanced Linear Algebra (
cupyx.scipy.linalg.*
)Multidimensional Image Processing (
cupyx.scipy.ndimage.*
)Sparse Matrices (
cupyx.scipy.sparse.*
)Sparse Linear Algebra (
cupyx.scipy.sparse.linalg.*
)Special Functions (
cupyx.scipy.special.*
)Signal Processing (
cupyx.scipy.signal.*
)Statistical Functions (
cupyx.scipy.stats.*
)
Routines are backed by CUDA libraries (cuBLAS, cuFFT, cuSPARSE, cuSOLVER, cuRAND), Thrust, CUB, and cuTENSOR to provide the best performance.
It is also possible to easily implement custom CUDA kernels that work with ndarray
using:
Kernel Templates: Quickly define element-wise and reduction operation as a single CUDA kernel
Raw Kernel: Import existing CUDA C/C++ code
Just-in-time Transpiler (JIT): Generate CUDA kernel from Python source code
Kernel Fusion: Fuse multiple CuPy operations into a single CUDA kernel
CuPy can run in multi-GPU or cluster environments. The distributed communication package (cupyx.distributed
) provides collective and peer-to-peer primitives for ndarray
, backed by NCCL.
For users who need more fine-grain control for performance, accessing low-level CUDA features are available:
Stream and Event: CUDA stream and per-thread default stream are supported by all APIs
Memory Pool: Customizable memory allocator with a built-in memory pool
Profiler: Supports profiling code using CUDA Profiler and NVTX
Host API Binding: Directly call CUDA libraries, such as NCCL, cuDNN, cuTENSOR, and cuSPARSELt APIs from Python
CuPy implements standard APIs for data exchange and interoperability, such as DLPack, CUDA Array Interface, __array_ufunc__
(NEP 13), __array_function__
(NEP 18), and Array API Standard.
Thanks to these protocols, CuPy easily integrates with NumPy, PyTorch, TensorFlow, MPI4Py, and any other libraries supporting the standard.
Under AMD ROCm environment, CuPy automatically translates all CUDA API calls to ROCm HIP (hipBLAS, hipFFT, hipSPARSE, hipRAND, hipCUB, hipThrust, RCCL, etc.), allowing code written using CuPy to run on both NVIDIA and AMD GPU without any modification.
Project Goal¶
The goal of the CuPy project is to provide Python users GPU acceleration capabilities, without the in-depth knowledge of underlying GPU technologies. The CuPy team focuses on providing:
A complete NumPy and SciPy API coverage to become a full drop-in replacement, as well as advanced CUDA features to maximize the performance.
Mature and quality library as a fundamental package for all projects needing acceleration, from a lab environment to a large-scale cluster.
Installation¶
Requirements¶
NVIDIA CUDA GPU with the Compute Capability 3.0 or larger.
CUDA Toolkit: v10.2 / v11.0 / v11.1 / v11.2 / v11.3 / v11.4 / v11.5
If you have multiple versions of CUDA Toolkit installed, CuPy will automatically choose one of the CUDA installations. See Working with Custom CUDA Installation for details.
This requirement is optional if you install CuPy from
conda-forge
. However, you still need to have a compatible driver installed for your GPU. See Installing CuPy from Conda-Forge for details.
Python: v3.7.0+ / v3.8.0+ / v3.9.0+ / v3.10.0+
Note
Currently, CuPy is tested against Ubuntu 18.04 LTS / 20.04 LTS (x86_64), CentOS 7 / 8 (x86_64) and Windows Server 2016 (x86_64).
Python Dependencies¶
NumPy/SciPy-compatible API in CuPy v10 is based on NumPy 1.21 and SciPy 1.7, and has been tested against the following versions:
NumPy: v1.18 / v1.19 / v1.20 / v1.21
SciPy (optional): v1.4 / v1.5 / v1.6 / v1.7
Required only when using Routines (SciPy) (
cupyx.scipy
).
Optuna (optional): v2.x
Required only when using Automatic Kernel Parameters Optimizations (cupyx.optimizing).
Note
SciPy and Optuna are optional dependencies and will not be installed automatically.
Note
Before installing CuPy, we recommend you to upgrade setuptools
and pip
:
$ python -m pip install -U setuptools pip
Additional CUDA Libraries¶
Part of the CUDA features in CuPy will be activated only when the corresponding libraries are installed.
cuTENSOR: v1.3
The library to accelerate tensor operations. See Environment variables for the details.
NCCL: v2.8 / v2.9 / v2.10 / v2.11
The library to perform collective multi-GPU / multi-node computations.
cuDNN: v7.6 / v8.0 / v8.1 / v8.2 / v8.3
The library to accelerate deep neural network computations.
cuSPARSELt: v0.1.0
The library to accelerate sparse matrix-matrix multiplication.
Installing CuPy¶
Installing CuPy from PyPI¶
Wheels (precompiled binary packages) are available for Linux (x86_64) and Windows (amd64). Package names are different depending on your CUDA Toolkit version.
CUDA |
Command |
---|---|
v10.2 |
|
v11.0 |
|
v11.1 |
|
v11.2 |
|
v11.3 |
|
v11.4 |
|
v11.5 |
|
Note
To enable features provided by additional CUDA libraries (cuTENSOR / NCCL / cuDNN), you need to install them manually. If you installed CuPy via wheels, you can use the installer command below to setup these libraries in case you don’t have a previous installation:
$ python -m cupyx.tools.install_library --cuda 11.2 --library cutensor
Note
Use pip install cupy-cudaXXX -f https://pip.cupy.dev/pre
to install pre-release (development) versions.
When using wheels, please be careful not to install multiple CuPy packages at the same time.
Any of these packages and cupy
package (source installation) conflict with each other.
Please make sure that only one CuPy package (cupy
or cupy-cudaXX
where XX is a CUDA version) is installed:
$ pip freeze | grep cupy
Installing CuPy from Conda-Forge¶
Conda/Anaconda is a cross-platform package management solution widely used in scientific computing and other fields.
The above pip install
instruction is compatible with conda
environments. Alternatively, for both Linux (x86_64,
ppc64le, aarch64-sbsa) and
Windows once the CUDA driver is correctly set up, you can also install CuPy from the conda-forge
channel:
$ conda install -c conda-forge cupy
and conda
will install a pre-built CuPy binary package for you, along with the CUDA runtime libraries
(cudatoolkit
). It is not necessary to install CUDA Toolkit in advance.
Conda has a built-in mechanism to determine and install the latest version of cudatoolkit
supported by your driver.
However, if for any reason you need to force-install a particular CUDA version (say 11.0), you can do:
$ conda install -c conda-forge cupy cudatoolkit=11.0
Note
cuDNN, cuTENSOR, and NCCL are available on conda-forge
as optional dependencies. The following command can install them all at once:
$ conda install -c conda-forge cupy cudnn cutensor nccl
Each of them can also be installed separately as needed.
Note
If you encounter any problem with CuPy installed from conda-forge
, please feel free to report to cupy-feedstock, and we will help investigate if it is just a packaging
issue in conda-forge
’s recipe or a real issue in CuPy.
Note
If you did not install CUDA Toolkit by yourself, the nvcc
compiler might not be available, as
the cudatoolkit
package from conda-forge
does not include the nvcc
compiler toolchain. If you would like to use
it from a local CUDA installation, you need to make sure the version of CUDA Toolkit matches that of cudatoolkit
to
avoid surprises.
Installing CuPy from Source¶
Use of wheel packages is recommended whenever possible. However, if wheels cannot meet your requirements (e.g., you are running non-Linux environment or want to use a version of CUDA / cuDNN / NCCL not supported by wheels), you can also build CuPy from source.
Note
CuPy source build requires g++-6
or later.
For Ubuntu 18.04, run apt-get install g++
.
For Ubuntu 16.04, CentOS 6 or 7, follow the instructions here.
Note
When installing CuPy from source, features provided by additional CUDA libraries will be disabled if these libraries are not available at the build time. See Installing cuDNN and NCCL for the instructions.
Note
If you upgrade or downgrade the version of CUDA Toolkit, cuDNN, NCCL or cuTENSOR, you may need to reinstall CuPy. See Reinstalling CuPy for details.
You can install the latest stable release version of the CuPy source package via pip
.
$ pip install cupy
If you want to install the latest development version of CuPy from a cloned Git repository:
$ git clone --recursive https://github.com/cupy/cupy.git
$ cd cupy
$ pip install .
Note
Cython 0.29.22 or later is required to build CuPy from source. It will be automatically installed during the build process if not available.
Uninstalling CuPy¶
Use pip
to uninstall CuPy:
$ pip uninstall cupy
Note
If you are using a wheel, cupy
shall be replaced with cupy-cudaXX
(where XX is a CUDA version number).
Note
If CuPy is installed via conda
, please do conda uninstall cupy
instead.
Upgrading CuPy¶
Just use pip install
with -U
option:
$ pip install -U cupy
Note
If you are using a wheel, cupy
shall be replaced with cupy-cudaXX
(where XX is a CUDA version number).
Reinstalling CuPy¶
To reinstall CuPy, please uninstall CuPy and then install it.
When reinstalling CuPy, we recommend using --no-cache-dir
option as pip
caches the previously built binaries:
$ pip uninstall cupy
$ pip install cupy --no-cache-dir
Note
If you are using a wheel, cupy
shall be replaced with cupy-cudaXX
(where XX is a CUDA version number).
Using CuPy inside Docker¶
We are providing the official Docker images. Use NVIDIA Container Toolkit to run CuPy image with GPU. You can login to the environment with bash, and run the Python interpreter:
$ docker run --gpus all -it cupy/cupy /bin/bash
Or run the interpreter directly:
$ docker run --gpus all -it cupy/cupy /usr/bin/python3
FAQ¶
pip
fails to install CuPy¶
Please make sure that you are using the latest setuptools
and pip
:
$ pip install -U setuptools pip
Use -vvvv
option with pip
command.
This will display all logs of installation:
$ pip install cupy -vvvv
If you are using sudo
to install CuPy, note that sudo
command does not propagate environment variables.
If you need to pass environment variable (e.g., CUDA_PATH
), you need to specify them inside sudo
like this:
$ sudo CUDA_PATH=/opt/nvidia/cuda pip install cupy
If you are using certain versions of conda, it may fail to build CuPy with error g++: error: unrecognized command line option ‘-R’
.
This is due to a bug in conda (see conda/conda#6030 for details).
If you encounter this problem, please upgrade your conda.
Installing cuDNN and NCCL¶
We recommend installing cuDNN and NCCL using binary packages (i.e., using apt
or yum
) provided by NVIDIA.
If you want to install tar-gz version of cuDNN and NCCL, we recommend installing it under the CUDA_PATH
directory.
For example, if you are using Ubuntu, copy *.h
files to include
directory and *.so*
files to lib64
directory:
$ cp /path/to/cudnn.h $CUDA_PATH/include
$ cp /path/to/libcudnn.so* $CUDA_PATH/lib64
The destination directories depend on your environment.
If you want to use cuDNN or NCCL installed in another directory, please use CFLAGS
, LDFLAGS
and LD_LIBRARY_PATH
environment variables before installing CuPy:
$ export CFLAGS=-I/path/to/cudnn/include
$ export LDFLAGS=-L/path/to/cudnn/lib
$ export LD_LIBRARY_PATH=/path/to/cudnn/lib:$LD_LIBRARY_PATH
Working with Custom CUDA Installation¶
If you have installed CUDA on the non-default directory or multiple CUDA versions on the same host, you may need to manually specify the CUDA installation directory to be used by CuPy.
CuPy uses the first CUDA installation directory found by the following order.
CUDA_PATH
environment variable.The parent directory of
nvcc
command. CuPy looks fornvcc
command fromPATH
environment variable./usr/local/cuda
For example, you can build CuPy using non-default CUDA directory by CUDA_PATH
environment variable:
$ CUDA_PATH=/opt/nvidia/cuda pip install cupy
Note
CUDA installation discovery is also performed at runtime using the rule above.
Depending on your system configuration, you may also need to set LD_LIBRARY_PATH
environment variable to $CUDA_PATH/lib64
at runtime.
CuPy always raises cupy.cuda.compiler.CompileException
¶
If CuPy raises a CompileException
for almost everything, it is possible that CuPy cannot detect CUDA installed on your system correctly.
The followings are error messages commonly observed in such cases.
nvrtc: error: failed to load builtins
catastrophic error: cannot open source file "cuda_fp16.h"
error: cannot overload functions distinguished by return type alone
error: identifier "__half_raw" is undefined
Please try setting LD_LIBRARY_PATH
and CUDA_PATH
environment variable.
For example, if you have CUDA installed at /usr/local/cuda-9.2
:
$ export CUDA_PATH=/usr/local/cuda-9.2
$ export LD_LIBRARY_PATH=$CUDA_PATH/lib64:$LD_LIBRARY_PATH
Also see Working with Custom CUDA Installation.
Build fails on Ubuntu 16.04, CentOS 6 or 7¶
In order to build CuPy from source on systems with legacy GCC (g++-5 or earlier), you need to manually set up g++-6 or later and configure NVCC
environment variable.
On Ubuntu 16.04:
$ sudo add-apt-repository ppa:ubuntu-toolchain-r/test
$ sudo apt update
$ sudo apt install g++-6
$ export NVCC="nvcc --compiler-bindir gcc-6"
On CentOS 6 / 7:
$ sudo yum install centos-release-scl
$ sudo yum install devtoolset-7-gcc-c++
$ source /opt/rh/devtoolset-7/enable
$ export NVCC="nvcc --compiler-bindir gcc"
Using CuPy on AMD GPU (experimental)¶
CuPy has an experimental support for AMD GPU (ROCm).
Requirements¶
- ROCm: v4.0 / v4.2 / v4.3
See the ROCm Installation Guide for details.
The following ROCm libraries are required:
$ sudo apt install hipblas hipsparse rocsparse rocrand rocthrust rocsolver rocfft hipcub rocprim rccl
Environment Variables¶
When building or running CuPy for ROCm, the following environment variables are effective.
ROCM_HOME
: directory containing the ROCm software (e.g.,/opt/rocm
).
Docker¶
You can try running CuPy for ROCm using Docker.
$ docker run -it --device=/dev/kfd --device=/dev/dri --group-add video cupy/cupy-rocm
Installing Binary Packages¶
Wheels (precompiled binary packages) are available for Linux (x86_64). Package names are different depending on your ROCm version.
ROCm |
Command |
---|---|
v4.0 |
|
v4.2 |
|
v4.3 |
|
Building CuPy for ROCm From Source¶
To build CuPy from source, set the CUPY_INSTALL_USE_HIP
, ROCM_HOME
, and HCC_AMDGPU_TARGET
environment variables.
(HCC_AMDGPU_TARGET
is the ISA name supported by your GPU.
Run rocminfo
and use the value displayed in Name:
line (e.g., gfx900
).
You can specify a comma-separated list of ISAs if you have multiple GPUs of different architectures.)
$ export CUPY_INSTALL_USE_HIP=1
$ export ROCM_HOME=/opt/rocm
$ export HCC_AMDGPU_TARGET=gfx906
$ pip install cupy
Note
If you don’t specify the HCC_AMDGPU_TARGET
environment variable, CuPy will be built for the GPU architectures available on the build host.
This behavior is specific to ROCm builds; when building CuPy for NVIDIA CUDA, the build result is not affected by the host configuration.
Limitations¶
The following features are not available due to the limitation of ROCm or because that they are specific to CUDA:
CUDA Array Interface
cuTENSOR
Handling extremely large arrays whose size is around 32-bit boundary (HIP is known to fail with sizes 2**32-1024)
Atomic addition in FP16 (
cupy.ndarray.scatter_add
andcupyx.scatter_add
)Multi-GPU FFT and FFT callback
Some random number generation algorithms
Several options in RawKernel/RawModule APIs: Jitify, dynamic parallelism
Per-thread default stream
Random generation API (
cupy.random.Generator
) for ROCm versions older than 4.3
The following features are not yet supported:
Sparse matrices (
cupyx.scipy.sparse
)cuDNN (hipDNN)
Hermitian/symmetric eigenvalue solver (
cupy.linalg.eigh
)Polynomial roots (uses Hermitian/symmetric eigenvalue solver)
The following features may not work in edge cases (e.g., some combinations of dtype):
Note
We are investigating the root causes of the issues. They are not necessarily CuPy’s issues, but ROCm may have some potential bugs.
User Guide¶
This user guide provides an overview of CuPy and explains its important features; details are found in CuPy API Reference.
Basics of CuPy¶
In this section, you will learn about the following things:
Basics of
cupy.ndarray
The concept of current device
host-device and device-device array transfer
Basics of cupy.ndarray¶
CuPy is a GPU array backend that implements a subset of NumPy interface.
In the following code, cp
is an abbreviation of cupy
, following the standard convention of abbreviating numpy
as np
:
>>> import numpy as np
>>> import cupy as cp
The cupy.ndarray
class is at the core of CuPy
and is a replacement class for NumPy
’s numpy.ndarray
.
>>> x_gpu = cp.array([1, 2, 3])
x_gpu
above is an instance of cupy.ndarray
.
As one can see, CuPy’s syntax here is identical to that of NumPy.
The main difference between cupy.ndarray
and numpy.ndarray
is that
the CuPy arrays are allocated on the current device, which we will talk about later.
Most of the array manipulations are also done in the way similar to NumPy.
Take the Euclidean norm (a.k.a L2 norm), for example.
NumPy has numpy.linalg.norm()
function that calculates it on CPU.
>>> x_cpu = np.array([1, 2, 3])
>>> l2_cpu = np.linalg.norm(x_cpu)
Using CuPy, we can perform the same calculations on GPU in a similar way:
>>> x_gpu = cp.array([1, 2, 3])
>>> l2_gpu = cp.linalg.norm(x_gpu)
CuPy implements many functions on cupy.ndarray
objects.
See the reference for the supported subset of NumPy API.
Knowledge of NumPy will help you utilize most of the CuPy features.
We, therefore, recommend you familiarize yourself with the NumPy documentation.
Current Device¶
CuPy has a concept of a current device, which is the default GPU device on which
the allocation, manipulation, calculation, etc., of arrays take place.
Suppose ID of the current device is 0.
In such a case, the following code would create an array x_on_gpu0
on GPU 0.
>>> x_on_gpu0 = cp.array([1, 2, 3, 4, 5])
To switch to another GPU device, use the Device
context manager:
>>> with cp.cuda.Device(1):
... x_on_gpu1 = cp.array([1, 2, 3, 4, 5])
>>> x_on_gpu0 = cp.array([1, 2, 3, 4, 5])
All CuPy operations (except for multi-GPU features and device-to-device copy) are performed on the currently active device.
In general, CuPy functions expect that the array is on the same device as the current one. Passing an array stored on a non-current device may work depending on the hardware configuration but is generally discouraged as it may not be performant.
Note
If the array’s device and the current device mismatch, CuPy functions try to establish peer-to-peer memory access (P2P) between them so that the current device can directly read the array from another device.
Note that P2P is available only when the topology permits it.
If P2P is unavailable, such an attempt will fail with ValueError
.
cupy.ndarray.device
attribute indicates the device on which the array is allocated.
>>> with cp.cuda.Device(1):
... x = cp.array([1, 2, 3, 4, 5])
>>> x.device
<CUDA Device 1>
Note
When only one device is available, explicit device switching is not needed.
Current Stream¶
Associated with the concept of current devices are current streams, which help avoid explicitly passing streams in every single operation so as to keep the APIs pythonic and user-friendly. In CuPy, all CUDA operations such as data transfer (see the Data Transfer section) and kernel launches are enqueued onto the current stream, and the queued tasks on the same stream will be executed in serial (but asynchronously with respect to the host).
The default current stream in CuPy is CUDA’s null stream (i.e., stream 0). It is also known as the legacy
default stream, which is unique per device. However, it is possible to change the current stream using the
cupy.cuda.Stream
API, please see Accessing CUDA Functionalities for example. The current stream in CuPy can be
retrieved using cupy.cuda.get_current_stream()
.
It is worth noting that CuPy’s current stream is managed on a per thread, per device basis, meaning that on different Python threads or different devices the current stream (if not the null stream) can be different.
Data Transfer¶
Move arrays to a device¶
cupy.asarray()
can be used to move a numpy.ndarray
, a list, or any object
that can be passed to numpy.array()
to the current device:
>>> x_cpu = np.array([1, 2, 3])
>>> x_gpu = cp.asarray(x_cpu) # move the data to the current device.
cupy.asarray()
can accept cupy.ndarray
, which means we can
transfer the array between devices with this function.
>>> with cp.cuda.Device(0):
... x_gpu_0 = cp.ndarray([1, 2, 3]) # create an array in GPU 0
>>> with cp.cuda.Device(1):
... x_gpu_1 = cp.asarray(x_gpu_0) # move the array to GPU 1
Note
cupy.asarray()
does not copy the input array if possible.
So, if you put an array of the current device, it returns the input object itself.
If we do copy the array in this situation, you can use cupy.array()
with copy=True.
Actually cupy.asarray()
is equivalent to cupy.array(arr, dtype, copy=False).
Move array from a device to the host¶
Moving a device array to the host can be done by cupy.asnumpy()
as follows:
>>> x_gpu = cp.array([1, 2, 3]) # create an array in the current device
>>> x_cpu = cp.asnumpy(x_gpu) # move the array to the host.
We can also use cupy.ndarray.get()
:
>>> x_cpu = x_gpu.get()
Memory management¶
Check Memory Management for a detailed description of how memory is managed in CuPy using memory pools.
How to write CPU/GPU agnostic code¶
CuPy’s compatibility with NumPy makes it possible to write CPU/GPU agnostic code.
For this purpose, CuPy implements the cupy.get_array_module()
function that
returns a reference to cupy
if any of its arguments resides on a GPU
and numpy
otherwise.
Here is an example of a CPU/GPU agnostic function that computes log1p
:
>>> # Stable implementation of log(1 + exp(x))
>>> def softplus(x):
... xp = cp.get_array_module(x) # 'xp' is a standard usage in the community
... print("Using:", xp.__name__)
... return xp.maximum(0, x) + xp.log1p(xp.exp(-abs(x)))
When you need to manipulate CPU and GPU arrays, an explicit data
transfer may be required to move them to the same location – either CPU or GPU.
For this purpose, CuPy implements two sister methods called cupy.asnumpy()
and
cupy.asarray()
. Here is an example that demonstrates the use of both methods:
>>> x_cpu = np.array([1, 2, 3])
>>> y_cpu = np.array([4, 5, 6])
>>> x_cpu + y_cpu
array([5, 7, 9])
>>> x_gpu = cp.asarray(x_cpu)
>>> x_gpu + y_cpu
Traceback (most recent call last):
...
TypeError: Unsupported type <class 'numpy.ndarray'>
>>> cp.asnumpy(x_gpu) + y_cpu
array([5, 7, 9])
>>> cp.asnumpy(x_gpu) + cp.asnumpy(y_cpu)
array([5, 7, 9])
>>> x_gpu + cp.asarray(y_cpu)
array([5, 7, 9])
>>> cp.asarray(x_gpu) + cp.asarray(y_cpu)
array([5, 7, 9])
The cupy.asnumpy()
method returns a NumPy array (array on the host),
whereas cupy.asarray()
method returns a CuPy array (array on the current device).
Both methods can accept arbitrary input, meaning that they can be applied to any data that
is located on either the host or device and can be converted to an array.
User-Defined Kernels¶
CuPy provides easy ways to define three types of CUDA kernels: elementwise kernels, reduction kernels and raw kernels. In this documentation, we describe how to define and call each kernels.
Basics of elementwise kernels¶
An elementwise kernel can be defined by the ElementwiseKernel
class.
The instance of this class defines a CUDA kernel which can be invoked by the __call__
method of this instance.
A definition of an elementwise kernel consists of four parts: an input argument list, an output argument list, a loop body code, and the kernel name. For example, a kernel that computes a squared difference \(f(x, y) = (x - y)^2\) is defined as follows:
>>> squared_diff = cp.ElementwiseKernel(
... 'float32 x, float32 y',
... 'float32 z',
... 'z = (x - y) * (x - y)',
... 'squared_diff')
The argument lists consist of comma-separated argument definitions. Each argument definition consists of a type specifier and an argument name. Names of NumPy data types can be used as type specifiers.
Note
n
, i
, and names starting with an underscore _
are reserved for the internal use.
The above kernel can be called on either scalars or arrays with broadcasting:
>>> x = cp.arange(10, dtype=np.float32).reshape(2, 5)
>>> y = cp.arange(5, dtype=np.float32)
>>> squared_diff(x, y)
array([[ 0., 0., 0., 0., 0.],
[25., 25., 25., 25., 25.]], dtype=float32)
>>> squared_diff(x, 5)
array([[25., 16., 9., 4., 1.],
[ 0., 1., 4., 9., 16.]], dtype=float32)
Output arguments can be explicitly specified (next to the input arguments):
>>> z = cp.empty((2, 5), dtype=np.float32)
>>> squared_diff(x, y, z)
array([[ 0., 0., 0., 0., 0.],
[25., 25., 25., 25., 25.]], dtype=float32)
Type-generic kernels¶
If a type specifier is one character, then it is treated as a type placeholder.
It can be used to define a type-generic kernels.
For example, the above squared_diff
kernel can be made type-generic as follows:
>>> squared_diff_generic = cp.ElementwiseKernel(
... 'T x, T y',
... 'T z',
... 'z = (x - y) * (x - y)',
... 'squared_diff_generic')
Type placeholders of a same character in the kernel definition indicate the same type. The actual type of these placeholders is determined by the actual argument type. The ElementwiseKernel class first checks the output arguments and then the input arguments to determine the actual type. If no output arguments are given on the kernel invocation, then only the input arguments are used to determine the type.
The type placeholder can be used in the loop body code:
>>> squared_diff_generic = cp.ElementwiseKernel(
... 'T x, T y',
... 'T z',
... '''
... T diff = x - y;
... z = diff * diff;
... ''',
... 'squared_diff_generic')
More than one type placeholder can be used in a kernel definition. For example, the above kernel can be further made generic over multiple arguments:
>>> squared_diff_super_generic = cp.ElementwiseKernel(
... 'X x, Y y',
... 'Z z',
... 'z = (x - y) * (x - y)',
... 'squared_diff_super_generic')
Note that this kernel requires the output argument explicitly specified, because the type Z
cannot be automatically determined from the input arguments.
Raw argument specifiers¶
The ElementwiseKernel class does the indexing with broadcasting automatically, which is useful to define most elementwise computations.
On the other hand, we sometimes want to write a kernel with manual indexing for some arguments.
We can tell the ElementwiseKernel class to use manual indexing by adding the raw
keyword preceding the type specifier.
We can use the special variable i
and method _ind.size()
for the manual indexing.
i
indicates the index within the loop.
_ind.size()
indicates total number of elements to apply the elementwise operation.
Note that it represents the size after broadcast operation.
For example, a kernel that adds two vectors with reversing one of them can be written as follows:
>>> add_reverse = cp.ElementwiseKernel(
... 'T x, raw T y', 'T z',
... 'z = x + y[_ind.size() - i - 1]',
... 'add_reverse')
(Note that this is an artificial example and you can write such operation just by z = x + y[::-1]
without defining a new kernel).
A raw argument can be used like an array.
The indexing operator y[_ind.size() - i - 1]
involves an indexing computation on y
, so y
can be arbitrarily shaped and strode.
Note that raw arguments are not involved in the broadcasting.
If you want to mark all arguments as raw
, you must specify the size
argument on invocation, which defines the value of _ind.size()
.
Texture memory¶
Texture objects (TextureObject
) can be passed to ElementwiseKernel
with their type marked by a unique type placeholder distinct from any other types used in the same kernel, as its actual datatype is determined when populating the texture memory. The texture coordinates can be computed in the kernel by the per-thread loop index i
.
Reduction kernels¶
Reduction kernels can be defined by the ReductionKernel
class.
We can use it by defining four parts of the kernel code:
Identity value: This value is used for the initial value of reduction.
Mapping expression: It is used for the pre-processing of each element to be reduced.
Reduction expression: It is an operator to reduce the multiple mapped values. The special variables
a
andb
are used for its operands.Post mapping expression: It is used to transform the resulting reduced values. The special variable
a
is used as its input. Output should be written to the output parameter.
ReductionKernel class automatically inserts other code fragments that are required for an efficient and flexible reduction implementation.
For example, L2 norm along specified axes can be written as follows:
>>> l2norm_kernel = cp.ReductionKernel(
... 'T x', # input params
... 'T y', # output params
... 'x * x', # map
... 'a + b', # reduce
... 'y = sqrt(a)', # post-reduction map
... '0', # identity value
... 'l2norm' # kernel name
... )
>>> x = cp.arange(10, dtype=np.float32).reshape(2, 5)
>>> l2norm_kernel(x, axis=1)
array([ 5.477226 , 15.9687195], dtype=float32)
Note
raw
specifier is restricted for usages that the axes to be reduced are put at the head of the shape.
It means, if you want to use raw
specifier for at least one argument, the axis
argument must be 0
or a contiguous increasing sequence of integers starting from 0
, like (0, 1)
, (0, 1, 2)
, etc.
Note
Texture memory is not yet supported in ReductionKernel
.
Raw kernels¶
Raw kernels can be defined by the RawKernel
class.
By using raw kernels, you can define kernels from raw CUDA source.
RawKernel
object allows you to call the kernel with CUDA’s cuLaunchKernel
interface.
In other words, you have control over grid size, block size, shared memory size and stream.
>>> add_kernel = cp.RawKernel(r'''
... extern "C" __global__
... void my_add(const float* x1, const float* x2, float* y) {
... int tid = blockDim.x * blockIdx.x + threadIdx.x;
... y[tid] = x1[tid] + x2[tid];
... }
... ''', 'my_add')
>>> x1 = cp.arange(25, dtype=cp.float32).reshape(5, 5)
>>> x2 = cp.arange(25, dtype=cp.float32).reshape(5, 5)
>>> y = cp.zeros((5, 5), dtype=cp.float32)
>>> add_kernel((5,), (5,), (x1, x2, y)) # grid, block and arguments
>>> y
array([[ 0., 2., 4., 6., 8.],
[10., 12., 14., 16., 18.],
[20., 22., 24., 26., 28.],
[30., 32., 34., 36., 38.],
[40., 42., 44., 46., 48.]], dtype=float32)
Raw kernels operating on complex-valued arrays can be created as well:
>>> complex_kernel = cp.RawKernel(r'''
... #include <cupy/complex.cuh>
... extern "C" __global__
... void my_func(const complex<float>* x1, const complex<float>* x2,
... complex<float>* y, float a) {
... int tid = blockDim.x * blockIdx.x + threadIdx.x;
... y[tid] = x1[tid] + a * x2[tid];
... }
... ''', 'my_func')
>>> x1 = cupy.arange(25, dtype=cupy.complex64).reshape(5, 5)
>>> x2 = 1j*cupy.arange(25, dtype=cupy.complex64).reshape(5, 5)
>>> y = cupy.zeros((5, 5), dtype=cupy.complex64)
>>> complex_kernel((5,), (5,), (x1, x2, y, cupy.float32(2.0))) # grid, block and arguments
>>> y
array([[ 0. +0.j, 1. +2.j, 2. +4.j, 3. +6.j, 4. +8.j],
[ 5.+10.j, 6.+12.j, 7.+14.j, 8.+16.j, 9.+18.j],
[10.+20.j, 11.+22.j, 12.+24.j, 13.+26.j, 14.+28.j],
[15.+30.j, 16.+32.j, 17.+34.j, 18.+36.j, 19.+38.j],
[20.+40.j, 21.+42.j, 22.+44.j, 23.+46.j, 24.+48.j]],
dtype=complex64)
Note that while we encourage the usage of complex<T>
types for complex numbers (available by including <cupy/complex.cuh>
as shown above), for CUDA codes already written using functions from cuComplex.h
there is no need to make the conversion yourself: just set the option translate_cucomplex=True
when creating a RawKernel
instance.
The CUDA kernel attributes can be retrieved by either accessing the attributes
dictionary,
or by accessing the RawKernel
object’s attributes directly; the latter can also be used to set certain
attributes:
>>> add_kernel = cp.RawKernel(r'''
... extern "C" __global__
... void my_add(const float* x1, const float* x2, float* y) {
... int tid = blockDim.x * blockIdx.x + threadIdx.x;
... y[tid] = x1[tid] + x2[tid];
... }
... ''', 'my_add')
>>> add_kernel.attributes
{'max_threads_per_block': 1024, 'shared_size_bytes': 0, 'const_size_bytes': 0, 'local_size_bytes': 0, 'num_regs': 10, 'ptx_version': 70, 'binary_version': 70, 'cache_mode_ca': 0, 'max_dynamic_shared_size_bytes': 49152, 'preferred_shared_memory_carveout': -1}
>>> add_kernel.max_dynamic_shared_size_bytes
49152
>>> add_kernel.max_dynamic_shared_size_bytes = 50000 # set a new value for the attribute
>>> add_kernel.max_dynamic_shared_size_bytes
50000
Dynamical parallelism is supported by RawKernel
. You just need to provide the linking flag (such as -dc
) to RawKernel
’s options
argument. The static CUDA device runtime library (cudadevrt
) is automatically discovered by CuPy. For further detail, see CUDA Toolkit’s documentation.
Accessing texture (surface) memory in RawKernel
is supported via CUDA Runtime’s Texture (Surface) Object API, see the documentation for TextureObject
(SurfaceObject
) as well as CUDA C Programming Guide. For using the Texture Reference API, which is marked as deprecated as of CUDA Toolkit 10.1, see the introduction to RawModule
below.
If your kernel relies on the C++ std library headers such as <type_traits>
, it is likely you will encounter compilation errors. In this case, try enabling CuPy’s Jitify support by setting jitify=True
when creating the RawKernel
instance. It provides basic C++ std support to remedy common errors.
Note
The kernel does not have return values. You need to pass both input arrays and output arrays as arguments.
Note
When using printf()
in your CUDA kernel, you may need to synchronize the stream to see the output.
You can use cupy.cuda.Stream.null.synchronize()
if you are using the default stream.
Note
In all of the examples above, we declare the kernels in an extern "C"
block,
indicating that the C linkage is used. This is to ensure the kernel names are not
mangled so that they can be retrived by name.
Kernel arguments¶
Python primitive types and NumPy scalars are passed to the kernel by value. Array arguments (pointer arguments) have to be passed as CuPy ndarrays. No validation is performed by CuPy for arguments passed to the kernel, including types and number of arguments.
Especially note that when passing a CuPy ndarray
, its dtype
should match with the type of the argument declared in the function signature of the CUDA source code (unless you are casting arrays intentionally).
As an example, cupy.float32
and cupy.uint64
arrays must be passed to the argument typed as float*
and unsigned long long*
, respectively. CuPy does not directly support arrays of non-primitive types such as float3
, but nothing prevents you from casting a float*
or void*
to a float3*
in a kernel.
Python primitive types, int
, float
, complex
and bool
map to long long
, double
, cuDoubleComplex
and bool
, respectively.
NumPy scalars (numpy.generic
) and NumPy arrays (numpy.ndarray
) of size one
are passed to the kernel by value.
This means that you can pass by value any base NumPy types such as numpy.int8
or numpy.float64
, provided the kernel arguments match in size. You can refer to this table to match CuPy/NumPy dtype and CUDA types:
CuPy/NumPy type |
Corresponding kernel types |
itemsize (bytes) |
---|---|---|
bool |
bool |
1 |
int8 |
char, signed char |
1 |
int16 |
short, signed short |
2 |
int32 |
int, signed int |
4 |
int64 |
long long, signed long long |
8 |
uint8 |
unsigned char |
1 |
uint16 |
unsigned short |
2 |
uint32 |
unsigned int |
4 |
uint64 |
unsigned long long |
8 |
float16 |
half |
2 |
float32 |
float |
4 |
float64 |
double |
8 |
complex64 |
float2, cuFloatComplex, complex<float> |
8 |
complex128 |
double2, cuDoubleComplex, complex<double> |
16 |
The CUDA standard guarantees that the size of fundamental types on the host and device always match.
The itemsize of size_t
, ptrdiff_t
, intptr_t
, uintptr_t
,
long
, signed long
and unsigned long
are however platform dependent.
To pass any CUDA vector builtins such as float3
or any other user defined structure
as kernel arguments (provided it matches the device-side kernel parameter type), see Custom user types below.
Custom user types¶
It is possible to use custom types (composite types such as structures and structures of structures) as kernel arguments by defining a custom NumPy dtype. When doing this, it is your responsibility to match host and device structure memory layout. The CUDA standard guarantees that the size of fundamental types on the host and device always match. It may however impose device alignment requirements on composite types. This means that for composite types the struct member offsets may be different from what you might expect.
When a kernel argument is passed by value, the CUDA driver will copy exactly sizeof(param_type)
bytes starting from the beginning of the NumPy object data pointer, where param_type
is the parameter type in your kernel.
You have to match param_type
’s memory layout (ex: size, alignment and struct padding/packing)
by defining a corresponding NumPy dtype.
For builtin CUDA vector types such as int2
and double4
and other packed structures with
named members you can directly define such NumPy dtypes as the following:
>>> import numpy as np
>>> names = ['x', 'y', 'z']
>>> types = [np.float32]*3
>>> float3 = np.dtype({'names': names, 'formats': types})
>>> arg = np.random.rand(3).astype(np.float32).view(float3)
>>> print(arg)
[(0.9940819, 0.62873816, 0.8953669)]
>>> arg['x'] = 42.0
>>> print(arg)
[(42., 0.62873816, 0.8953669)]
Here arg
can be used directly as a kernel argument.
When there is no need to name fields you may prefer this syntax to define packed structures such as
vectors or matrices:
>>> import numpy as np
>>> float5x5 = np.dtype({'names': ['dummy'], 'formats': [(np.float32,(5,5))]})
>>> arg = np.random.rand(25).astype(np.float32).view(float5x5)
>>> print(arg.itemsize)
100
Here arg
represents a 100-byte scalar (i.e. a NumPy array of size 1)
that can be passed by value to any kernel.
Kernel parameters are passed by value in a dedicated 4kB memory bank which has its own cache with broadcast.
Upper bound for total kernel parameters size is thus 4kB
(see this link).
It may be important to note that this dedicated memory bank is not shared with the device __constant__
memory space.
For now, CuPy offers no helper routines to create user defined composite types. Such composite types can however be built recursively using NumPy dtype offsets and itemsize capabilities, see cupy/examples/custum_struct for examples of advanced usage.
Warning
You cannot directly pass static arrays as kernel arguments with the type arg[N]
syntax where N is a compile time constant. The signature of __global__ void kernel(float arg[5])
is seen as __global__ void kernel(float* arg)
by the compiler. If you want to pass five floats to the kernel by value you need to define a custom structure struct float5 { float val[5]; };
and modify the kernel signature to __global__ void kernel(float5 arg)
.
Raw modules¶
For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule
class can be more handy. It can be initialized either by a CUDA source code, or by a path to the CUDA binary. It accepts most of the arguments as in RawKernel
. The needed kernels can then be retrieved by calling the get_function()
method, which returns a RawKernel
instance that can be invoked as discussed above.
>>> loaded_from_source = r'''
... extern "C"{
...
... __global__ void test_sum(const float* x1, const float* x2, float* y, \
... unsigned int N)
... {
... unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
... if (tid < N)
... {
... y[tid] = x1[tid] + x2[tid];
... }
... }
...
... __global__ void test_multiply(const float* x1, const float* x2, float* y, \
... unsigned int N)
... {
... unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
... if (tid < N)
... {
... y[tid] = x1[tid] * x2[tid];
... }
... }
...
... }'''
>>> module = cp.RawModule(code=loaded_from_source)
>>> ker_sum = module.get_function('test_sum')
>>> ker_times = module.get_function('test_multiply')
>>> N = 10
>>> x1 = cp.arange(N**2, dtype=cp.float32).reshape(N, N)
>>> x2 = cp.ones((N, N), dtype=cp.float32)
>>> y = cp.zeros((N, N), dtype=cp.float32)
>>> ker_sum((N,), (N,), (x1, x2, y, N**2)) # y = x1 + x2
>>> assert cp.allclose(y, x1 + x2)
>>> ker_times((N,), (N,), (x1, x2, y, N**2)) # y = x1 * x2
>>> assert cp.allclose(y, x1 * x2)
The instruction above for using complex numbers in RawKernel
also applies to RawModule
.
For CUDA kernels that need to access global symbols, such as constant memory, the get_global()
method can be used, see its documentation for further detail.
CuPy also supports the Texture Reference API. A handle to the texture reference in a module can be retrieved by name via get_texref()
. Then, you need to pass it to TextureReference
, along with a resource descriptor and texture descriptor, for binding the reference to the array. (The interface of TextureReference
is meant to mimic that of TextureObject
to help users make transition to the latter, since as of CUDA Toolkit 10.1 the former is marked as deprecated.)
To support C++ template kernels, RawModule
additionally provide a name_expressions
argument. A list of template specializations should be provided, so that the corresponding kernels can be generated and retrieved by type:
>>> code = r'''
... template<typename T>
... __global__ void fx3(T* arr, int N) {
... unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x;
... if (tid < N) {
... arr[tid] = arr[tid] * 3;
... }
... }
... '''
>>>
>>> name_exp = ['fx3<float>', 'fx3<double>']
>>> mod = cp.RawModule(code=code, options=('-std=c++11',),
... name_expressions=name_exp)
>>> ker_float = mod.get_function(name_exp[0]) # compilation happens here
>>> N=10
>>> a = cp.arange(N, dtype=cp.float32)
>>> ker_float((1,), (N,), (a, N))
>>> a
array([ 0., 3., 6., 9., 12., 15., 18., 21., 24., 27.], dtype=float32)
>>> ker_double = mod.get_function(name_exp[1])
>>> a = cp.arange(N, dtype=cp.float64)
>>> ker_double((1,), (N,), (a, N))
>>> a
array([ 0., 3., 6., 9., 12., 15., 18., 21., 24., 27.])
Note
The name expressions used to both initialize a RawModule
instance and retrieve the kernels are
the original (un-mangled) kernel names with all template parameters unambiguously specified. The name mangling
and demangling are handled under the hood so that users do not need to worry about it.
Kernel fusion¶
cupy.fuse()
is a decorator that fuses functions. This decorator can be used to define an elementwise or reduction kernel more easily than ElementwiseKernel
or ReductionKernel
.
By using this decorator, we can define the squared_diff
kernel as follows:
>>> @cp.fuse()
... def squared_diff(x, y):
... return (x - y) * (x - y)
The above kernel can be called on either scalars, NumPy arrays or CuPy arrays likes the original function.
>>> x_cp = cp.arange(10)
>>> y_cp = cp.arange(10)[::-1]
>>> squared_diff(x_cp, y_cp)
array([81, 49, 25, 9, 1, 1, 9, 25, 49, 81])
>>> x_np = np.arange(10)
>>> y_np = np.arange(10)[::-1]
>>> squared_diff(x_np, y_np)
array([81, 49, 25, 9, 1, 1, 9, 25, 49, 81])
At the first function call, the fused function analyzes the original function based on the abstracted information of arguments (e.g. their dtypes and ndims) and creates and caches an actual CUDA kernel. From the second function call with the same input types, the fused function calls the previously cached kernel, so it is highly recommended to reuse the same decorated functions instead of decorating local functions that are defined multiple times.
cupy.fuse()
also supports simple reduction kernel.
>>> @cp.fuse()
... def sum_of_products(x, y):
... return cp.sum(x * y, axis = -1)
You can specify the kernel name by using the kernel_name
keyword argument as follows:
>>> @cp.fuse(kernel_name='squared_diff')
... def squared_diff(x, y):
... return (x - y) * (x - y)
Note
Currently, cupy.fuse()
can fuse only simple elementwise and reduction operations. Most other routines (e.g. cupy.matmul()
, cupy.reshape()
) are not supported.
JIT kernel definition¶
The cupyx.jit.rawkernel
decorator can create raw CUDA kernels from Python functions.
In this section, a Python function wrapped with the decorator is called a target function.
A target function consists of elementary scalar operations, and users have to manage how to parallelize them. CuPy’s array operations which automatically parallelize operations (e.g., add()
, sum()
) are not supported. If a custom kernel based on such array functions is desired, please refer to the Kernel fusion section.
Basic Usage¶
Here is a short example for how to write a cupyx.jit.rawkernel
to copy the values from x
to y
using a grid-stride loop:
>>> from cupyx import jit
>>>
>>> @jit.rawkernel()
... def elementwise_copy(x, y, size):
... tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
... ntid = jit.gridDim.x * jit.blockDim.x
... for i in range(tid, size, ntid):
... y[i] = x[i]
>>> size = cupy.uint32(2 ** 22)
>>> x = cupy.random.normal(size=(size,), dtype=cupy.float32)
>>> y = cupy.empty((size,), dtype=cupy.float32)
>>> elementwise_copy((128,), (1024,), (x, y, size)) # RawKernel style
>>> assert (x == y).all()
>>> elementwise_copy[128, 1024](x, y, size) # Numba style
>>> assert (x == y).all()
The above two kinds of styles to launch the kernel are supported, see the documentation of cupyx.jit._interface._JitRawKernel
for details.
The compilation will be deferred until the first function call. CuPy’s JIT compiler infers the types of arguments at the call time, and will cache the compiled kernels for speeding up any subsequent calls.
See Custom kernels for a full list of API.
Basic Design¶
CuPy’s JIT compiler generates CUDA code via Python AST. We decided not to use Python bytecode to analyze the target function to avoid perforamance degradation. The CUDA source code generated from the Python bytecode will not effectively optimized by CUDA compiler, because for-loops and other control statements of the target function are fully transformed to jump instruction when converting the target function to bytecode.
Typing rule¶
The types of local variables are inferred at the first assignment in the function. The first assignment must be done at the top-level of the function; in other words, it must not be in if
/else
bodies or for
-loops.
Limitations¶
CuPy’s JIT compiler uses inspect.getsource()
to get the source code of the target function, so the compiler does not work in the following situations:
In Python REPL
Lambda expressions as target functions
Accessing CUDA Functionalities¶
Streams and Events¶
In this section we discuss basic usages for CUDA streams and events. For the API reference please see Streams and events. For their roles in the CUDA programming model, please refer to CUDA Programming Guide.
CuPy provides high-level Python APIs Stream
and Event
for creating
streams and events, respectively. Data copies and kernel launches are enqueued onto the Current Stream,
which can be queried via get_current_stream()
and changed either by setting up a context
manager:
>>> import numpy as np
>>>
>>> a_np = np.arange(10)
>>> s = cp.cuda.Stream()
>>> with s:
... a_cp = cp.asarray(a_np) # H2D transfer on stream s
... b_cp = cp.sum(a_cp) # kernel launched on stream s
... assert s == cp.cuda.get_current_stream()
...
>>> # fall back to the previous stream in use (here the default stream)
>>> # when going out of the scope of s
or by using the use()
method:
>>> s = cp.cuda.Stream()
>>> s.use() # any subsequent operations are done on steam s
<Stream ... (device ...)>
>>> b_np = cp.asnumpy(b_cp)
>>> assert s == cp.cuda.get_current_stream()
>>> cp.cuda.Stream.null.use() # fall back to the default (null) stream
<Stream 0 (device -1)>
>>> assert cp.cuda.Stream.null == cp.cuda.get_current_stream()
Events can be created either manually or through the record()
method.
Event
objects can be used for timing GPU activities (via get_elapsed_time()
)
or setting up inter-stream dependencies:
>>> e1 = cp.cuda.Event()
>>> e1.record()
>>> a_cp = b_cp * a_cp + 8
>>> e2 = cp.cuda.get_current_stream().record()
>>>
>>> # set up a stream order
>>> s2 = cp.cuda.Stream()
>>> s2.wait_event(e2)
>>> with s2:
... # the a_cp is guaranteed updated when this copy (on s2) starts
... a_np = cp.asnumpy(a_cp)
>>>
>>> # timing
>>> e2.synchronize()
>>> t = cp.cuda.get_elapsed_time(e1, e2) # only include the compute time, not the copy time
Just like the Device
objects, Stream
and Event
objects can also be used for synchronization.
Note
In CuPy, the Stream
objects are managed on the per thread, per device basis.
Note
On NVIDIA GPUs, there are two stream singleton objects null
and
ptds
, referred to as the legacy default stream and the per-thread default
stream, respectively. CuPy uses the former as default when no user-defined stream is in use. To
change this behavior, set the environment variable CUPY_CUDA_PER_THREAD_DEFAULT_STREAM
to 1,
see Environment variables. This is not applicable to AMD GPUs.
To interoperate with streams created in other Python libraries, CuPy provides the ExternalStream
API to wrap an existing stream pointer (given as a Python int). In this case, the stream lifetime is not managed
by CuPy. In addition, you need to make sure the ExternalStream
object is used on the device
where the stream was created, either manually or by explicitly setting the optional device_id argument. But the
created ExternalStream
object can otherwise be used like a Stream
object.
CUDA Driver and Runtime API¶
Under construction. Please see Runtime API for the API reference.
Fast Fourier Transform with CuPy¶
CuPy covers the full Fast Fourier Transform (FFT) functionalities provided in NumPy (cupy.fft
) and a
subset in SciPy (cupyx.scipy.fft
). In addition to those high-level APIs that can be used
as is, CuPy provides additional features to
access advanced routines that cuFFT offers for NVIDIA GPUs,
control better the performance and behavior of the FFT routines.
Some of these features are experimental (subject to change, deprecation, or removal, see API Compatibility Policy) or may be absent in hipFFT/rocFFT targeting AMD GPUs.
SciPy FFT backend¶
Since SciPy v1.4 a backend mechanism is provided so that users can register different FFT backends and use SciPy’s API to perform the actual transform
with the target backend, such as CuPy’s cupyx.scipy.fft
module. For a one-time only usage, a context manager scipy.fft.set_backend()
can be used:
import cupy as cp
import cupyx.scipy.fft as cufft
import scipy.fft
a = cp.random.random(100).astype(cp.complex64)
with scipy.fft.set_backend(cufft):
b = scipy.fft.fft(a) # equivalent to cufft.fft(a)
However, such usage can be tedious. Alternatively, users can register a backend through scipy.fft.register_backend()
or scipy.fft.set_global_backend()
to avoid using context managers:
import cupy as cp
import cupyx.scipy.fft as cufft
import scipy.fft
scipy.fft.set_global_backend(cufft)
a = cp.random.random(100).astype(cp.complex64)
b = scipy.fft.fft(a) # equivalent to cufft.fft(a)
Note
Please refer to SciPy FFT documentation for further information.
Note
To use the backend together with an explicit plan
argument requires SciPy version 1.5.0 or higher.
See below for how to create FFT plans.
User-managed FFT plans¶
For performance reasons, users may wish to create, reuse, and manage the FFT plans themselves. CuPy provides a high-level experimental API get_fft_plan()
for this need. Users specify the transform to be performed as they would with most of the high-level FFT APIs, and a plan will be generated based on the input.
import cupy as cp
from cupyx.scipy.fft import get_fft_plan
a = cp.random.random((4, 64, 64)).astype(cp.complex64)
plan = get_fft_plan(a, axes=(1, 2), value_type='C2C') # for batched, C2C, 2D transform
The returned plan can be used either explicitly as an argument with the cupyx.scipy.fft
APIs:
import cupyx.scipy.fft
# the rest of the arguments must match those used when generating the plan
out = cupyx.scipy.fft.fft2(a, axes=(1, 2), plan=plan)
or as a context manager for the cupy.fft
APIs:
with plan:
# the arguments must match those used when generating the plan
out = cp.fft.fft2(a, axes=(1, 2))
FFT plan cache¶
However, there are occasions when users may not want to manage the FFT plans by themselves. Moreover, plans could also be reused internally in CuPy’s routines, to which user-managed plans would not be applicable. Therefore, starting CuPy v8 we provide a built-in plan cache, enabled by default. The plan cache is done on a per device, per thread basis, and can be retrieved by the get_plan_cache()
API.
>>> import cupy as cp
>>>
>>> cache = cp.fft.config.get_plan_cache()
>>> cache.show_info()
------------------- cuFFT plan cache (device 0) -------------------
cache enabled? True
current / max size : 0 / 16 (counts)
current / max memsize: 0 / (unlimited) (bytes)
hits / misses: 0 / 0 (counts)
cached plans (most recently used first):
>>> # perform a transform, which would generate a plan and cache it
>>> a = cp.random.random((4, 64, 64))
>>> out = cp.fft.fftn(a, axes=(1, 2))
>>> cache.show_info() # hit = 0
------------------- cuFFT plan cache (device 0) -------------------
cache enabled? True
current / max size : 1 / 16 (counts)
current / max memsize: 262144 / (unlimited) (bytes)
hits / misses: 0 / 1 (counts)
cached plans (most recently used first):
key: ((64, 64), (64, 64), 1, 4096, (64, 64), 1, 4096, 105, 4, 'C', 2, None), plan type: PlanNd, memory usage: 262144
>>> # perform the same transform again, the plan is looked up from cache and reused
>>> out = cp.fft.fftn(a, axes=(1, 2))
>>> cache.show_info() # hit = 1
------------------- cuFFT plan cache (device 0) -------------------
cache enabled? True
current / max size : 1 / 16 (counts)
current / max memsize: 262144 / (unlimited) (bytes)
hits / misses: 1 / 1 (counts)
cached plans (most recently used first):
key: ((64, 64), (64, 64), 1, 4096, (64, 64), 1, 4096, 105, 4, 'C', 2, None), plan type: PlanNd, memory usage: 262144
>>> # clear the cache
>>> cache.clear()
>>> cp.fft.config.show_plan_cache_info() # = cache.show_info(), for all devices
=============== cuFFT plan cache info (all devices) ===============
------------------- cuFFT plan cache (device 0) -------------------
cache enabled? True
current / max size : 0 / 16 (counts)
current / max memsize: 0 / (unlimited) (bytes)
hits / misses: 0 / 0 (counts)
cached plans (most recently used first):
The returned PlanCache
object has other methods for finer control, such as setting the cache size (either by counts or by memory usage). If the size is set to 0, the cache is disabled. Please refer to its documentation for more detail.
Note
As shown above each FFT plan has an associated working area allocated. If an out-of-memory error happens, one may want to inspect, clear, or limit the plan cache.
Note
The plans returned by get_fft_plan()
are not cached.
FFT callbacks¶
cuFFT provides FFT callbacks for merging pre- and/or post- processing kernels with the FFT routines so as to reduce the access to global memory.
This capability is supported experimentally by CuPy. Users need to supply custom load and/or store kernels as strings, and set up a context manager
via set_cufft_callbacks()
. Note that the load (store) kernel pointer has to be named as d_loadCallbackPtr
(d_storeCallbackPtr
).
import cupy as cp
# a load callback that overwrites the input array to 1
code = r'''
__device__ cufftComplex CB_ConvertInputC(
void *dataIn,
size_t offset,
void *callerInfo,
void *sharedPtr)
{
cufftComplex x;
x.x = 1.;
x.y = 0.;
return x;
}
__device__ cufftCallbackLoadC d_loadCallbackPtr = CB_ConvertInputC;
'''
a = cp.random.random((64, 128, 128)).astype(cp.complex64)
# this fftn call uses callback
with cp.fft.config.set_cufft_callbacks(cb_load=code):
b = cp.fft.fftn(a, axes=(1,2))
# this does not use
c = cp.fft.fftn(cp.ones(shape=a.shape, dtype=cp.complex64), axes=(1,2))
# result agrees
assert cp.allclose(b, c)
# "static" plans are also cached, but are distinct from their no-callback counterparts
cp.fft.config.get_plan_cache().show_info()
Note
Internally, this feature requires recompiling a Python module for each distinct pair of load and store kernels. Therefore, the first invocation will be very slow, and this cost is amortized if the callbacks can be reused in the subsequent calculations. The compiled modules are cached on disk, with a default position $HOME/.cupy/callback_cache
that can be changed by the environment variable CUPY_CACHE_DIR
.
Multi-GPU FFT¶
CuPy currently provides two kinds of experimental support for multi-GPU FFT.
Warning
Using multiple GPUs to perform FFT is not guaranteed to be more performant. The rule of thumb is if the transform fits in 1 GPU, you should avoid using multiple.
The first kind of support is with the high-level fft()
and ifft()
APIs, which requires the input array to reside on one of the participating GPUs. The multi-GPU calculation is done under the hood, and by the end of the calculation the result again resides on the device where it started. Currently only 1D complex-to-complex (C2C) transform is supported; complex-to-real (C2R) or real-to-complex (R2C) transforms (such as rfft()
and friends) are not. The transform can be either batched (batch size > 1) or not (batch size = 1).
import cupy as cp
cp.fft.config.use_multi_gpus = True
cp.fft.config.set_cufft_gpus([0, 1]) # use GPU 0 & 1
shape = (64, 64) # batch size = 64
dtype = cp.complex64
a = cp.random.random(shape).astype(dtype) # reside on GPU 0
b = cp.fft.fft(a) # computed on GPU 0 & 1, reside on GPU 0
If you need to perform 2D/3D transforms (ex: fftn()
) instead of 1D (ex: fft()
), it would likely still work, but in this particular use case it loops over the transformed axes under the hood (which is exactly what is done in NumPy too), which could lead to suboptimal performance.
The second kind of usage is to use the low-level, private CuPy APIs. You need to construct a Plan1d
object and use it as if you are programming in C/C++ with cuFFT. Using this approach, your input array can reside on the host as a numpy.ndarray
so that its size can be much larger than what a single GPU can accommodate, which is one of the main reasons to run multi-GPU FFT.
import numpy as np
import cupy as cp
# no need to touch cp.fft.config, as we are using low-level API
shape = (64, 64)
dtype = np.complex64
a = np.random.random(shape).astype(dtype) # reside on CPU
if len(shape) == 1:
batch = 1
nx = shape[0]
elif len(shape) == 2:
batch = shape[0]
nx = shape[1]
# compute via cuFFT
cufft_type = cp.cuda.cufft.CUFFT_C2C # single-precision c2c
plan = cp.cuda.cufft.Plan1d(nx, cufft_type, batch, devices=[0,1])
out_cp = np.empty_like(a) # output on CPU
plan.fft(a, out_cp, cufft.CUFFT_FORWARD)
out_np = numpy.fft.fft(a) # use NumPy's fft
# np.fft.fft alway returns np.complex128
if dtype is numpy.complex64:
out_np = out_np.astype(dtype)
# check result
assert np.allclose(out_cp, out_np, rtol=1e-4, atol=1e-7)
For this use case, please consult the cuFFT documentation on multi-GPU transform for further detail.
Note
The multi-GPU plans are cached if auto-generated via the high-level APIs, but not if manually generated via the low-level APIs.
Half-precision FFT¶
cuFFT provides cufftXtMakePlanMany
and cufftXtExec
routines to support a wide range of FFT needs, including 64-bit indexing and half-precision FFT. CuPy provides an experimental support for this capability via the new (though private) XtPlanNd
API. For half-precision FFT, on supported hardware it can be twice as fast than its single-precision counterpart. NumPy does not yet provide the necessary infrastructure for half-precision complex numbers (i.e., numpy.complex32
), though, so the steps for this feature is currently a bit more involved than common cases.
import cupy as cp
import numpy as np
shape = (1024, 256, 256) # input array shape
idtype = odtype = edtype = 'E' # = numpy.complex32 in the future
# store the input/output arrays as fp16 arrays twice as long, as complex32 is not yet available
a = cp.random.random((shape[0], shape[1], 2*shape[2])).astype(cp.float16)
out = cp.empty_like(a)
# FFT with cuFFT
plan = cp.cuda.cufft.XtPlanNd(shape[1:],
shape[1:], 1, shape[1]*shape[2], idtype,
shape[1:], 1, shape[1]*shape[2], odtype,
shape[0], edtype,
order='C', last_axis=-1, last_size=None)
plan.fft(a, out, cp.cuda.cufft.CUFFT_FORWARD)
# FFT with NumPy
a_np = cp.asnumpy(a).astype(np.float32) # upcast
a_np = a_np.view(np.complex64)
out_np = np.fft.fftn(a_np, axes=(-2,-1))
out_np = np.ascontiguousarray(out_np).astype(np.complex64) # downcast
out_np = out_np.view(np.float32)
out_np = out_np.astype(np.float16)
# don't worry about accruacy for now, as we probably lost a lot during casting
print('ok' if cp.mean(cp.abs(out - cp.asarray(out_np))) < 0.1 else 'not ok')
The 64-bit indexing support for all high-level FFT APIs is planned for a future CuPy release.
Memory Management¶
CuPy uses memory pool for memory allocations by default. The memory pool significantly improves the performance by mitigating the overhead of memory allocation and CPU/GPU synchronization.
There are two different memory pools in CuPy:
Device memory pool (GPU device memory), which is used for GPU memory allocations.
Pinned memory pool (non-swappable CPU memory), which is used during CPU-to-GPU data transfer.
Attention
When you monitor the memory usage (e.g., using nvidia-smi
for GPU memory or ps
for CPU memory), you may notice that memory not being freed even after the array instance become out of scope.
This is an expected behavior, as the default memory pool “caches” the allocated memory blocks.
See Low-level CUDA support for the details of memory management APIs.
For using pinned memory more conveniently, we also provide a few high-level APIs in the cupyx
namespace,
including cupyx.empty_pinned()
, cupyx.empty_like_pinned()
, cupyx.zeros_pinned()
, and
cupyx.zeros_like_pinned()
. They return NumPy arrays backed by pinned memory. If CuPy’s pinned memory pool
is in use, the pinned memory is allocated from the pool.
Note
CuPy v8 and above provides a FFT plan cache that could use a portion of device memory if FFT and related functions are used. The memory taken can be released by shrinking or disabling the cache.
Memory Pool Operations¶
The memory pool instance provides statistics about memory allocation.
To access the default memory pool instance, use cupy.get_default_memory_pool()
and cupy.get_default_pinned_memory_pool()
.
You can also free all unused memory blocks hold in the memory pool.
See the example code below for details:
import cupy
import numpy
mempool = cupy.get_default_memory_pool()
pinned_mempool = cupy.get_default_pinned_memory_pool()
# Create an array on CPU.
# NumPy allocates 400 bytes in CPU (not managed by CuPy memory pool).
a_cpu = numpy.ndarray(100, dtype=numpy.float32)
print(a_cpu.nbytes) # 400
# You can access statistics of these memory pools.
print(mempool.used_bytes()) # 0
print(mempool.total_bytes()) # 0
print(pinned_mempool.n_free_blocks()) # 0
# Transfer the array from CPU to GPU.
# This allocates 400 bytes from the device memory pool, and another 400
# bytes from the pinned memory pool. The allocated pinned memory will be
# released just after the transfer is complete. Note that the actual
# allocation size may be rounded to larger value than the requested size
# for performance.
a = cupy.array(a_cpu)
print(a.nbytes) # 400
print(mempool.used_bytes()) # 512
print(mempool.total_bytes()) # 512
print(pinned_mempool.n_free_blocks()) # 1
# When the array goes out of scope, the allocated device memory is released
# and kept in the pool for future reuse.
a = None # (or `del a`)
print(mempool.used_bytes()) # 0
print(mempool.total_bytes()) # 512
print(pinned_mempool.n_free_blocks()) # 1
# You can clear the memory pool by calling `free_all_blocks`.
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
print(mempool.used_bytes()) # 0
print(mempool.total_bytes()) # 0
print(pinned_mempool.n_free_blocks()) # 0
See cupy.cuda.MemoryPool
and cupy.cuda.PinnedMemoryPool
for details.
Limiting GPU Memory Usage¶
You can hard-limit the amount of GPU memory that can be allocated by using CUPY_GPU_MEMORY_LIMIT
environment variable (see Environment variables for details).
# Set the hard-limit to 1 GiB:
# $ export CUPY_GPU_MEMORY_LIMIT="1073741824"
# You can also specify the limit in fraction of the total amount of memory
# on the GPU. If you have a GPU with 2 GiB memory, the following is
# equivalent to the above configuration.
# $ export CUPY_GPU_MEMORY_LIMIT="50%"
import cupy
print(cupy.get_default_memory_pool().get_limit()) # 1073741824
You can also set the limit (or override the value specified via the environment variable) using cupy.cuda.MemoryPool.set_limit()
.
In this way, you can use a different limit for each GPU device.
import cupy
mempool = cupy.get_default_memory_pool()
with cupy.cuda.Device(0):
mempool.set_limit(size=1024**3) # 1 GiB
with cupy.cuda.Device(1):
mempool.set_limit(size=2*1024**3) # 2 GiB
Note
CUDA allocates some GPU memory outside of the memory pool (such as CUDA context, library handles, etc.). Depending on the usage, such memory may take one to few hundred MiB. That will not be counted in the limit.
Changing Memory Pool¶
You can use your own memory allocator instead of the default memory pool by passing the memory allocation function to cupy.cuda.set_allocator()
/ cupy.cuda.set_pinned_memory_allocator()
.
The memory allocator function should take 1 argument (the requested size in bytes) and return cupy.cuda.MemoryPointer
/ cupy.cuda.PinnedMemoryPointer
.
CuPy provides two such allocators for using managed memory and stream ordered memory on GPU,
see cupy.cuda.malloc_managed()
and cupy.cuda.malloc_async()
, respectively, for details.
To enable a memory pool backed by managed memory, you can construct a new MemoryPool
instance with its allocator
set to malloc_managed()
as follows
import cupy
# Use managed memory
cupy.cuda.set_allocator(cupy.cuda.MemoryPool(cupy.cuda.malloc_managed).malloc)
Note that if you pass malloc_managed()
directly to set_allocator()
without constructing
a MemoryPool
instance, when the memory is freed it will be released back to the system immediately,
which may or may not be desired.
Stream Ordered Memory Allocator is a new feature added since CUDA 11.2. CuPy provides an experimental interface to it. Similar to CuPy’s memory pool, Stream Ordered Memory Allocator also allocates/deallocates memory asynchronously from/to a memory pool in a stream-ordered fashion. The key difference is that it is a built-in feature implemented in the CUDA driver by NVIDIA, so other CUDA applications in the same processs can easily allocate memory from the same pool.
To enable a memory pool that manages stream ordered memory, you can construct a new MemoryAsyncPool
instance:
import cupy
# Use asynchronous stream ordered memory
cupy.cuda.set_allocator(cupy.cuda.MemoryAsyncPool().malloc)
# Create a custom stream
s = cupy.cuda.Stream()
# This would allocate memory asynchronously on stream s
with s:
a = cupy.empty((100,), dtype=cupy.float64)
Note that in this case we do not use the MemoryPool
class. The MemoryAsyncPool
takes
a different input argument from that of MemoryPool
to indicate which pool to use.
Please refer to MemoryAsyncPool
’s documentation for further detail.
Note that if you pass malloc_async()
directly to set_allocator()
without constructing
a MemoryAsyncPool
instance, the device’s current memory pool will be used.
When using stream ordered memory, it is important that you maintain a correct stream semantics yourselves using, for example,
the Stream
and Event
APIs (see Streams and Events for details); CuPy does not
attempt to act smartly for you. Upon deallocation, the memory is freed asynchronously either on the stream it was
allocated (first attempt), or on any current CuPy stream (second attempt). It is permitted that the stream on which the
memory was allocated gets destroyed before all memory allocated on it is freed.
In addition, applications/libraries internally use cudaMalloc
(CUDA’s default, synchronous allocator) could have unexpected
interplay with Stream Ordered Memory Allocator. Specifically, memory freed to the memory pool might not be immediately visible
to cudaMalloc
, leading to potential out-of-memory errors. In this case, you can either call free_all_blocks()
or just manually perform a (event/stream/device) synchronization, and retry.
Currently the MemoryAsyncPool
interface is experimental. In particular, while its API is largely identical
to that of MemoryPool
, several of the pool’s methods require a sufficiently new driver (and of course, a
supported hardware, CUDA version, and platform) due to CUDA’s limitation.
You can even disable the default memory pool by the code below. Be sure to do this before any other CuPy operations.
import cupy
# Disable memory pool for device memory (GPU)
cupy.cuda.set_allocator(None)
# Disable memory pool for pinned memory (CPU).
cupy.cuda.set_pinned_memory_allocator(None)
Performance Best Practices¶
Here we gather a few tricks and advices for improving CuPy’s performance.
Benchmarking¶
It is utterly important to first identify the performance bottleneck before making any attempt to optimize
your code. To help set up a baseline benchmark, CuPy provides a useful utility cupyx.profiler.benchmark()
for timing the elapsed time of a Python function on both CPU and GPU:
>>> from cupyx.profiler import benchmark
>>>
>>> def my_func(a):
... return cp.sqrt(cp.sum(a**2, axis=-1))
...
>>> a = cp.random.random((256, 1024))
>>> print(benchmark(my_func, (a,), n_repeat=20))
my_func : CPU: 44.407 us +/- 2.428 (min: 42.516 / max: 53.098) us GPU-0: 181.565 us +/- 1.853 (min: 180.288 / max: 188.608) us
Because GPU executions run asynchronously with respect to CPU executions, a common pitfall in GPU programming is to mistakenly
measure the elapsed time using CPU timing utilities (such as time.perf_counter()
from the Python Standard Library
or the %timeit
magic from IPython), which have no knowledge in the GPU runtime. cupyx.profiler.benchmark()
addresses
this by setting up CUDA events on the Current Stream right before and after the function to be measured and
synchronizing over the end event (see Streams and Events for detail). Below we sketch what is done internally in cupyx.profiler.benchmark()
:
>>> import time
>>> start_gpu = cp.cuda.Event()
>>> end_gpu = cp.cuda.Event()
>>>
>>> start_gpu.record()
>>> start_cpu = time.perf_counter()
>>> out = my_func(a)
>>> end_cpu = time.perf_counter()
>>> end_gpu.record()
>>> end_gpu.synchronize()
>>> t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu)
>>> t_cpu = end_cpu - start_cpu
Additionally, cupyx.profiler.benchmark()
runs a few warm-up runs to reduce timing fluctuation and exclude the overhead in first invocations.
One-Time Overheads¶
Be aware of these overheads when benchmarking CuPy code.
Context Initialization¶
It may take several seconds when calling a CuPy function for the first time in a process. This is because CUDA driver creates a CUDA context during the first CUDA API call in CUDA applications.
Kernel Compilation¶
CuPy uses on-the-fly kernel synthesis. When a kernel call is required, it compiles a kernel code optimized for the dimensions and dtypes of the given arguments, sends them to the GPU device, and executes the kernel.
CuPy caches the kernel code sent to GPU device within the process, which reduces the kernel compilation time on further calls.
The compiled code is also cached in the directory ${HOME}/.cupy/kernel_cache
(the path can be overwritten by setting the CUPY_CACHE_DIR
environment variable).
This allows reusing the compiled kernel binary across the process.
In-depth profiling¶
Under construction. To mark with NVTX/rocTX ranges, you can use the cupyx.profiler.time_range()
API. To start/stop the profiler, you can use the cupyx.profiler.profile()
API.
Use CUB/cuTENSOR backends for reduction and other routines¶
For reduction operations (such as sum()
, prod()
, amin()
, amax()
, argmin()
, argmax()
) and many more routines built upon them, CuPy ships with our own implementations so that things just work out of the box. However, there are dedicated efforts to further accelerate these routines, such as CUB and cuTENSOR.
In order to support more performant backends wherever applicable, starting v8 CuPy introduces an environment variable CUPY_ACCELERATORS
to allow users to specify the desired backends (and in what order they are tried). For example, consider summing over a 256-cubic array:
>>> from cupyx.profiler import benchmark
>>> a = cp.random.random((256, 256, 256), dtype=cp.float32)
>>> print(benchmark(a.sum, (), n_repeat=100))
sum : CPU: 12.101 us +/- 0.694 (min: 11.081 / max: 17.649) us GPU-0:10174.898 us +/-180.551 (min:10084.576 / max:10595.936) us
We can see that it takes about 10 ms to run (on this GPU). However, if we launch the Python session using CUPY_ACCELERATORS=cub python
, we get a ~100x speedup for free (only ~0.1 ms):
>>> print(benchmark(a.sum, (), n_repeat=100))
sum : CPU: 20.569 us +/- 5.418 (min: 13.400 / max: 28.439) us GPU-0: 114.740 us +/- 4.130 (min: 108.832 / max: 122.752) us
CUB is a backend shipped together with CuPy.
It also accelerates other routines, such as inclusive scans (ex: cumsum()
), histograms,
sparse matrix-vector multiplications (not applicable in CUDA 11), and ReductionKernel
.
cuTENSOR offers optimized performance for binary elementwise ufuncs, reduction and tensor contraction.
If cuTENSOR is installed, setting CUPY_ACCELERATORS=cub,cutensor
, for example, would try CUB first and fall back to cuTENSOR if CUB does not provide the needed support. In the case that both backends are not applicable, it falls back to CuPy’s default implementation.
Note that while in general the accelerated reductions are faster, there could be exceptions depending on the data layout. In particular, the CUB reduction only supports reduction over contiguous axes. In any case, we recommend to perform some benchmarks to determine whether CUB/cuTENSOR offers better performance or not.
Overlapping work using streams¶
Under construction.
Use JIT compiler¶
Under construction. For now please refer to JIT kernel definition for a quick introduction.
Prefer float32 over float64¶
Under construction.
Interoperability¶
CuPy can be used in conjunction with other libraries.
CUDA functionalities¶
Under construction. For using CUDA streams created in foreign libraries in CuPy, see Streams and Events.
NumPy¶
cupy.ndarray
implements __array_ufunc__
interface (see NEP 13 — A Mechanism for Overriding Ufuncs for details).
This enables NumPy ufuncs to be directly operated on CuPy arrays.
__array_ufunc__
feature requires NumPy 1.13 or later.
import cupy
import numpy
arr = cupy.random.randn(1, 2, 3, 4).astype(cupy.float32)
result = numpy.sum(arr)
print(type(result)) # => <class 'cupy._core.core.ndarray'>
cupy.ndarray
also implements __array_function__
interface (see NEP 18 — A dispatch mechanism for NumPy’s high level array functions for details).
This enables code using NumPy to be directly operated on CuPy arrays.
__array_function__
feature requires NumPy 1.16 or later; As of NumPy 1.17, __array_function__
is enabled by default.
Numba¶
Numba is a Python JIT compiler with NumPy support.
cupy.ndarray
implements __cuda_array_interface__
, which is the CUDA array interchange interface compatible with Numba v0.39.0 or later (see CUDA Array Interface for details).
It means you can pass CuPy arrays to kernels JITed with Numba.
The following is a simple example code borrowed from numba/numba#2860:
import cupy
from numba import cuda
@cuda.jit
def add(x, y, out):
start = cuda.grid(1)
stride = cuda.gridsize(1)
for i in range(start, x.shape[0], stride):
out[i] = x[i] + y[i]
a = cupy.arange(10)
b = a * 2
out = cupy.zeros_like(a)
print(out) # => [0 0 0 0 0 0 0 0 0 0]
add[1, 32](a, b, out)
print(out) # => [ 0 3 6 9 12 15 18 21 24 27]
In addition, cupy.asarray()
supports zero-copy conversion from Numba CUDA array to CuPy array.
import numpy
import numba
import cupy
x = numpy.arange(10) # type: numpy.ndarray
x_numba = numba.cuda.to_device(x) # type: numba.cuda.cudadrv.devicearray.DeviceNDArray
x_cupy = cupy.asarray(x_numba) # type: cupy.ndarray
Warning
__cuda_array_interface__
specifies that the object lifetime must be managed by the user, so it is an undefined behavior if the
exported object is destroyed while still in use by the consumer library.
Note
CuPy uses two environment variables controlling the exchange behavior: CUPY_CUDA_ARRAY_INTERFACE_SYNC
and CUPY_CUDA_ARRAY_INTERFACE_EXPORT_VERSION
.
mpi4py¶
MPI for Python (mpi4py) is a Python wrapper for the Message Passing Interface (MPI) libraries.
MPI is the most widely used standard for high-performance inter-process communications. Recently several MPI vendors, including MPICH, Open MPI and MVAPICH, have extended their support beyond the MPI-3.1 standard to enable “CUDA-awareness”; that is, passing CUDA device pointers directly to MPI calls to avoid explicit data movement between the host and the device.
With the __cuda_array_interface__
(as mentioned above) and DLPack
data exchange protocols (see DLPack below) implemented in CuPy, mpi4py now provides (experimental) support for passing CuPy arrays to MPI calls, provided that mpi4py is built against a CUDA-aware MPI implementation. The following is a simple example code borrowed from mpi4py Tutorial:
# To run this script with N MPI processes, do
# mpiexec -n N python this_script.py
import cupy
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
# Allreduce
sendbuf = cupy.arange(10, dtype='i')
recvbuf = cupy.empty_like(sendbuf)
comm.Allreduce(sendbuf, recvbuf)
assert cupy.allclose(recvbuf, sendbuf*size)
This new feature is added since mpi4py 3.1.0. See the mpi4py website for more information.
PyTorch¶
PyTorch is a machine learning framefork that provides high-performance, differentiable tensor operations.
PyTorch also supports __cuda_array_interface__
, so zero-copy data exchange between CuPy and PyTorch can be achieved at no cost.
The only caveat is PyTorch by default creates CPU tensors, which do not have the __cuda_array_interface__
property defined, and
users need to ensure the tensor is already on GPU before exchanging.
>>> import cupy as cp
>>> import torch
>>>
>>> # convert a torch tensor to a cupy array
>>> a = torch.rand((4, 4), device='cuda')
>>> b = cp.asarray(a)
>>> b *= b
>>> b
array([[0.8215962 , 0.82399917, 0.65607935, 0.30354425],
[0.422695 , 0.8367199 , 0.00208597, 0.18545236],
[0.00226746, 0.46201342, 0.6833052 , 0.47549972],
[0.5208748 , 0.6059282 , 0.1909013 , 0.5148635 ]], dtype=float32)
>>> a
tensor([[0.8216, 0.8240, 0.6561, 0.3035],
[0.4227, 0.8367, 0.0021, 0.1855],
[0.0023, 0.4620, 0.6833, 0.4755],
[0.5209, 0.6059, 0.1909, 0.5149]], device='cuda:0')
>>> # check the underlying memory pointer is the same
>>> assert a.__cuda_array_interface__['data'][0] == b.__cuda_array_interface__['data'][0]
>>>
>>> # convert a cupy array to a torch tensor
>>> a = cp.arange(10)
>>> b = torch.as_tensor(a, device='cuda')
>>> b += 3
>>> b
tensor([ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], device='cuda:0')
>>> a
array([ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> assert a.__cuda_array_interface__['data'][0] == b.__cuda_array_interface__['data'][0]
PyTorch also supports zero-copy data exchange through DLPack
(see DLPack below):
import cupy
import torch
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
# Create a PyTorch tensor.
tx1 = torch.randn(1, 2, 3, 4).cuda()
# Convert it into a DLPack tensor.
dx = to_dlpack(tx1)
# Convert it into a CuPy array.
cx = cupy.from_dlpack(dx)
# Convert it back to a PyTorch tensor.
tx2 = from_dlpack(cx.toDlpack())
pytorch-pfn-extras library provides additional integration features with PyTorch, including memory pool sharing and stream sharing:
>>> import cupy
>>> import torch
>>> import pytorch_pfn_extras as ppe
>>>
>>> # Perform CuPy memory allocation using the PyTorch memory pool.
>>> ppe.cuda.use_torch_mempool_in_cupy()
>>> torch.cuda.memory_allocated()
0
>>> arr = cupy.arange(10)
>>> torch.cuda.memory_allocated()
512
>>>
>>> # Change the default stream in PyTorch and CuPy:
>>> stream = torch.cuda.Stream()
>>> with ppe.cuda.stream(stream):
... ...
Using custom kernels in PyTorch¶
With the DLPack protocol, it becomes very simple to implement functions in PyTorch using CuPy user-defined kernels. Below is the example of a PyTorch autograd function
that computes the forward and backward pass of the logarithm using cupy.RawKernel
s.
import cupy
import torch
cupy_custom_kernel_fwd = cupy.RawKernel(
r"""
extern "C" __global__
void cupy_custom_kernel_fwd(const float* x, float* y, int size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < size)
y[tid] = log(x[tid]);
}
""",
"cupy_custom_kernel_fwd",
)
cupy_custom_kernel_bwd = cupy.RawKernel(
r"""
extern "C" __global__
void cupy_custom_kernel_bwd(const float* x, float* gy, float* gx, int size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < size)
gx[tid] = gy[tid] / x[tid];
}
""",
"cupy_custom_kernel_bwd",
)
class CuPyLog(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.input = x
# Enforce contiguous arrays to simplify RawKernel indexing.
cupy_x = cupy.ascontiguousarray(cupy.from_dlpack(x.detach()))
cupy_y = cupy.empty(cupy_x.shape, dtype=cupy_x.dtype)
x_size = cupy_x.size
bs = 128
cupy_custom_kernel_fwd(
(bs,), ((x_size + bs - 1) // bs,), (cupy_x, cupy_y, x_size)
)
# the ownership of the device memory backing cupy_y is implicitly
# transferred to torch_y, so this operation is safe even after
# going out of scope of this function.
torch_y = torch.from_dlpack(cupy_y)
return torch_y
@staticmethod
def backward(ctx, grad_y):
# Enforce contiguous arrays to simplify RawKernel indexing.
cupy_input = cupy.from_dlpack(ctx.input.detach()).ravel()
cupy_grad_y = cupy.from_dlpack(grad_y.detach()).ravel()
cupy_grad_x = cupy.zeros(cupy_grad_y.shape, dtype=cupy_grad_y.dtype)
gy_size = cupy_grad_y.size
bs = 128
cupy_custom_kernel_bwd(
(bs,),
((gy_size + bs - 1) // bs,),
(cupy_input, cupy_grad_y, cupy_grad_x, gy_size),
)
# the ownership of the device memory backing cupy_grad_x is implicitly
# transferred to torch_y, so this operation is safe even after
# going out of scope of this function.
torch_grad_x = torch.from_dlpack(cupy_grad_x)
return torch_grad_x
Note
Directly feeding a torch.Tensor
to cupy.from_dlpack()
is only supported in the (new) DLPack data exchange protocol added in CuPy v10+ and PyTorch 1.10+.
For earlier versions, you will need to wrap the Tensor
with torch.utils.dlpack.to_dlpack()
as shown in the above examples.
RMM¶
RMM (RAPIDS Memory Manager) provides highly configurable memory allocators.
RMM provides an interface to allow CuPy to allocate memory from the RMM memory pool instead of from CuPy’s own pool. It can be set up as simple as:
import cupy
import rmm
cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
Sometimes, a more performant allocator may be desirable. RMM provides an option to switch the allocator:
import cupy
import rmm
rmm.reinitialize(pool_allocator=True) # can also set init pool size etc here
cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
For more information on CuPy’s memory management, see Memory Management.
DLPack¶
DLPack is a specification of tensor structure to share tensors among frameworks.
CuPy supports importing from and exporting to DLPack data structure (cupy.from_dlpack()
and cupy.ndarray.toDlpack()
).
Here is a simple example:
import cupy
# Create a CuPy array.
cx1 = cupy.random.randn(1, 2, 3, 4).astype(cupy.float32)
# Convert it into a DLPack tensor.
dx = cx1.toDlpack()
# Convert it back to a CuPy array.
cx2 = cupy.from_dlpack(dx)
TensorFlow also supports DLpack, so zero-copy data exchange between CuPy and TensorFlow through DLPack is possible:
>>> import tensorflow as tf
>>> import cupy as cp
>>>
>>> # convert a TF tensor to a cupy array
>>> with tf.device('/GPU:0'):
... a = tf.random.uniform((10,))
...
>>> a
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([0.9672388 , 0.57568085, 0.53163004, 0.6536236 , 0.20479882,
0.84908986, 0.5852566 , 0.30355775, 0.1733712 , 0.9177849 ],
dtype=float32)>
>>> a.device
'/job:localhost/replica:0/task:0/device:GPU:0'
>>> cap = tf.experimental.dlpack.to_dlpack(a)
>>> b = cp.from_dlpack(cap)
>>> b *= 3
>>> b
array([1.4949363 , 0.60699713, 1.3276931 , 1.5781245 , 1.1914308 ,
2.3180873 , 1.9560868 , 1.3932796 , 1.9299742 , 2.5352407 ],
dtype=float32)
>>> a
<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([1.4949363 , 0.60699713, 1.3276931 , 1.5781245 , 1.1914308 ,
2.3180873 , 1.9560868 , 1.3932796 , 1.9299742 , 2.5352407 ],
dtype=float32)>
>>>
>>> # convert a cupy array to a TF tensor
>>> a = cp.arange(10)
>>> cap = a.toDlpack()
>>> b = tf.experimental.dlpack.from_dlpack(cap)
>>> b.device
'/job:localhost/replica:0/task:0/device:GPU:0'
>>> b
<tf.Tensor: shape=(10,), dtype=int64, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])>
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Be aware that in TensorFlow all tensors are immutable, so in the latter case any changes in b
cannot be reflected in the CuPy array a
.
Note that as of DLPack v0.5 for correctness the above approach (implicitly) requires users to ensure that such conversion (both importing and exporting a CuPy array) must happen on the same CUDA/HIP stream. If in doubt, the current CuPy stream in use can be fetched by, for example, calling cupy.cuda.get_current_stream()
. Please consult the other framework’s documentation for how to access and control the streams.
DLPack data exchange protocol¶
To obviate user-managed streams and DLPack tensor objects, the DLPack data exchange protocol provides a mechanism to shift the responsibility from users to libraries. Any compliant objects (such as cupy.ndarray
) must implement a pair of methods __dlpack__
and __dlpack_device__
. The function cupy.from_dlpack()
accepts such object and returns a cupy.ndarray
that is safely accessible on CuPy’s current stream. Likewise, cupy.ndarray
can be exported via any compliant library’s from_dlpack()
function.
Note
CuPy uses CUPY_DLPACK_EXPORT_VERSION
to control how to handle tensors backed by CUDA managed memory.
Difference between CuPy and NumPy¶
The interface of CuPy is designed to obey that of NumPy. However, there are some differences.
Cast behavior from float to integer¶
Some casting behaviors from float to integer are not defined in C++ specification. The casting from a negative float to unsigned integer and infinity to integer is one of such examples. The behavior of NumPy depends on your CPU architecture. This is the result on an Intel CPU:
>>> np.array([-1], dtype=np.float32).astype(np.uint32)
array([4294967295], dtype=uint32)
>>> cupy.array([-1], dtype=np.float32).astype(np.uint32)
array([0], dtype=uint32)
>>> np.array([float('inf')], dtype=np.float32).astype(np.int32)
array([-2147483648], dtype=int32)
>>> cupy.array([float('inf')], dtype=np.float32).astype(np.int32)
array([2147483647], dtype=int32)
Random methods support dtype argument¶
NumPy’s random value generator does not support a dtype argument and instead always returns a float64
value.
We support the option in CuPy because cuRAND, which is used in CuPy, supports both float32
and float64
.
>>> np.random.randn(dtype=np.float32)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: randn() got an unexpected keyword argument 'dtype'
>>> cupy.random.randn(dtype=np.float32)
array(0.10689262300729752, dtype=float32)
Out-of-bounds indices¶
CuPy handles out-of-bounds indices differently by default from NumPy when using integer array indexing. NumPy handles them by raising an error, but CuPy wraps around them.
>>> x = np.array([0, 1, 2])
>>> x[[1, 3]] = 10
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: index 3 is out of bounds for axis 1 with size 3
>>> x = cupy.array([0, 1, 2])
>>> x[[1, 3]] = 10
>>> x
array([10, 10, 2])
Duplicate values in indices¶
CuPy’s __setitem__
behaves differently from NumPy when integer arrays
reference the same location multiple times.
In that case, the value that is actually stored is undefined.
Here is an example of CuPy.
>>> a = cupy.zeros((2,))
>>> i = cupy.arange(10000) % 2
>>> v = cupy.arange(10000).astype(np.float32)
>>> a[i] = v
>>> a
array([ 9150., 9151.])
NumPy stores the value corresponding to the last element among elements referencing duplicate locations.
>>> a_cpu = np.zeros((2,))
>>> i_cpu = np.arange(10000) % 2
>>> v_cpu = np.arange(10000).astype(np.float32)
>>> a_cpu[i_cpu] = v_cpu
>>> a_cpu
array([9998., 9999.])
Zero-dimensional array¶
Reduction methods¶
NumPy’s reduction functions (e.g. numpy.sum()
) return scalar values (e.g. numpy.float32
).
However CuPy counterparts return zero-dimensional cupy.ndarray
s.
That is because CuPy scalar values (e.g. cupy.float32
) are aliases of NumPy scalar values and are allocated in CPU memory.
If these types were returned, it would be required to synchronize between GPU and CPU.
If you want to use scalar values, cast the returned arrays explicitly.
>>> type(np.sum(np.arange(3))) == np.int64
True
>>> type(cupy.sum(cupy.arange(3))) == cupy._core.core.ndarray
True
Type promotion¶
CuPy automatically promotes dtypes of cupy.ndarray
s in a function with two or more operands, the result dtype is determined by the dtypes of the inputs.
This is different from NumPy’s rule on type promotion, when operands contain zero-dimensional arrays.
Zero-dimensional numpy.ndarray
s are treated as if they were scalar values if they appear in operands of NumPy’s function,
This may affect the dtype of its output, depending on the values of the “scalar” inputs.
>>> (np.array(3, dtype=np.int32) * np.array([1., 2.], dtype=np.float32)).dtype
dtype('float32')
>>> (np.array(300000, dtype=np.int32) * np.array([1., 2.], dtype=np.float32)).dtype
dtype('float64')
>>> (cupy.array(3, dtype=np.int32) * cupy.array([1., 2.], dtype=np.float32)).dtype
dtype('float64')
Matrix type (numpy.matrix
)¶
SciPy returns numpy.matrix
(a subclass of numpy.ndarray
) when dense matrices are computed from sparse matrices (e.g., coo_matrix + ndarray
). However, CuPy returns cupy.ndarray
for such operations.
There is no plan to provide numpy.matrix
equivalent in CuPy.
This is because the use of numpy.matrix
is no longer recommended since NumPy 1.15.
Data types¶
Data type of CuPy arrays cannot be non-numeric like strings or objects. See Overview for details.
Universal Functions only work with CuPy array or scalar¶
Unlike NumPy, Universal Functions in CuPy only work with CuPy array or scalar.
They do not accept other objects (e.g., lists or numpy.ndarray
).
>>> np.power([np.arange(5)], 2)
array([[ 0, 1, 4, 9, 16]])
>>> cupy.power([cupy.arange(5)], 2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: Unsupported type <class 'list'>
Random seed arrays are hashed to scalars¶
Like Numpy, CuPy’s RandomState objects accept seeds either as numbers or as full numpy arrays.
>>> seed = np.array([1, 2, 3, 4, 5])
>>> rs = cupy.random.RandomState(seed=seed)
However, unlike Numpy, array seeds will be hashed down to a single number and so may not communicate as much entropy to the underlying random number generator.
NaN (not-a-number) handling¶
By default CuPy’s reduction functions (e.g., cupy.sum()
) handle NaNs in complex numbers differently from NumPy’s
counterparts:
>>> a = [0.5 + 3.7j, complex(0.7, np.nan), complex(np.nan, -3.9), complex(np.nan, np.nan)]
>>>
>>> a_np = np.asarray(a)
>>> print(a_np.max(), a_np.min())
(0.7+nanj) (0.7+nanj)
>>>
>>> a_cp = cp.asarray(a_np)
>>> print(a_cp.max(), a_cp.min())
(nan-3.9j) (nan-3.9j)
The reason is that internally the reduction is performed in a strided fashion, thus it does not ensure a proper comparison order and cannot follow NumPy’s rule to always propagate the first-encountered NaN.
API Compatibility Policy¶
This document expresses the design policy on compatibilities of CuPy APIs. Development team should obey this policy on deciding to add, extend, and change APIs and their behaviors.
This document is written for both users and developers. Users can decide the level of dependencies on CuPy’s implementations in their codes based on this document. Developers should read through this document before creating pull requests that contain changes on the interface. Note that this document may contain ambiguities on the level of supported compatibilities.
Versioning and Backward Compatibilities¶
The updates of CuPy are classified into three levels: major, minor, and revision. These types have distinct levels of backward compatibilities.
Major update contains disruptive changes that break the backward compatibility.
Minor update contains additions and extensions to the APIs that keep the backward compatibility supported.
Revision update contains improvements on the API implementations without changing any API specifications.
Note that we do not support full backward compatibility, which is almost infeasible for Python-based APIs, since there is no way to completely hide the implementation details.
Processes to Break Backward Compatibilities¶
Deprecation, Dropping, and Its Preparation¶
Any APIs may be deprecated at some minor updates. In such a case, the deprecation note is added to the API documentation, and the API implementation is changed to fire a deprecation warning (if possible). There should be another way to reimplement the same functionality previously written using the deprecated APIs.
Any APIs may be marked as to be dropped in the future. In such a case, the dropping is stated in the documentation with the major version number on which the API is planned to be dropped, and the API implementation is changed to fire a future warning (if possible).
The actual dropping should be done through the following steps:
Make the API deprecated. At this point, users should not use the deprecated API in their new application codes.
After that, mark the API as to be dropped in the future. It must be done in the minor update different from that of the deprecation.
At the major version announced in the above update, drop the API.
Consequently, it takes at least two minor versions to drop any APIs after the first deprecation.
API Changes and Its Preparation¶
Any APIs may be marked as to be changed in the future for changes without backward compatibility. In such a case, the change is stated in the documentation with the version number on which the API is planned to be changed, and the API implementation is changed to fire the future warning on the certain usages.
The actual change should be done in the following steps:
Announce that the API will be changed in the future. At this point, the actual version of change need not be accurate.
After the announcement, mark the API as to be changed in the future with version number of planned changes. At this point, users should not use the marked API in their new application codes.
At the major update announced in the above update, change the API.
Supported Backward Compatibility¶
This section defines backward compatibilities that minor updates must maintain.
Documented Interface¶
CuPy has an official API documentation. Many applications can be written based on the documented features. We support backward compatibilities of documented features. In other words, codes only based on the documented features run correctly with minor-/revision- updated versions.
Developers are encouraged to use apparent names for objects of implementation details. For example, attributes outside of the documented APIs should have one or more underscores at the prefix of their names.
Undocumented behaviors¶
Behaviors of CuPy implementation not stated in the documentation are undefined. Undocumented behaviors are not guaranteed to be stable between different minor/revision versions.
Minor update may contain changes to undocumented behaviors. For example, suppose an API X is added at the minor update. In the previous version, attempts to use X cause AttributeError. This behavior is not stated in the documentation, so this is undefined. Thus, adding the API X in minor version is permissible.
Revision update may also contain changes to undefined behaviors. Typical example is a bug fix. Another example is an improvement on implementation, which may change the internal object structures not shown in the documentation. As a consequence, even revision updates do not support compatibility of pickling, unless the full layout of pickled objects is clearly documented.
Documentation Error¶
Compatibility is basically determined based on the documentation, though it sometimes contains errors. It may make the APIs confusing to assume the documentation always stronger than the implementations. We therefore may fix the documentation errors in any updates that may break the compatibility in regard to the documentation.
Note
Developers MUST NOT fix the documentation and implementation of the same functionality at the same time in revision updates as “bug fix”. Such a change completely breaks the backward compatibility. If you want to fix the bugs in both sides, first fix the documentation to fit it into the implementation, and start the API changing procedure described above.
Object Attributes and Properties¶
Object attributes and properties are sometimes replaced by each other at minor updates. It does not break the user codes, except for the codes depending on how the attributes and properties are implemented.
Functions and Methods¶
Methods may be replaced by callable attributes keeping the compatibility of parameters and return values in minor updates. It does not break the user codes, except for the codes depending on how the methods and callable attributes are implemented.
Exceptions and Warnings¶
The specifications of raising exceptions are considered as a part of standard backward compatibilities. No exception is raised in the future versions with correct usages that the documentation allows, unless the API changing process is completed.
On the other hand, warnings may be added at any minor updates for any APIs. It means minor updates do not keep backward compatibility of warnings.
Installation Compatibility¶
The installation process is another concern of compatibilities. We support environmental compatibilities in the following ways.
Any changes of dependent libraries that force modifications on the existing environments must be done in major updates. Such changes include following cases:
dropping supported versions of dependent libraries (e.g. dropping cuDNN v2)
adding new mandatory dependencies (e.g. adding h5py to setup_requires)
Supporting optional packages/libraries may be done in minor updates (e.g. supporting h5py in optional features).
Note
The installation compatibility does not guarantee that all the features of CuPy correctly run on supported environments. It may contain bugs that only occurs in certain environments. Such bugs should be fixed in some updates.
API Reference¶
The N-dimensional array (ndarray
)¶
cupy.ndarray
is the CuPy counterpart of NumPy numpy.ndarray
.
It provides an intuitive interface for a fixed-size multidimensional array which resides
in a CUDA device.
For the basic concept of ndarray
s, please refer to the NumPy documentation.
|
Multi-dimensional array on a CUDA device. |
Conversion to/from NumPy arrays¶
cupy.ndarray
and numpy.ndarray
are not implicitly convertible to each other.
That means, NumPy functions cannot take cupy.ndarray
s as inputs, and vice versa.
To convert
numpy.ndarray
tocupy.ndarray
, usecupy.array()
orcupy.asarray()
.To convert
cupy.ndarray
tonumpy.ndarray
, usecupy.asnumpy()
orcupy.ndarray.get()
.
Note that converting between cupy.ndarray
and numpy.ndarray
incurs data transfer between
the host (CPU) device and the GPU device, which is costly in terms of performance.
|
Creates an array on the current device. |
|
Converts an object to array. |
|
Returns an array on the host memory from an arbitrary source array. |
Code compatibility features¶
cupy.ndarray
is designed to be interchangeable with numpy.ndarray
in terms of code compatibility as much as possible.
But occasionally, you will need to know whether the arrays you’re handling are cupy.ndarray
or numpy.ndarray
.
One example is when invoking module-level functions such as cupy.sum()
or numpy.sum()
.
In such situations, cupy.get_array_module()
can be used.
|
Returns the array module for arguments. |
|
Returns the array module for arguments. |
Universal functions (cupy.ufunc
)¶
CuPy provides universal functions (a.k.a. ufuncs) to support various elementwise operations. CuPy’s ufunc supports following features of NumPy’s one:
Broadcasting
Output type determination
Casting rules
CuPy’s ufunc currently does not provide methods such as reduce
, accumulate
, reduceat
, outer
, and at
.
Available ufuncs¶
Math operations¶
Adds two arrays elementwise. |
|
Subtracts arguments elementwise. |
|
Multiplies two arrays elementwise. |
|
matmul(x1, x2, /, out=None, **kwargs) |
|
Elementwise true division (i.e. |
|
Computes |
|
Computes |
|
Elementwise true division (i.e. |
|
Elementwise floor division (i.e. |
|
Takes numerical negative elementwise. |
|
Takes numerical positive elementwise. |
|
Computes |
|
Computes the remainder of Python division elementwise. |
|
Computes the remainder of Python division elementwise. |
|
Computes the remainder of C division elementwise. |
|
Elementwise absolute value function. |
|
Rounds each element of an array to the nearest integer. |
|
Elementwise sign function. |
|
Returns the complex conjugate, element-wise. |
|
Returns the complex conjugate, element-wise. |
|
Elementwise exponential function. |
|
Elementwise exponentiation with base 2. |
|
Elementwise natural logarithm function. |
|
Elementwise binary logarithm function. |
|
Elementwise common logarithm function. |
|
Computes |
|
Computes |
|
Elementwise square root function. |
|
Elementwise square function. |
|
Elementwise cube root function. |
|
Computes |
|
Computes gcd of |
|
Computes lcm of |
Trigonometric functions¶
Elementwise sine function. |
|
Elementwise cosine function. |
|
Elementwise tangent function. |
|
Elementwise inverse-sine function (a.k.a. |
|
Elementwise inverse-cosine function (a.k.a. |
|
Elementwise inverse-tangent function (a.k.a. |
|
Elementwise inverse-tangent of the ratio of two arrays. |
|
Computes the hypoteneous of orthogonal vectors of given length. |
|
Elementwise hyperbolic sine function. |
|
Elementwise hyperbolic cosine function. |
|
Elementwise hyperbolic tangent function. |
|
Elementwise inverse of hyperbolic sine function. |
|
Elementwise inverse of hyperbolic cosine function. |
|
Elementwise inverse of hyperbolic tangent function. |
|
Converts angles from radians to degrees elementwise. |
|
Converts angles from degrees to radians elementwise. |
|
Converts angles from degrees to radians elementwise. |
|
Converts angles from radians to degrees elementwise. |
Bit-twiddling functions¶
Computes the bitwise AND of two arrays elementwise. |
|
Computes the bitwise OR of two arrays elementwise. |
|
Computes the bitwise XOR of two arrays elementwise. |
|
Computes the bitwise NOT of an array elementwise. |
|
Shifts the bits of each integer element to the left. |
|
Shifts the bits of each integer element to the right. |
Comparison functions¶
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
|
Computes the logical AND of two arrays. |
|
Computes the logical OR of two arrays. |
|
Computes the logical XOR of two arrays. |
|
Computes the logical NOT of an array. |
|
Takes the maximum of two arrays elementwise. |
|
Takes the minimum of two arrays elementwise. |
|
Takes the maximum of two arrays elementwise. |
|
Takes the minimum of two arrays elementwise. |
Floating functions¶
Tests finiteness elementwise. |
|
Tests if each element is the positive or negative infinity. |
|
Tests if each element is a NaN. |
|
Tests elementwise if the sign bit is set (i.e. |
|
Returns the first argument with the sign bit of the second elementwise. |
|
Computes the nearest neighbor float values towards the second argument. |
|
Extracts the fractional and integral parts of an array elementwise. |
|
Computes |
|
Decomposes each element to mantissa and two’s exponent. |
|
Computes the remainder of C division elementwise. |
|
Rounds each element of an array to its floor integer. |
|
Rounds each element of an array to its ceiling integer. |
|
Rounds each element of an array towards zero. |
ufunc.at¶
Currently, CuPy does not support at
for ufuncs in general.
However, cupyx.scatter_add()
can substitute add.at
as both behave identically.
Generalized Universal Functions¶
In addition to regular ufuncs, CuPy also provides a wrapper class to convert
regular cupy functions into Generalized Universal Functions as in NumPy https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html.
This allows to automatically use keyword arguments such as axes
, order
, dtype
without needing to explicitly implement them in the wrapped function.
|
Creates a Generalized Universal Function by wrapping a user provided function with the signature. |
Routines (NumPy)¶
The following pages describe NumPy-compatible routines. These functions cover a subset of NumPy routines.
Array creation routines¶
Ones and zeros¶
|
Returns an array without initializing the elements. |
|
Returns a new array with same shape and dtype of a given array. |
|
Returns a 2-D array with ones on the diagonals and zeros elsewhere. |
|
Returns a 2-D identity array. |
|
Returns a new array of given shape and dtype, filled with ones. |
|
Returns an array of ones with same shape and dtype as a given array. |
|
Returns a new array of given shape and dtype, filled with zeros. |
|
Returns an array of zeros with same shape and dtype as a given array. |
|
Returns a new array of given shape and dtype, filled with a given value. |
|
Returns a full array with same shape and dtype as a given array. |
From existing data¶
|
Creates an array on the current device. |
|
Converts an object to array. |
|
Converts an object to array. |
|
Returns a C-contiguous array. |
|
Creates a copy of a given array on the current device. |
|
Interpret a buffer as a 1-dimensional array. |
|
Reads an array from a file. |
|
Construct an array by executing a function over each coordinate. |
|
Create a new 1-dimensional array from an iterable object. |
|
A new 1-D array initialized from text data in a string. |
|
Load data from a text file. |
Numerical ranges¶
|
Returns an array with evenly spaced values within a given interval. |
|
Returns an array with evenly-spaced values within a given interval. |
|
Returns an array with evenly-spaced values on a log-scale. |
|
Return coordinate matrices from coordinate vectors. |
Construct a multi-dimensional “meshgrid”. |
|
Construct a multi-dimensional “meshgrid”. |
Building matrices¶
|
Returns a diagonal or a diagonal array. |
|
Creates a diagonal array from the flattened input. |
|
Creates an array with ones at and below the given diagonal. |
|
Returns a lower triangle of an array. |
|
Returns an upper triangle of an array. |
Array manipulation routines¶
Basic operations¶
|
Copies values from one array to another with broadcasting. |
|
Returns the shape of an array |
Changing array shape¶
|
Returns an array with new shape and same elements. |
|
Returns a flattened array. |
See also
Transpose-like operations¶
|
Moves axes of an array to new positions. |
|
Moves the specified axis backwards to the given place. |
|
Swaps the two axes. |
|
Permutes the dimensions of an array. |
See also
Changing number of dimensions¶
|
Converts arrays to arrays with dimensions >= 1. |
|
Converts arrays to arrays with dimensions >= 2. |
|
Converts arrays to arrays with dimensions >= 3. |
|
Object that performs broadcasting. |
|
Broadcast an array to a given shape. |
|
Broadcasts given arrays. |
|
Expands given arrays. |
|
Removes size-one axes from the shape of an array. |
Changing kind of array¶
|
Converts an object to array. |
|
Converts an object to array. |
|
Return an array laid out in Fortran order in memory. |
|
Returns a C-contiguous array. |
|
Return an array which satisfies the requirements. |
Joining arrays¶
|
Joins arrays along an axis. |
|
Stacks arrays along a new axis. |
|
Stacks arrays vertically. |
|
Stacks arrays horizontally. |
|
Stacks arrays along the third axis. |
|
Stacks 1-D and 2-D arrays as columns into a 2-D array. |
Splitting arrays¶
|
Splits an array into multiple sub arrays along a given axis. |
|
Splits an array into multiple sub arrays along a given axis. |
|
Splits an array into multiple sub arrays along the third axis. |
|
Splits an array into multiple sub arrays horizontally. |
|
Splits an array into multiple sub arrays along the first axis. |
Tiling arrays¶
|
Construct an array by repeating A the number of times given by reps. |
|
Repeat arrays along an axis. |
Adding and removing elements¶
|
Append values to the end of an array. |
|
Return a new array with the specified shape. |
|
Find the unique elements of an array. |
|
Trim the leading and/or trailing zeros from a 1-D array or sequence. |
Rearranging elements¶
|
Reverse the order of elements in an array along the given axis. |
|
Flip array in the left/right direction. |
|
Flip array in the up/down direction. |
|
Returns an array with new shape and same elements. |
|
Roll array elements along a given axis. |
|
Rotate an array by 90 degrees in the plane specified by axes. |
Binary operations¶
Elementwise bit operations¶
Computes the bitwise AND of two arrays elementwise. |
|
Computes the bitwise OR of two arrays elementwise. |
|
Computes the bitwise XOR of two arrays elementwise. |
|
Computes the bitwise NOT of an array elementwise. |
|
Shifts the bits of each integer element to the left. |
|
Shifts the bits of each integer element to the right. |
Bit packing¶
|
Packs the elements of a binary-valued array into bits in a uint8 array. |
|
Unpacks elements of a uint8 array into a binary-valued output array. |
Output formatting¶
|
Return the binary representation of the input number as a string. |
Data type routines¶
|
Returns True if cast between data types can occur according to the casting rule. |
|
Returns the type that results from applying the NumPy type promotion rules to the arguments. |
|
Return a scalar type which is common to the input arrays. |
|
|
|
Creating data types¶
|
|
Data type information¶
|
|
|
Data type testing¶
|
|
|
|
|
Miscellaneous¶
|
|
|
Discrete Fourier Transform (cupy.fft
)¶
Standard FFTs¶
|
Compute the one-dimensional FFT. |
|
Compute the one-dimensional inverse FFT. |
|
Compute the two-dimensional FFT. |
|
Compute the two-dimensional inverse FFT. |
|
Compute the N-dimensional FFT. |
|
Compute the N-dimensional inverse FFT. |
Real FFTs¶
|
Compute the one-dimensional FFT for real input. |
|
Compute the one-dimensional inverse FFT for real input. |
|
Compute the two-dimensional FFT for real input. |
|
Compute the two-dimensional inverse FFT for real input. |
|
Compute the N-dimensional FFT for real input. |
|
Compute the N-dimensional inverse FFT for real input. |
Hermitian FFTs¶
|
Compute the FFT of a signal that has Hermitian symmetry. |
|
Compute the FFT of a signal that has Hermitian symmetry. |
Helper routines¶
|
Return the FFT sample frequencies. |
|
Return the FFT sample frequencies for real input. |
|
Shift the zero-frequency component to the center of the spectrum. |
|
The inverse of |
CuPy-specific APIs¶
See the description below for details.
A context manager for setting up load and/or store callbacks. |
|
|
Set the GPUs to be used in multi-GPU FFT. |
Get the per-thread, per-device plan cache, or create one if not found. |
|
Show all of the plan caches’ info on this thread. |
Normalization¶
The default normalization (norm
is "backward"
or None
) has the direct transforms unscaled and the inverse transforms scaled by \(1/n\).
If the keyword argument norm
is "forward"
, it is the exact opposite of "backward"
:
the direct transforms are scaled by \(1/n\) and the inverse transforms are unscaled.
Finally, if the keyword argument norm
is "ortho"
, both transforms are scaled by \(1/\sqrt{n}\).
Code compatibility features¶
FFT functions of NumPy always return numpy.ndarray which type is numpy.complex128
or numpy.float64
.
CuPy functions do not follow the behavior, they will return numpy.complex64
or numpy.float32
if the type of the input is numpy.float16
, numpy.float32
, or numpy.complex64
.
Internally, cupy.fft
always generates a cuFFT plan (see the cuFFT documentation for detail) corresponding to the desired transform. When possible, an n-dimensional plan will be used, as opposed to applying separate 1D plans for each axis to be transformed. Using n-dimensional planning can provide better performance for multidimensional transforms, but requires more GPU memory than separable 1D planning. The user can disable n-dimensional planning by setting cupy.fft.config.enable_nd_planning = False
. This ability to adjust the planning type is a deviation from the NumPy API, which does not use precomputed FFT plans.
Moreover, the automatic plan generation can be suppressed by using an existing plan returned by cupyx.scipy.fftpack.get_fft_plan()
as a context manager. This is again a deviation from NumPy.
Finally, when using the high-level NumPy-like FFT APIs as listed above, internally the cuFFT plans are cached for possible reuse. The plan cache can be retrieved by get_plan_cache()
, and its current status can be queried by show_plan_cache_info()
. For finer control of the plan cache, see cuFFT Plan Cache.
Multi-GPU FFT¶
cupy.fft
can use multiple GPUs. To enable (disable) this feature, set cupy.fft.config.use_multi_gpus
to True
(False
). Next, to set the number of GPUs or the participating GPU IDs, use the function cupy.fft.config.set_cufft_gpus()
. All of the limitations listed in the cuFFT documentation apply here. In particular, using more than one GPU does not guarantee better performance.
Functional programming¶
Note
cupy.vectorize
applies JIT compiler to the given Python function.
See JIT kernel definition for details.
|
Apply a function to 1-D slices along the given axis. |
|
Generalized function class. |
|
Evaluate a piecewise-defined function. |
Indexing routines¶
Generating index arrays¶
|
Return the indices of the elements that are non-zero. |
|
Return elements, either from x or y, depending on condition. |
|
Returns an array representing the indices of a grid. |
|
Construct an open mesh from multiple sequences. |
|
Converts a tuple of index arrays into an array of flat indices, applying boundary modes to the multi-index. |
|
Converts array of flat indices into a tuple of coordinate arrays. |
|
Return the indices to access the main diagonal of an array. |
|
Return the indices to access the main diagonal of an n-dimensional array. |
Indexing-like operations¶
|
Takes elements of an array at specified indices along an axis. |
|
Take values from the input array by matching 1d index and data slices. |
|
|
|
Returns selected slices of an array along given axis. |
|
Returns a diagonal or a diagonal array. |
|
Returns specified diagonals. |
|
Return an array drawn from elements in choicelist, depending on conditions. |
|
Create a view into the array with the given shape and strides. |
Inserting data into arrays¶
|
Change elements of an array based on conditional and input values. |
|
Replaces specified elements of an array with given values. |
|
Changes elements of an array inplace, based on a conditional mask and input values. |
|
Fills the main diagonal of the given array of any dimensionality. |
Input and output¶
NumPy binary files (NPY, NPZ)¶
|
Loads arrays or pickled objects from |
|
Saves an array to a binary file in |
|
Saves one or more arrays into a file in uncompressed |
|
Saves one or more arrays into a file in compressed |
Text files¶
|
Load data from a text file. |
|
Save an array to a text file. |
|
Load data from text file, with missing values handled as specified. |
|
A new 1-D array initialized from text data in a string. |
String formatting¶
|
Return a string representation of an array. |
|
Returns the string representation of an array. |
|
Returns the string representation of the content of an array. |
Base-n representations¶
|
Return the binary representation of the input number as a string. |
|
Return a string representation of a number in the given base system. |
Linear algebra (cupy.linalg
)¶
See also
Matrix and vector products¶
|
Returns a dot product of two arrays. |
|
Returns the dot product of two vectors. |
|
Returns the inner product of two arrays. |
|
Returns the outer product of two vectors. |
matmul(x1, x2, /, out=None, **kwargs) |
|
|
Returns the tensor dot product of two arrays along specified axes. |
|
Evaluates the Einstein summation convention on the operands. |
|
Raise a square matrix to the (integer) power n. |
|
Returns the kronecker product of two arrays. |
Decompositions¶
Cholesky decomposition. |
|
|
QR decomposition. |
|
Singular Value Decomposition. |
Matrix eigenvalues¶
|
Return the eigenvalues and eigenvectors of a complex Hermitian (conjugate symmetric) or a real symmetric matrix. |
|
Compute the eigenvalues of a complex Hermitian or real symmetric matrix. |
Norms and other numbers¶
|
Returns one of matrix norms specified by |
|
Returns the determinant of an array. |
|
Return matrix rank of array using SVD method |
Returns sign and logarithm of the determinant of an array. |
|
|
Returns the sum along the diagonals of an array. |
Solving equations and inverting matrices¶
|
Solves a linear matrix equation. |
|
Solves tensor equations denoted by |
|
Return the least-squares solution to a linear matrix equation. |
|
Computes the inverse of a matrix. |
|
Compute the Moore-Penrose pseudoinverse of a matrix. |
|
Computes the inverse of a tensor. |
Logic functions¶
Truth value testing¶
|
Tests whether all array elements along a given axis evaluate to True. |
|
Tests whether any array elements along a given axis evaluate to True. |
Array contents¶
Tests finiteness elementwise. |
|
Tests if each element is the positive or negative infinity. |
|
Tests if each element is a NaN. |
Array type testing¶
|
Returns a bool array, where True if input element is complex. |
|
Check for a complex type or an array of complex numbers. |
|
Returns True if the array is Fortran contiguous but not C contiguous. |
|
Returns a bool array, where True if input element is real. |
|
Return True if x is a not complex type or an array of complex numbers. |
|
Returns True if the type of num is a scalar type. |
Logic operations¶
Computes the logical AND of two arrays. |
|
Computes the logical OR of two arrays. |
|
Computes the logical NOT of an array. |
|
Computes the logical XOR of two arrays. |
Comparison¶
|
Returns True if two arrays are element-wise equal within a tolerance. |
|
Returns a boolean array where two arrays are equal within a tolerance. |
|
Returns |
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
|
Tests elementwise if |
Mathematical functions¶
Trigonometric functions¶
Elementwise sine function. |
|
Elementwise cosine function. |
|
Elementwise tangent function. |
|
Elementwise inverse-sine function (a.k.a. |
|
Elementwise inverse-cosine function (a.k.a. |
|
Elementwise inverse-tangent function (a.k.a. |
|
Computes the hypoteneous of orthogonal vectors of given length. |
|
Elementwise inverse-tangent of the ratio of two arrays. |
|
Converts angles from radians to degrees elementwise. |
|
Converts angles from degrees to radians elementwise. |
|
|
Unwrap by taking the complement of large deltas w.r.t. |
Converts angles from degrees to radians elementwise. |
|
Converts angles from radians to degrees elementwise. |
Hyperbolic functions¶
Elementwise hyperbolic sine function. |
|
Elementwise hyperbolic cosine function. |
|
Elementwise hyperbolic tangent function. |
|
Elementwise inverse of hyperbolic sine function. |
|
Elementwise inverse of hyperbolic cosine function. |
|
Elementwise inverse of hyperbolic tangent function. |
Rounding¶
|
Rounds to the given number of decimals. |
|
|
Rounds each element of an array to the nearest integer. |
|
If given value x is positive, it return floor(x). |
|
Rounds each element of an array to its floor integer. |
|
Rounds each element of an array to its ceiling integer. |
|
Rounds each element of an array towards zero. |
Sums, products, differences¶
|
Returns the product of an array along given axes. |
|
Returns the sum of an array along given axes. |
|
Returns the product of an array along given axes treating Not a Numbers (NaNs) as zero. |
|
Returns the sum of an array along given axes treating Not a Numbers (NaNs) as zero. |
|
Returns the cumulative product of an array along a given axis. |
|
Returns the cumulative sum of an array along a given axis. |
|
Returns the cumulative product of an array along a given axis treating Not a Numbers (NaNs) as one. |
|
Returns the cumulative sum of an array along a given axis treating Not a Numbers (NaNs) as zero. |
|
Calculate the n-th discrete difference along the given axis. |
|
Return the gradient of an N-dimensional array. |
|
Returns the cross product of two vectors. |
Exponents and logarithms¶
Elementwise exponential function. |
|
Computes |
|
Elementwise exponentiation with base 2. |
|
Elementwise natural logarithm function. |
|
Elementwise common logarithm function. |
|
Elementwise binary logarithm function. |
|
Computes |
|
Computes |
|
Computes |
Other special functions¶
Modified Bessel function of the first kind, order 0. |
|
Elementwise sinc function. |
Floating point routines¶
Tests elementwise if the sign bit is set (i.e. |
|
Returns the first argument with the sign bit of the second elementwise. |
|
Decomposes each element to mantissa and two’s exponent. |
|
Computes |
|
Computes the nearest neighbor float values towards the second argument. |
Rational routines¶
Computes lcm of |
|
Computes gcd of |
Arithmetic operations¶
Adds two arrays elementwise. |
|
Computes |
|
Takes numerical positive elementwise. |
|
Takes numerical negative elementwise. |
|
Multiplies two arrays elementwise. |
|
Elementwise true division (i.e. |
|
Computes |
|
Subtracts arguments elementwise. |
|
Elementwise true division (i.e. |
|
Elementwise floor division (i.e. |
|
Computes the remainder of C division elementwise. |
|
Computes the remainder of Python division elementwise. |
|
Extracts the fractional and integral parts of an array elementwise. |
|
Computes the remainder of Python division elementwise. |
|
Handling complex numbers¶
Returns the angle of the complex argument. |
|
|
Returns the real part of the elements of the array. |
|
Returns the imaginary part of the elements of the array. |
Returns the complex conjugate, element-wise. |
|
Returns the complex conjugate, element-wise. |
Miscellaneous¶
|
Returns the discrete, linear convolution of two one-dimensional sequences. |
|
Clips the values of an array to a given interval. |
Elementwise square root function. |
|
Elementwise cube root function. |
|
Elementwise square function. |
|
Elementwise absolute value function. |
|
Elementwise sign function. |
|
Takes the maximum of two arrays elementwise. |
|
Takes the minimum of two arrays elementwise. |
|
Takes the maximum of two arrays elementwise. |
|
Takes the minimum of two arrays elementwise. |
|
|
Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the nan, posinf and/or neginf keywords. |
|
One-dimensional linear interpolation. |
Miscellaneous routines¶
Memory ranges¶
|
|
|
Utility¶
|
Prints the current runtime configuration to standard output. |
Polynomials¶
Power Series (cupy.polynomial.polynomial
)¶
Misc Functions¶
|
Computes the Vandermonde matrix of given degree. |
Computes the companion matrix of c. |
Polyutils¶
Poly1d¶
Basics¶
|
A one-dimensional polynomial class. |
|
Evaluates a polynomial at specific values. |
|
Computes the roots of a polynomial with given coefficients. |
Random sampling (cupy.random
)¶
Differences between cupy.random
and numpy.random
:
Most functions under
cupy.random
support thedtype
option, which do not exist in the corresponding NumPy APIs. This option enables generation of float32 values directly without any space overhead.cupy.random.default_rng()
uses XORWOW bit generator by default.Random states cannot be serialized. See the description below for details.
CuPy does not guarantee that the same number generator is used across major versions. This means that numbers generated by
cupy.random
by new major version may not be the same as the previous one, even if the same seed and distribution are used.
New Random Generator API¶
Random Generator¶
|
Construct a new Generator with the default BitGenerator (XORWOW). |
|
Container for the BitGenerators. |
Bit Generators¶
|
Generic BitGenerator. |
CuPy provides the following bit generator implementations:
|
BitGenerator that uses cuRAND XORWOW device generator. |
|
BitGenerator that uses cuRAND MRG32k3a device generator. |
|
BitGenerator that uses cuRAND Philox4x3210 device generator. |
Legacy Random Generation¶
|
Portable container of a pseudo-random number generator. |
Functions in cupy.random
¶
|
Beta distribution. |
|
Binomial distribution. |
|
Returns random bytes. |
|
Chi-square distribution. |
|
Returns an array of random values from a given 1-D array. |
|
Dirichlet distribution. |
|
Exponential distribution. |
|
F distribution. |
|
Gamma distribution. |
|
Geometric distribution. |
|
Returns an array of samples drawn from a Gumbel distribution. |
|
hypergeometric distribution. |
|
Laplace distribution. |
|
Logistic distribution. |
|
Returns an array of samples drawn from a log normal distribution. |
|
Log series distribution. |
|
Returns an array from multinomial distribution. |
|
Multivariate normal distribution. |
|
Negative binomial distribution. |
|
Noncentral chisquare distribution. |
|
Noncentral F distribution. |
|
Returns an array of normally distributed samples. |
|
Pareto II or Lomax distribution. |
|
Returns a permuted range or a permutation of an array. |
|
Poisson distribution. |
|
Power distribution. |
|
Returns an array of uniform random values over the interval |
|
Returns a scalar or an array of integer values over |
|
Returns an array of standard normal random values. |
|
Returns an array of random values over the interval |
|
Return a scalar or an array of integer values over |
|
Returns an array of random values over the interval |
|
Returns an array of random values over the interval |
|
Rayleigh distribution. |
|
Returns an array of random values over the interval |
|
Resets the state of the random number generator with a seed. |
|
Shuffles an array. |
|
Standard cauchy distribution. |
|
Standard exponential distribution. |
|
Standard gamma distribution. |
|
Returns an array of samples drawn from the standard normal distribution. |
|
Standard Student’s t distribution. |
|
Triangular distribution. |
|
Returns an array of uniformly-distributed samples over an interval. |
|
von Mises distribution. |
|
Wald distribution. |
|
weibull distribution. |
|
Zipf distribution. |
CuPy does not provide cupy.random.get_state
nor cupy.random.set_state
at this time.
Use the following CuPy-specific APIs instead.
Note that these functions use cupy.random.RandomState
instance to represent the internal state, which cannot be serialized.
Gets the state of the random number generator for the current device. |
|
|
Sets the state of the random number generator for the current device. |
Set routines¶
Sorting, searching, and counting¶
Sorting¶
|
Returns a sorted copy of an array with a stable sorting algorithm. |
|
Perform an indirect sort using an array of keys. |
|
Returns the indices that would sort an array with a stable sorting. |
|
Returns a copy of an array sorted along the first axis. |
|
Sort a complex array using the real part first, then the imaginary part. |
|
Returns a partitioned copy of an array. |
|
Returns the indices that would partially sort an array. |
See also
Searching¶
|
Returns the indices of the maximum along an axis. |
|
Return the indices of the maximum values in the specified axis ignoring NaNs. |
|
Returns the indices of the minimum along an axis. |
|
Return the indices of the minimum values in the specified axis ignoring NaNs. |
|
Return the indices of the elements that are non-zero. |
|
Return the indices of the elements that are non-zero. |
|
Return indices that are non-zero in the flattened version of a. |
|
Return elements, either from x or y, depending on condition. |
|
Finds indices where elements should be inserted to maintain order. |
|
Return the elements of an array that satisfy some condition. |
Counting¶
|
Counts the number of non-zero values in the array. |
Statistics¶
Order statistics¶
|
Returns the minimum of an array or the minimum along an axis. |
|
Returns the maximum of an array or the maximum along an axis. |
|
Returns the minimum of an array along an axis ignoring NaN. |
|
Returns the maximum of an array along an axis ignoring NaN. |
|
Returns the range of values (maximum - minimum) along an axis. |
|
Computes the q-th percentile of the data along the specified axis. |
|
Computes the q-th quantile of the data along the specified axis. |
Averages and variances¶
|
Compute the median along the specified axis. |
|
Returns the weighted average along an axis. |
|
Returns the arithmetic mean along an axis. |
|
Returns the standard deviation along an axis. |
|
Returns the variance along an axis. |
|
Compute the median along the specified axis, while ignoring NaNs. |
|
Returns the arithmetic mean along an axis ignoring NaN values. |
|
Returns the standard deviation along an axis ignoring NaN values. |
|
Returns the variance along an axis ignoring NaN values. |
Correlations¶
|
Returns the Pearson product-moment correlation coefficients of an array. |
|
Returns the cross-correlation of two 1-dimensional sequences. |
|
Returns the covariance matrix of an array. |
Histograms¶
|
Computes the histogram of a set of data. |
|
Compute the bi-dimensional histogram of two data samples. |
|
Compute the multidimensional histogram of some data. |
|
Count number of occurrences of each value in array of non-negative ints. |
|
Finds the indices of the bins to which each value in input array belongs. |
Test support (cupy.testing
)¶
Asserts¶
Hint
These APIs can accept both numpy.ndarray
and cupy.ndarray
.
|
Raises an AssertionError if objects are not equal up to desired precision. |
|
Raises an AssertionError if objects are not equal up to desired tolerance. |
|
Compare two arrays relatively to their spacing. |
|
Check that all items of arrays differ in at most N Units in the Last Place. |
|
Raises an AssertionError if two array_like objects are not equal. |
|
Raises an AssertionError if array_like objects are not ordered by less than. |
CuPy-specific APIs¶
Asserts¶
|
Compares lists of arrays pairwise with |
NumPy-CuPy Consistency Check¶
The following decorators are for testing consistency between CuPy’s functions and corresponding NumPy’s ones.
|
Decorator that checks NumPy results and CuPy ones are close. |
|
Decorator that checks NumPy results and CuPy ones are almost equal. |
|
Decorator that checks results of NumPy and CuPy are equal w.r.t. |
|
Decorator that checks results of NumPy and CuPy ones are equal w.r.t. |
|
Decorator that checks NumPy results and CuPy ones are equal. |
|
Decorator that checks the resulting lists of NumPy and CuPy’s one are equal. |
|
Decorator that checks the CuPy result is less than NumPy result. |
Parameterized dtype Test¶
The following decorators offer the standard way for parameterized test with respect to single or the combination of dtype(s).
|
Decorator for parameterized dtype test. |
|
Decorator that checks the fixture with all dtypes. |
|
Decorator that checks the fixture with float dtypes. |
|
Decorator that checks the fixture with signed dtypes. |
|
Decorator that checks the fixture with unsinged dtypes. |
|
Decorator that checks the fixture with integer and optionally bool dtypes. |
|
Decorator that checks the fixture with complex dtypes. |
|
Decorator that checks the fixture with a product set of dtypes. |
|
Decorator that checks the fixture with a product set of all dtypes. |
|
Decorator for parameterized test w.r.t. |
|
Decorator for parameterized test w.r.t. |
|
Decorator for parameterized test w.r.t. |
Parameterized order Test¶
The following decorators offer the standard way to parameterize tests with orders.
|
Decorator to parameterize tests with order. |
|
Decorator that checks the fixture with orders ‘C’ and ‘F’. |
Routines (SciPy)¶
The following pages describe SciPy-compatible routines. These functions cover a subset of SciPy routines.
Discrete Fourier transforms (cupyx.scipy.fft
)¶
See also
Fast Fourier Transforms (FFTs)¶
|
Compute the one-dimensional FFT. |
|
Compute the one-dimensional inverse FFT. |
|
Compute the two-dimensional FFT. |
|
Compute the two-dimensional inverse FFT. |
|
Compute the N-dimensional FFT. |
|
Compute the N-dimensional inverse FFT. |
|
Compute the one-dimensional FFT for real input. |
|
Compute the one-dimensional inverse FFT for real input. |
|
Compute the two-dimensional FFT for real input. |
|
Compute the two-dimensional inverse FFT for real input. |
|
Compute the N-dimensional FFT for real input. |
|
Compute the N-dimensional inverse FFT for real input. |
|
Compute the FFT of a signal that has Hermitian symmetry. |
|
Compute the FFT of a signal that has Hermitian symmetry. |
|
Compute the FFT of a two-dimensional signal that has Hermitian symmetry. |
|
Compute the Inverse FFT of a two-dimensional signal that has Hermitian symmetry. |
|
Compute the FFT of a N-dimensional signal that has Hermitian symmetry. |
|
Compute the Inverse FFT of a N-dimensional signal that has Hermitian symmetry. |
Helper functions¶
|
Shift the zero-frequency component to the center of the spectrum. |
|
The inverse of |
|
Return the FFT sample frequencies. |
|
Return the FFT sample frequencies for real input. |
|
Find the next fast size to |
Code compatibility features¶
As with other FFT modules in CuPy, FFT functions in this module can take advantage of an existing cuFFT plan (returned by
get_fft_plan()
) to accelarate the computation. The plan can be either passed in explicitly via the keyword-onlyplan
argument or used as a context manager.The boolean switch
cupy.fft.config.enable_nd_planning
also affects the FFT functions in this module, see Discrete Fourier Transform (cupy.fft). This switch is neglected when planning manually usingget_fft_plan()
.Like in
scipy.fft
, all FFT functions in this module have an optional argumentoverwrite_x
(default isFalse
), which has the same semantics as inscipy.fft
: when it is set toTrue
, the input arrayx
can (not will) be overwritten arbitrarily. For this reason, when an in-place FFT is desired, the user should always reassign the input in the following manner:x = cupyx.scipy.fftpack.fft(x, ..., overwrite_x=True, ...)
.The
cupyx.scipy.fft
module can also be used as a backend forscipy.fft
e.g. by installing withscipy.fft.set_backend(cupyx.scipy.fft)
. This can allowscipy.fft
to work with bothnumpy
andcupy
arrays. For more information, see SciPy FFT backend.The boolean switch
cupy.fft.config.use_multi_gpus
also affects the FFT functions in this module, see Discrete Fourier Transform (cupy.fft). Moreover, this switch is honored when planning manually usingget_fft_plan()
.
Legacy discrete fourier transforms (cupyx.scipy.fftpack
)¶
Note
As of SciPy version 1.4.0, scipy.fft
is recommended over
scipy.fftpack
. Consider using cupyx.scipy.fft
instead.
Fast Fourier Transforms (FFTs)¶
|
Compute the one-dimensional FFT. |
|
Compute the one-dimensional inverse FFT. |
|
Compute the two-dimensional FFT. |
|
Compute the two-dimensional inverse FFT. |
|
Compute the N-dimensional FFT. |
|
Compute the N-dimensional inverse FFT. |
|
Compute the one-dimensional FFT for real input. |
|
Compute the one-dimensional inverse FFT for real input. |
|
Generate a CUDA FFT plan for transforming up to three axes. |
Code compatibility features¶
As with other FFT modules in CuPy, FFT functions in this module can take advantage of an existing cuFFT plan (returned by
get_fft_plan()
) to accelarate the computation. The plan can be either passed in explicitly via theplan
argument or used as a context manager. The argumentplan
is currently experimental and the interface may be changed in the future version. Theget_fft_plan()
function has no counterpart inscipy.fftpack
.The boolean switch
cupy.fft.config.enable_nd_planning
also affects the FFT functions in this module, see Discrete Fourier Transform (cupy.fft). This switch is neglected when planning manually usingget_fft_plan()
.Like in
scipy.fftpack
, all FFT functions in this module have an optional argumentoverwrite_x
(default isFalse
), which has the same semantics as inscipy.fftpack
: when it is set toTrue
, the input arrayx
can (not will) be overwritten arbitrarily. For this reason, when an in-place FFT is desired, the user should always reassign the input in the following manner:x = cupyx.scipy.fftpack.fft(x, ..., overwrite_x=True, ...)
.The boolean switch
cupy.fft.config.use_multi_gpus
also affects the FFT functions in this module, see Discrete Fourier Transform (cupy.fft). Moreover, this switch is honored when planning manually usingget_fft_plan()
.
Linear algebra (cupyx.scipy.linalg
)¶
Basics¶
|
Solve the equation a x = b for x, assuming a is a triangular matrix. |
|
Make a copy of a matrix with elements above the |
|
Make a copy of a matrix with elements below the |
Decompositions¶
|
LU decomposition. |
|
LU decomposition. |
|
Solve an equation system, |
Special Matrices¶
|
Create a block diagonal matrix from provided arrays. |
|
Construct a circulant matrix. |
|
Create a companion matrix. |
|
Construct a convolution matrix. |
|
Discrete Fourier transform matrix. |
|
Returns a symmetric Fiedler matrix |
Returns a Fiedler companion matrix |
|
|
Construct an Hadamard matrix. |
|
Construct a Hankel matrix. |
|
Create an Helmert matrix of order |
|
Create a Hilbert matrix of order |
|
Kronecker product. |
|
Create a Leslie matrix. |
|
Construct a Toeplitz matrix. |
|
Construct ( |
Multidimensional image processing (cupyx.scipy.ndimage
)¶
Filters¶
|
Multi-dimensional convolution. |
|
One-dimensional convolution. |
|
Multi-dimensional correlate. |
|
One-dimensional correlate. |
|
Multi-dimensional Gaussian filter. |
|
One-dimensional Gaussian filter along the given axis. |
|
Multi-dimensional gradient magnitude using Gaussian derivatives. |
|
Multi-dimensional Laplace filter using Gaussian second derivatives. |
|
Compute a multi-dimensional filter using the provided raw kernel or reduction kernel. |
|
Compute a 1D filter along the given axis using the provided raw kernel. |
|
Multi-dimensional gradient magnitude filter using a provided derivative function. |
|
Multi-dimensional Laplace filter using a provided second derivative function. |
|
Multi-dimensional Laplace filter based on approximate second derivatives. |
|
Multi-dimensional maximum filter. |
|
Compute the maximum filter along a single axis. |
|
Multi-dimensional median filter. |
|
Multi-dimensional minimum filter. |
|
Compute the minimum filter along a single axis. |
|
Multi-dimensional percentile filter. |
|
Compute a Prewitt filter along the given axis. |
|
Multi-dimensional rank filter. |
|
Compute a Sobel filter along the given axis. |
|
Multi-dimensional uniform filter. |
|
One-dimensional uniform filter along the given axis. |
Fourier filters¶
|
Multidimensional ellipsoid Fourier filter. |
|
Multidimensional Gaussian shift filter. |
|
Multidimensional Fourier shift filter. |
|
Multidimensional uniform shift filter. |
Interpolation¶
|
Apply an affine transformation. |
|
Map the input array to new coordinates by interpolation. |
|
Rotate an array. |
|
Shift an array. |
|
Multidimensional spline filter. |
|
Calculate a 1-D spline filter along the given axis. |
|
Zoom an array. |
Measurements¶
|
Calculate the center of mass of the values of an array at labels. |
|
Calculate the minimums and maximums of the values of an array at labels, along with their positions. |
|
Calculate the histogram of the values of an array, optionally at labels. |
|
Labels features in an array. |
|
Array resulting from applying |
|
Calculate the maximum of the values of an array over labeled regions. |
|
Find the positions of the maximums of the values of an array at labels. |
|
Calculates the mean of the values of an n-D image array, optionally |
|
Calculate the median of the values of an array over labeled regions. |
|
Calculate the minimum of the values of an array over labeled regions. |
|
Find the positions of the minimums of the values of an array at labels. |
|
Calculates the standard deviation of the values of an n-D image array, optionally at specified sub-regions. |
|
Calculates the sum of the values of an n-D image array, optionally |
|
Calculates the variance of the values of an n-D image array, optionally at specified sub-regions. |
Morphology¶
|
Multidimensional binary closing with the given structuring element. |
|
Multidimensional binary dilation with the given structuring element. |
|
Multidimensional binary erosion with a given structuring element. |
|
Fill the holes in binary objects. |
|
Multidimensional binary hit-or-miss transform. |
|
Multidimensional binary opening with the given structuring element. |
|
Multidimensional binary propagation with the given structuring element. |
|
Multidimensional black tophat filter. |
|
Generate a binary structure for binary morphological operations. |
|
Calculates a multi-dimensional greyscale closing. |
|
Calculates a greyscale dilation. |
|
Calculates a greyscale erosion. |
|
Calculates a multi-dimensional greyscale opening. |
|
Iterate a structure by dilating it with itself. |
|
Multidimensional morphological gradient. |
|
Multidimensional morphological laplace. |
|
Multidimensional white tophat filter. |
OpenCV mode¶
cupyx.scipy.ndimage
supports additional mode, opencv
.
If it is given, the function performs like cv2.warpAffine or cv2.resize. Example:
import cupyx.scipy.ndimage
import cupy as cp
import cv2
im = cv2.imread('TODO') # pls fill in your image path
trans_mat = cp.eye(4)
trans_mat[0][0] = trans_mat[1][1] = 0.5
smaller_shape = (im.shape[0] // 2, im.shape[1] // 2, 3)
smaller = cp.zeros(smaller_shape) # preallocate memory for resized image
cupyx.scipy.ndimage.affine_transform(im, trans_mat, output_shape=smaller_shape,
output=smaller, mode='opencv')
cv2.imwrite('smaller.jpg', cp.asnumpy(smaller)) # smaller image saved locally
Signal processing (cupyx.scipy.signal
)¶
Convolution¶
|
Convolve two N-dimensional arrays. |
|
Cross-correlate two N-dimensional arrays. |
|
Convolve two N-dimensional arrays using FFT. |
|
Convolve two N-dimensional arrays using the overlap-add method. |
|
Convolve two 2-dimensional arrays. |
|
Cross-correlate two 2-dimensional arrays. |
|
Convolve with a 2-D separable FIR filter. |
|
Find the fastest convolution/correlation method. |
Filtering¶
|
Perform an order filter on an N-D array. |
|
Perform a median filter on an N-dimensional array. |
|
Median filter a 2-dimensional array. |
|
Perform a Wiener filter on an N-dimensional array. |
Sparse matrices (cupyx.scipy.sparse
)¶
CuPy supports sparse matrices using cuSPARSE. These matrices have the same interfaces of SciPy’s sparse matrices.
Conversion to/from SciPy sparse matrices¶
cupyx.scipy.sparse.*_matrix
and scipy.sparse.*_matrix
are not implicitly convertible to each other.
That means, SciPy functions cannot take cupyx.scipy.sparse.*_matrix
objects as inputs, and vice versa.
To convert SciPy sparse matrices to CuPy, pass it to the constructor of each CuPy sparse matrix class.
To convert CuPy sparse matrices to SciPy, use
get
method of each CuPy sparse matrix class.
Note that converting between CuPy and SciPy incurs data transfer between the host (CPU) device and the GPU device, which is costly in terms of performance.
Conversion to/from CuPy ndarrays¶
To convert CuPy ndarray to CuPy sparse matrices, pass it to the constructor of each CuPy sparse matrix class.
To convert CuPy sparse matrices to CuPy ndarray, use
toarray
of each CuPy sparse matrix instance (e.g.,cupyx.scipy.sparse.csr_matrix.toarray()
).
Converting between CuPy ndarray and CuPy sparse matrices does not incur data transfer; it is copied inside the GPU device.
Contents¶
Sparse matrix classes¶
|
COOrdinate format sparse matrix. |
|
Compressed Sparse Column matrix. |
|
Compressed Sparse Row matrix. |
|
Sparse matrix with DIAgonal storage. |
|
Base class of all sparse matrixes. |
Functions¶
Building sparse matrices:
|
Creates a sparse matrix with ones on diagonal. |
|
Creates an identity matrix in sparse format. |
|
Kronecker product of sparse matrices A and B. |
|
Kronecker sum of sparse matrices A and B. |
|
Construct a sparse matrix from diagonals. |
|
Creates a sparse matrix from diagonals. |
|
Returns the lower triangular portion of a matrix in sparse format |
|
Returns the upper triangular portion of a matrix in sparse format |
|
Builds a sparse matrix from sparse sub-blocks |
|
Stacks sparse matrices horizontally (column wise) |
|
Stacks sparse matrices vertically (row wise) |
|
Generates a random sparse matrix. |
|
Generates a random sparse matrix. |
Sparse matrix tools:
|
Returns the indices and values of the nonzero elements of a matrix |
Identifying sparse matrices:
|
Checks if a given matrix is a sparse matrix. |
|
Checks if a given matrix is a sparse matrix. |
Checks if a given matrix is of CSC format. |
|
Checks if a given matrix is of CSR format. |
|
Checks if a given matrix is of COO format. |
|
Checks if a given matrix is of DIA format. |
Sparse linear algebra (cupyx.scipy.sparse.linalg
)¶
Abstract linear operators¶
|
Common interface for performing matrix vector products |
Return A as a LinearOperator. |
Solving linear problems¶
Direct methods for linear equation systems:
|
Solves a sparse linear system |
|
Solves a sparse triangular system |
|
Return a function for solving a sparse linear system, with A pre-factorized. |
Iterative methods for linear equation systems:
|
Uses Conjugate Gradient iteration to solve |
|
Uses Generalized Minimal RESidual iteration to solve |
|
Use Conjugate Gradient Squared iteration to solve |
|
Uses MINimum RESidual iteration to solve |
Iterative methods for least-squares problems:
|
Solves linear system with QR decomposition. |
|
Iterative solver for least-squares problems. |
Matrix factorizations¶
Eigenvalue problems:
|
Find |
|
Locally Optimal Block Preconditioned Conjugate Gradient Method (LOBPCG) |
Singular values problems:
|
Finds the largest |
Complete or incomplete LU factorizations:
|
Computes the LU decomposition of a sparse square matrix. |
|
Computes the incomplete LU decomposition of a sparse square matrix. |
|
Compressed sparse graph routines (cupyx.scipy.sparse.csgraph
)¶
Note
The csgraph
module uses pylibcugraph
as a backend.
You need to install pylibcugraph package <https://anaconda.org/rapidsai/pylibcugraph> from rapidsai
Conda channel to use features listed on this page.
Note
Currently, the csgraph
module is not supported on AMD ROCm platforms.
Contents¶
|
Analyzes the connected components of a sparse graph |
Special functions (cupyx.scipy.special
)¶
Bessel functions¶
Bessel function of the first kind of order 0. |
|
Bessel function of the first kind of order 1. |
|
Bessel function of the second kind of order 0. |
|
Bessel function of the second kind of order 1. |
|
Modified Bessel function of order 0. |
|
Modified Bessel function of order 1. |
Information Theory functions¶
Elementwise function for computing entropy. |
|
Elementwise function for computing relative entropy. |
|
Elementwise function for computing Kullback-Leibler divergence. |
|
Elementwise function for computing the Huber loss. |
|
Elementwise function for computing the Pseudo-Huber loss. |
Error function and Fresnel integrals¶
Error function. |
|
Complementary error function. |
|
Scaled complementary error function. |
|
Inverse function of error function. |
|
Inverse function of complementary error function. |
Statistical functions (cupyx.scipy.stats
)¶
CuPy-specific functions¶
CuPy-specific functions are placed under cupyx
namespace.
Returns the reciprocal square root. |
|
|
Adds given values to specified elements of an array. |
|
Stores a maximum value of elements specified by indices to an array. |
|
Stores a minimum value of elements specified by indices to an array. |
|
Returns a new, uninitialized NumPy array with the given shape and dtype. |
|
Returns a new, uninitialized NumPy array with the same shape and dtype as those of the given array. |
|
Returns a new, zero-initialized NumPy array with the given shape and dtype. |
|
Returns a new, zero-initialized NumPy array with the same shape and dtype as those of the given array. |
Profiling utilities¶
|
Timing utility for measuring time spent by both CPU and GPU. |
|
Mark function calls with ranges using NVTX/rocTX. |
Enable CUDA profiling during with statement. |
DLPack utilities¶
Below are helper functions for creating a cupy.ndarray
from either a DLPack tensor
or any object supporting the DLPack data exchange protocol.
For further detail see DLPack.
|
Zero-copy conversion between array objects compliant with the DLPack data exchange protocol. |
Automatic Kernel Parameters Optimizations (cupyx.optimizing
)¶
|
Context manager that optimizes kernel launch parameters. |
Low-level CUDA support¶
Device management¶
|
Object that represents a CUDA device. |
Memory management¶
Returns CuPy default memory pool for GPU memory. |
|
Returns CuPy default memory pool for pinned memory. |
|
|
Memory allocation on a CUDA device. |
|
Asynchronous memory allocation on a CUDA device. |
|
Managed memory (Unified memory) allocation on a CUDA device. |
|
CUDA memory that is not owned by CuPy. |
|
Pinned memory allocation on host. |
|
Pointer to a point on a device memory. |
|
Pointer of a pinned memory. |
|
Allocate managed memory (unified memory). |
|
(Experimental) Allocate memory from Stream Ordered Memory Allocator. |
|
Calls the current allocator. |
|
Calls the current allocator. |
Returns the current allocator for GPU memory. |
|
|
Sets the current allocator for GPU memory. |
|
Sets a thread-local allocator for GPU memory inside |
Sets the current allocator for the pinned memory. |
|
|
Memory pool for all GPU devices on the host. |
|
(Experimental) CUDA memory pool for all GPU devices on the host. |
|
Memory pool for pinned memory on the host. |
Allocator with python functions to perform memory allocation. |
|
|
Allocator with C function pointers to allocation routines. |
Memory hook¶
Base class of hooks for Memory allocations. |
|
Memory hook that prints debug information. |
|
Code line CuPy memory profiler. |
Streams and events¶
|
CUDA stream. |
|
CUDA stream not managed by CuPy. |
Gets current CUDA stream. |
|
|
CUDA event, a synchronization point of CUDA streams. |
|
Gets the elapsed time between two events. |
Texture and surface memory¶
A class that holds the channel format description. |
|
Allocate a CUDA array (cudaArray_t) that can be used as texture memory. |
|
A class that holds the resource description. |
|
A class that holds the texture description. |
|
A class that holds a texture object. |
|
A class that holds a surface object. |
|
A class that holds a texture reference. |
Profiler¶
Enable CUDA profiling during with statement. |
|
Initialize the CUDA profiler. |
|
Enable profiling. |
|
Disable profiling. |
|
|
Marks an instantaneous event (marker) in the application. |
|
Marks an instantaneous event (marker) in the application. |
|
Starts a nested range. |
|
Starts a nested range. |
Ends a nested range. |
NCCL¶
|
Initialize an NCCL communicator for one device controlled by one process. |
Returns the runtime version of NCCL. |
|
Start a group of NCCL calls. |
|
End a group of NCCL calls. |
Runtime API¶
CuPy wraps CUDA Runtime APIs to provide the native CUDA operations. Please check the CUDA Runtime API documentation to use these functions.
Custom kernels¶
|
User-defined elementwise kernel. |
|
User-defined reduction kernel. |
|
User-defined custom kernel. |
|
User-defined custom module. |
|
Decorator that fuses a function. |
JIT kernel definition¶
|
A decorator compiles a Python function into CUDA kernel. |
dim3 threadIdx |
|
dim3 blockDim |
|
dim3 blockIdx |
|
dim3 gridDim |
|
Compute the thread index in the grid. |
|
Compute the grid size. |
|
Returns the lane ID of the calling thread, ranging in |
|
Returns the number of threads in a warp. |
|
Calls |
|
Calls |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Allocates shared memory and returns the 1-dim array. |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
Calls the |
|
|
JIT CUDA kernel object. |
Kernel binary memoization¶
|
Makes a function memoizing the result for each argument and device. |
Clears the memoized results for all functions decorated by memoize. |
Distributed¶
The following pages describe the APIs used to easily perform communication between different processes in CuPy.
|
Start cupyx.distributed and obtain a communicator. |
|
Interface that uses NVIDIA’s NCCL to perform communications. |
Environment variables¶
For runtime¶
Here are the environment variables that CuPy uses at runtime.
- CUDA_PATH¶
Path to the directory containing CUDA. The parent of the directory containing
nvcc
is used as default. Whennvcc
is not found,/usr/local/cuda
is used. See Working with Custom CUDA Installation for details.
- CUPY_CACHE_DIR¶
Default:
${HOME}/.cupy/kernel_cache
Path to the directory to store kernel cache. See Performance Best Practices for details.
- CUPY_CACHE_SAVE_CUDA_SOURCE¶
Default:
0
If set to
1
, CUDA source file will be saved along with compiled binary in the cache directory for debug purpose. Note: the source file will not be saved if the compiled binary is already stored in the cache.
- CUPY_CACHE_IN_MEMORY¶
Default:
0
If set to
1
,CUPY_CACHE_DIR
andCUPY_CACHE_SAVE_CUDA_SOURCE
will be ignored, and the cache is in memory. This environment variable allows reducing disk I/O, but is ignoed whennvcc
is set to be the compiler backend.
- CUPY_DUMP_CUDA_SOURCE_ON_ERROR¶
Default:
0
If set to
1
, when CUDA kernel compilation fails, CuPy dumps CUDA kernel code to standard error.
- CUPY_CUDA_COMPILE_WITH_DEBUG¶
Default:
0
If set to
1
, CUDA kernel will be compiled with debug information (--device-debug
and--generate-line-info
).
- CUPY_GPU_MEMORY_LIMIT¶
Default:
0
(unlimited)The amount of memory that can be allocated for each device. The value can be specified in absolute bytes or fraction (e.g.,
"90%"
) of the total memory of each GPU. See Memory Management for details.
- CUPY_SEED¶
Set the seed for random number generators.
- CUPY_EXPERIMENTAL_SLICE_COPY¶
Default:
0
If set to
1
, the following syntax is enabled:cupy_ndarray[:] = numpy_ndarray
- CUPY_ACCELERATORS¶
Default:
""
(no accelerators)A comma-separated string of backend names (
cub
orcutensor
) which indicates the acceleration backends used in CuPy operations and its priority. All accelerators are disabled by default.
- CUPY_TF32¶
Default:
0
If set to
1
, it allows CUDA libraries to use Tensor Cores TF32 compute for 32-bit floating point compute.
- CUPY_CUDA_ARRAY_INTERFACE_SYNC¶
Default:
1
This controls CuPy’s behavior as a Consumer. If set to
0
, a stream synchronization will not be performed when a device array provided by an external library that implements the CUDA Array Interface is being consumed by CuPy. For more detail, see the Synchronization requirement in the CUDA Array Interface v3 documentation.
- CUPY_CUDA_ARRAY_INTERFACE_EXPORT_VERSION¶
Default:
3
This controls CuPy’s behavior as a Producer. If set to
2
, the CuPy stream on which the data is being operated will not be exported and thus the Consumer (another library) will not perform any stream synchronization. For more detail, see the Synchronization requirement in the CUDA Array Interface v3 documentation.
- CUPY_DLPACK_EXPORT_VERSION¶
Default:
0.6
This controls CuPy’s DLPack support. Currently, setting a value smaller than 0.6 would disguise managed memory as normal device memory, which enables data exchanges with libraries that have not updated their DLPack support, whereas starting 0.6 CUDA managed memory can be correctly recognized as a valid device type.
- NVCC¶
Default:
nvcc
Define the compiler to use when compiling CUDA source. Note that most CuPy kernels are built with NVRTC; this environment variable is only effective for
RawKernel
/RawModule
with thenvcc
backend or when usingcub
as the accelerator.
- CUPY_CUDA_PER_THREAD_DEFAULT_STREAM¶
Default:
0
If set to
1
, CuPy will use the CUDA per-thread default stream, effectively causing each host thread to automatically execute in its own stream, unless the CUDA default (null
) stream or a user-created stream is specified. If set to0
(default), the CUDA default (null
) stream is used, unless the per-thread default stream (ptds
) or a user-created stream is specified.
- CUPY_COMPILE_WITH_PTX¶
Default:
0
By default, CuPy directly compiles kernels into SASS (CUBIN) to support CUDA Enhanced Compatibility If set to
1
, CuPy instead compiles kernels into PTX and lets CUDA Driver assemble SASS from PTX. This option is only effective for CUDA 11.1 or later; CuPy always compiles into PTX on earlier CUDA versions. Also, this option only applies when NVRTC is selected as the compilation backend. NVCC backend always compiles into SASS (CUBIN).
- CUDA Toolkit Environment Variables
In addition to the environment variables listed above, as in any CUDA programs, all of the CUDA environment variables listed in the CUDA Toolkit Documentation will also be honored.
Note
When CUPY_ACCELERATORS
or NVCC
environment variables are set, g++-6 or later is required as the runtime host compiler.
Please refer to Installing CuPy from Source for the details on how to install g++.
For installation¶
These environment variables are used during installation (building CuPy from source).
- CUTENSOR_PATH¶
Path to the cuTENSOR root directory that contains
lib
andinclude
directories. (experimental)
- CUPY_INSTALL_USE_HIP¶
Default:
0
If set to
1
, CuPy is built for AMD ROCm Platform (experimental). For building the ROCm support, see Installing Binary Packages for further detail.
- CUPY_USE_CUDA_PYTHON¶
Default:
0
If set to
1
, CuPy is built using CUDA Python.
- CUPY_NVCC_GENERATE_CODE¶
Build CuPy for a particular CUDA architecture. For example:
CUPY_NVCC_GENERATE_CODE="arch=compute_60,code=sm_60"
For specifying multiple archs, concatenate the
arch=...
strings with semicolons (;
). Ifcurrent
is specified, then it will automatically detect the currently installed GPU architectures in build time. When this is not set, the default is to support all architectures.
- CUPY_NUM_BUILD_JOBS¶
Default:
4
To enable or disable parallel build, sets the number of processes used to build the extensions in parallel.
- CUPY_NUM_NVCC_THREADS¶
Default:
2
To enable or disable nvcc parallel compilation, sets the number of threads used to compile files using nvcc.
Additionally, the environment variables CUDA_PATH
and NVCC
are also respected at build time.
Comparison Table¶
Here is a list of NumPy / SciPy APIs and its corresponding CuPy implementations.
-
in CuPy column denotes that CuPy implementation is not provided yet.
We welcome contributions for these functions.
NumPy / CuPy APIs¶
Module-Level¶
NumPy |
CuPy |
---|---|
|
- 1 |
|
|
|
- 2 |
- |
|
|
- |
|
- 1 |
|
- 2 |
|
|
|
- 2 |
|
|
- |
|
- |
|
- |
|
- |
|
- 3 |
|
- 2 |
|
|
|
- |
|
- 3 |
|
|
|
|
|
|
|
- 4 |
|
- 4 |
|
- 4 |
|
|
|
- |
|
|
- 1 |
- 1 |
|
|
- |
|
|
|
|
|
- 1 |
- 1 |
|
- |
|
- |
|
|
- 1 |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
- 4 |
|
- 4 |
|
- 4 |
|
- |
|
- |
|
- |
|
|
|
|
|
|
|
- |
|
- |
|
- 5 |
|
- |
|
|
|
|
|
- 1 |
|
- |
|
|
|
|
|
|
|
|
|
- |
|
|
|
- |
|
- |
|
|
|
- |
|
- 6 |
|
|
|
- |
|
|
|
- |
|
|
|
- |
|
- 5 |
|
- 5 |
|
- 5 |
|
|
|
- |
|
- |
|
|
|
|
|
|
|
- |
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
|
|
- 4 |
|
- 4 |
|
- |
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
- 2 |
- |
|
- |
|
|
|
|
|
- |
|
|
- 2 |
- |
|
- 3 |
|
- 3 |
|
|
|
- |
|
- |
|
|
|
|
|
|
|
- |
|
- |
|
|
- |
- |
|
|
- 2 |
|
|
|
|
- |
|
- |
|
|
|
|
|
|
|
|
- 1 |
- 1 |
|
- 7 |
|
- 7 |
|
- 7 |
|
- 7 |
|
|
|
|
|
|
|
- |
|
- |
|
- 6 |
|
|
- 6 |
|
- 6 |
- 6 |
|
|
|
- |
|
|
|
|
|
|
|
|
- |
|
- |
|
- 2 |
|
|
|
|
- |
|
- |
|
- 5 |
|
- 5 |
|
- 5 |
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
- |
|
- |
|
|
- 1 |
- 1 |
|
- 1 |
|
- 4 |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- 1 |
|
- |
|
|
|
|
|
- |
|
- 1 |
|
|
- 1 |
Linear Algebra¶
NumPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
Discrete Fourier Transform¶
NumPy |
CuPy |
---|---|
Random Sampling¶
NumPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
Polynomials¶
NumPy |
CuPy |
---|---|
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
- |
Power Series¶
NumPy |
CuPy |
---|---|
|
- |
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Polyutils¶
NumPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
SciPy / CuPy APIs¶
Discrete Fourier Transform¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Legacy Discrete Fourier Transform¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Advanced Linear Algebra¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
Multidimensional Image Processing¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
|
|
- |
|
Signal processing¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
|
- |
- |
|
- |
|
- |
|
- |
|
|
- |
|
- |
- |
|
|
- |
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
|
- |
- |
|
- |
|
|
- |
|
- |
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Sparse Matrices¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
Sparse Linear Algebra¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Compressed sparse graph routines¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Special Functions¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
|
- |
|
- |
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Statistical Functions¶
SciPy |
CuPy |
---|---|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
|
- |
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
|
- |
Footnotes
- 1(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16)
object and string dtypes are not supported in GPU and thus left unimplemented in CuPy.
- 2(1,2,3,4,5,6,7,8,9)
Not supported as it has been deprecated in NumPy.
- 3(1,2,3,4)
Use of
numpy.matrix
is discouraged in NumPy and thus we have no plan to add it to CuPy.- 4(1,2,3,4,5,6,7,8,9)
datetime64 and timedelta64 dtypes are currently unsupported.
- 5(1,2,3,4,5,6,7)
Floating point error handling depends on CPU-specific features which is not available in GPU.
- 6(1,2,3,4,5)
Structured arrays and record arrays are currently unsupported.
- 7(1,2,3,4)
Use of
numpy.poly1d
is discouraged in NumPy and thus we have stopped adding functions with the interface.- 8(1,2)
Not supported as GPUs only support little-endian byte-encoding.
Python Array API Support¶
The Python array API standard aims to provide a coherent set of APIs for array and tensor libraries developed by the community to build upon. This solves the API fragmentation issue across the community by offering concrete function signatures, semantics and scopes of coverage, enabling writing backend-agnostic codes for better portability.
CuPy provides experimental support based on NumPy’s NEP-47,
which is in turn based on the draft standard to be finalized in 2021. All of the functionalities can be accessed
through the cupy.array_api
namespace.
The key difference between NumPy and CuPy is that we are a GPU-only library, therefore CuPy users should be aware
of potential device management issues.
Same as in regular CuPy codes, the GPU-to-use can be specified via the Device
objects, see
Device management.
Array API Functions¶
This section is a full list of implemented APIs. For the detailed documentation, see the array API specification.
- cupy.array_api.abs(x, /)[source]¶
Array API compatible wrapper for
np.abs
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.acos(x, /)[source]¶
Array API compatible wrapper for
np.arccos
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.acosh(x, /)[source]¶
Array API compatible wrapper for
np.arccosh
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.add(x1, x2, /)[source]¶
Array API compatible wrapper for
np.add
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.all(x, /, *, axis=None, keepdims=False)[source]¶
Array API compatible wrapper for
np.all
.See its docstring for more information.
- Parameters
keepdims (bool) –
- Return type
- cupy.array_api.any(x, /, *, axis=None, keepdims=False)[source]¶
Array API compatible wrapper for
np.any
.See its docstring for more information.
- Parameters
keepdims (bool) –
- Return type
- cupy.array_api.arange(start, /, stop=None, step=1, *, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.arange
.See its docstring for more information.
- cupy.array_api.argmax(x, /, *, axis=None, keepdims=False)[source]¶
Array API compatible wrapper for
np.argmax
.See its docstring for more information.
- cupy.array_api.argmin(x, /, *, axis=None, keepdims=False)[source]¶
Array API compatible wrapper for
np.argmin
.See its docstring for more information.
- cupy.array_api.argsort(x, /, *, axis=- 1, descending=False, stable=True)[source]¶
Array API compatible wrapper for
np.argsort
.See its docstring for more information.
- cupy.array_api.asarray(obj, /, *, dtype=None, device=None, copy=None)[source]¶
Array API compatible wrapper for
np.asarray
.See its docstring for more information.
- cupy.array_api.asin(x, /)[source]¶
Array API compatible wrapper for
np.arcsin
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.asinh(x, /)[source]¶
Array API compatible wrapper for
np.arcsinh
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.atan(x, /)[source]¶
Array API compatible wrapper for
np.arctan
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.atan2(x1, x2, /)[source]¶
Array API compatible wrapper for
np.arctan2
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.atanh(x, /)[source]¶
Array API compatible wrapper for
np.arctanh
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.bitwise_and(x1, x2, /)[source]¶
Array API compatible wrapper for
np.bitwise_and
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.bitwise_invert(x, /)[source]¶
Array API compatible wrapper for
np.invert
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.bitwise_left_shift(x1, x2, /)[source]¶
Array API compatible wrapper for
np.left_shift
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.bitwise_or(x1, x2, /)[source]¶
Array API compatible wrapper for
np.bitwise_or
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.bitwise_right_shift(x1, x2, /)[source]¶
Array API compatible wrapper for
np.right_shift
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.bitwise_xor(x1, x2, /)[source]¶
Array API compatible wrapper for
np.bitwise_xor
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.broadcast_arrays(*arrays)[source]¶
Array API compatible wrapper for
np.broadcast_arrays
.See its docstring for more information.
- Parameters
arrays (cupy.array_api._array_object.Array) –
- Return type
- cupy.array_api.broadcast_to(x, /, shape)[source]¶
Array API compatible wrapper for
np.broadcast_to
.See its docstring for more information.
- Parameters
shape (Tuple[int, ...]) –
- Return type
- cupy.array_api.can_cast(from_, to, /)[source]¶
Array API compatible wrapper for
np.can_cast
.See its docstring for more information.
- cupy.array_api.ceil(x, /)[source]¶
Array API compatible wrapper for
np.ceil
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.concat(arrays, /, *, axis=0)[source]¶
Array API compatible wrapper for
np.concatenate
.See its docstring for more information.
- Parameters
arrays (Union[Tuple[cupy.array_api._array_object.Array, ...], List[cupy.array_api._array_object.Array]]) –
axis (Optional[int]) –
- Return type
- cupy.array_api.cos(x, /)[source]¶
Array API compatible wrapper for
np.cos
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.cosh(x, /)[source]¶
Array API compatible wrapper for
np.cosh
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.divide(x1, x2, /)[source]¶
Array API compatible wrapper for
np.divide
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.empty(shape, *, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.empty
.See its docstring for more information.
- cupy.array_api.empty_like(x, /, *, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.empty_like
.See its docstring for more information.
- cupy.array_api.equal(x1, x2, /)[source]¶
Array API compatible wrapper for
np.equal
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.exp(x, /)[source]¶
Array API compatible wrapper for
np.exp
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.expand_dims(x, /, *, axis)[source]¶
Array API compatible wrapper for
np.expand_dims
.See its docstring for more information.
- Parameters
axis (int) –
- Return type
- cupy.array_api.expm1(x, /)[source]¶
Array API compatible wrapper for
np.expm1
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.eye
.See its docstring for more information.
- cupy.array_api.finfo(type, /)[source]¶
Array API compatible wrapper for
np.finfo
.See its docstring for more information.
- Parameters
type (Union[Dtype, Array]) –
- Return type
finfo_object
- cupy.array_api.flip(x, /, *, axis=None)[source]¶
Array API compatible wrapper for
np.flip
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.floor(x, /)[source]¶
Array API compatible wrapper for
np.floor
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.floor_divide(x1, x2, /)[source]¶
Array API compatible wrapper for
np.floor_divide
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.from_dlpack(x, /)[source]¶
Array API compatible wrapper for
np.from_dlpack
.See its docstring for more information.
- cupy.array_api.full(shape, fill_value, *, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.full
.See its docstring for more information.
- cupy.array_api.full_like(x, /, fill_value, *, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.full_like
.See its docstring for more information.
- cupy.array_api.greater(x1, x2, /)[source]¶
Array API compatible wrapper for
np.greater
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.greater_equal(x1, x2, /)[source]¶
Array API compatible wrapper for
np.greater_equal
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.iinfo(type, /)[source]¶
Array API compatible wrapper for
np.iinfo
.See its docstring for more information.
- Parameters
type (Union[Dtype, Array]) –
- Return type
iinfo_object
- cupy.array_api.isfinite(x, /)[source]¶
Array API compatible wrapper for
np.isfinite
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.isinf(x, /)[source]¶
Array API compatible wrapper for
np.isinf
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.isnan(x, /)[source]¶
Array API compatible wrapper for
np.isnan
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.less(x1, x2, /)[source]¶
Array API compatible wrapper for
np.less
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.less_equal(x1, x2, /)[source]¶
Array API compatible wrapper for
np.less_equal
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True)[source]¶
Array API compatible wrapper for
np.linspace
.See its docstring for more information.
- cupy.array_api.log(x, /)[source]¶
Array API compatible wrapper for
np.log
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.log10(x, /)[source]¶
Array API compatible wrapper for
np.log10
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.log1p(x, /)[source]¶
Array API compatible wrapper for
np.log1p
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.log2(x, /)[source]¶
Array API compatible wrapper for
np.log2
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.logaddexp(x1, x2)[source]¶
Array API compatible wrapper for
np.logaddexp
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.logical_and(x1, x2, /)[source]¶
Array API compatible wrapper for
np.logical_and
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.logical_not(x, /)[source]¶
Array API compatible wrapper for
np.logical_not
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.logical_or(x1, x2, /)[source]¶
Array API compatible wrapper for
np.logical_or
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.logical_xor(x1, x2, /)[source]¶
Array API compatible wrapper for
np.logical_xor
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.matmul(x1, x2, /)[source]¶
Array API compatible wrapper for
np.matmul
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.meshgrid(*arrays, indexing='xy')[source]¶
Array API compatible wrapper for
np.meshgrid
.See its docstring for more information.
- cupy.array_api.multiply(x1, x2, /)[source]¶
Array API compatible wrapper for
np.multiply
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.negative(x, /)[source]¶
Array API compatible wrapper for
np.negative
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.nonzero(x, /)[source]¶
Array API compatible wrapper for
np.nonzero
.See its docstring for more information.
- Parameters
- Return type
Tuple[cupy.array_api._array_object.Array, …]
- cupy.array_api.not_equal(x1, x2, /)[source]¶
Array API compatible wrapper for
np.not_equal
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.ones(shape, *, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.ones
.See its docstring for more information.
- cupy.array_api.ones_like(x, /, *, dtype=None, device=None)[source]¶
Array API compatible wrapper for
np.ones_like
.See its docstring for more information.
- cupy.array_api.permute_dims(x, /, axes)[source]¶
Array API compatible wrapper for
np.transpose
.See its docstring for more information.
- Parameters
axes (Tuple[int, ...]) –
- Return type
- cupy.array_api.positive(x, /)[source]¶
Array API compatible wrapper for
np.positive
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.pow(x1, x2, /)[source]¶
Array API compatible wrapper for
np.power
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.remainder(x1, x2, /)[source]¶
Array API compatible wrapper for
np.remainder
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.reshape(x, /, shape)[source]¶
Array API compatible wrapper for
np.reshape
.See its docstring for more information.
- Parameters
shape (Tuple[int, ...]) –
- Return type
- cupy.array_api.result_type(*arrays_and_dtypes)[source]¶
Array API compatible wrapper for
np.result_type
.See its docstring for more information.
- Parameters
arrays_and_dtypes (Union[Array, Dtype]) –
- Return type
Dtype
- cupy.array_api.roll(x, /, shift, *, axis=None)[source]¶
Array API compatible wrapper for
np.roll
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.round(x, /)[source]¶
Array API compatible wrapper for
np.round
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.sign(x, /)[source]¶
Array API compatible wrapper for
np.sign
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.sin(x, /)[source]¶
Array API compatible wrapper for
np.sin
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.sinh(x, /)[source]¶
Array API compatible wrapper for
np.sinh
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.sort(x, /, *, axis=- 1, descending=False, stable=True)[source]¶
Array API compatible wrapper for
np.sort
.See its docstring for more information.
- cupy.array_api.sqrt(x, /)[source]¶
Array API compatible wrapper for
np.sqrt
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.square(x, /)[source]¶
Array API compatible wrapper for
np.square
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.squeeze(x, /, axis)[source]¶
Array API compatible wrapper for
np.squeeze
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.stack(arrays, /, *, axis=0)[source]¶
Array API compatible wrapper for
np.stack
.See its docstring for more information.
- Parameters
arrays (Union[Tuple[cupy.array_api._array_object.Array, ...], List[cupy.array_api._array_object.Array]]) –
axis (int) –
- Return type
- cupy.array_api.subtract(x1, x2, /)[source]¶
Array API compatible wrapper for
np.subtract
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.tan(x, /)[source]¶
Array API compatible wrapper for
np.tan
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.tanh(x, /)[source]¶
Array API compatible wrapper for
np.tanh
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.tril(x, /, *, k=0)[source]¶
Array API compatible wrapper for
np.tril
.See its docstring for more information.
- cupy.array_api.triu(x, /, *, k=0)[source]¶
Array API compatible wrapper for
np.triu
.See its docstring for more information.
- cupy.array_api.trunc(x, /)[source]¶
Array API compatible wrapper for
np.trunc
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.unique_all(x, /)[source]¶
Array API compatible wrapper for
np.unique
.See its docstring for more information.
- Parameters
- Return type
cupy.array_api._set_functions.UniqueAllResult
- cupy.array_api.unique_inverse(x, /)[source]¶
Array API compatible wrapper for
np.unique
.See its docstring for more information.
- Parameters
- Return type
cupy.array_api._set_functions.UniqueInverseResult
- cupy.array_api.unique_values(x, /)[source]¶
Array API compatible wrapper for
np.unique
.See its docstring for more information.
- Parameters
- Return type
- cupy.array_api.where(condition, x1, x2, /)[source]¶
Array API compatible wrapper for
np.where
.See its docstring for more information.
- Parameters
condition (cupy.array_api._array_object.Array) –
- Return type
Array API Complaint Object¶
Array
is a wrapper class built upon cupy.ndarray
to enforce strict complaince with the array API standard. See the
documentation
for detail.
This object should not be constructed directly. Rather, use one of the
creation functions,
such as cupy.array_api.asarray()
.
|
n-d array object for the array API namespace. |
Contribution Guide¶
This is a guide for all contributions to CuPy. The development of CuPy is running on the official repository at GitHub. Anyone that wants to register an issue or to send a pull request should read through this document.
Classification of Contributions¶
There are several ways to contribute to CuPy community:
Registering an issue
Sending a pull request (PR)
Sending a question to CuPy’s Gitter channel, CuPy User Group, or StackOverflow
Open-sourcing an external example
Writing a post about CuPy
This document mainly focuses on 1 and 2, though other contributions are also appreciated.
Development Cycle¶
This section explains the development process of CuPy. Before contributing to CuPy, it is strongly recommended to understand the development cycle.
Versioning¶
The versioning of CuPy follows PEP 440 and a part of Semantic versioning.
The version number consists of three or four parts: X.Y.Zw
where X
denotes the major version, Y
denotes the minor version, Z
denotes the revision number, and the optional w
denotes the prelease suffix.
While the major, minor, and revision numbers follow the rule of semantic versioning, the pre-release suffix follows PEP 440 so that the version string is much friendly with Python eco-system.
Note that a major update basically does not contain compatibility-breaking changes from the last release candidate (RC). This is not a strict rule, though; if there is a critical API bug that we have to fix for the major version, we may add breaking changes to the major version up.
As for the backward compatibility, see API Compatibility Policy.
Release Cycle¶
The first one is the track of stable versions, which is a series of revision updates for the latest major version. The second one is the track of development versions, which is a series of pre-releases for the upcoming major version.
Consider that X.0.0
is the latest major version and Y.0.0
, Z.0.0
are the succeeding major versions.
Then, the timeline of the updates is depicted by the following table.
Date |
ver X |
ver Y |
ver Z |
---|---|---|---|
0 weeks |
X.0.0rc1 |
– |
– |
4 weeks |
X.0.0 |
Y.0.0a1 |
– |
8 weeks |
X.1.0* |
Y.0.0b1 |
– |
12 weeks |
X.2.0* |
Y.0.0rc1 |
– |
16 weeks |
– |
Y.0.0 |
Z.0.0a1 |
(* These might be revision releases)
The dates shown in the left-most column are relative to the release of X.0.0rc1
.
In particular, each revision/minor release is made four weeks after the previous one of the same major version, and the pre-release of the upcoming major version is made at the same time.
Whether these releases are revision or minor is determined based on the contents of each update.
Note that there are only three stable releases for the versions X.x.x
.
During the parallel development of Y.0.0
and Z.0.0a1
, the version Y
is treated as an almost-stable version and Z
is treated as a development version.
If there is a critical bug found in X.x.x
after stopping the development of version X
, we may release a hot-fix for this version at any time.
We create a milestone for each upcoming release at GitHub. The GitHub milestone is basically used for collecting the issues and PRs resolved in the release.
Git Branches¶
The master
branch is used to develop pre-release versions.
It means that alpha, beta, and RC updates are developed at the master
branch.
This branch contains the most up-to-date source tree that includes features newly added after the latest major version.
The stable version is developed at the individual branch named as vN
where “N” reflects the version number (we call it a versioned branch).
For example, v1.0.0, v1.0.1, and v1.0.2 will be developed at the v1
branch.
Notes for contributors:
When you send a pull request, you basically have to send it to the master
branch.
If the change can also be applied to the stable version, a core team member will apply the same change to the stable version so that the change is also included in the next revision update.
If the change is only applicable to the stable version and not to the master
branch, please send it to the versioned branch.
We basically only accept changes to the latest versioned branch (where the stable version is developed) unless the fix is critical.
If you want to make a new feature of the master
branch available in the current stable version, please send a backport PR to the stable version (the latest vN
branch).
See the next section for details.
Note: a change that can be applied to both branches should be sent to the master
branch.
Each release of the stable version is also merged to the development version so that the change is also reflected to the next major version.
Feature Backport PRs¶
We basically do not backport any new features of the development version to the stable versions.
If you desire to include the feature to the current stable version and you can work on the backport work, we welcome such a contribution.
In such a case, you have to send a backport PR to the latest vN
branch.
Note that we do not accept any feature backport PRs to older versions because we are not running quality assurance workflows (e.g. CI) for older versions so that we cannot ensure that the PR is correctly ported.
There are some rules on sending a backport PR.
Start the PR title from the prefix [backport].
Clarify the original PR number in the PR description (something like “This is a backport of #XXXX”).
(optional) Write to the PR description the motivation of backporting the feature to the stable version.
Please follow these rules when you create a feature backport PR.
Note: PRs that do not include any changes/additions to APIs (e.g. bug fixes, documentation improvements) are usually backported by core dev members. It is also appreciated to make such a backport PR by any contributors, though, so that the overall development proceeds more smoothly!
Issues and Pull Requests¶
In this section, we explain how to file issues and send pull requests (PRs).
Issue/PR Labels¶
Issues and PRs are labeled by the following tags:
Bug: bug reports (issues) and bug fixes (PRs)
Enhancement: implementation improvements without breaking the interface
Feature: feature requests (issues) and their implementations (PRs)
NoCompat: disrupts backward compatibility
Test: test fixes and updates
Document: document fixes and improvements
Example: fixes and improvements on the examples
Install: fixes installation script
Contribution-Welcome: issues that we request for contribution (only issues are categorized to this)
Other: other issues and PRs
Multiple tags might be labeled to one issue/PR. Note that revision releases cannot include PRs in Feature and NoCompat categories.
How to File an Issue¶
On registering an issue, write precise explanations on how you want CuPy to be. Bug reports must include necessary and sufficient conditions to reproduce the bugs. Feature requests must include what you want to do (and why you want to do, if needed) with CuPy. You can contain your thoughts on how to realize it into the feature requests, though what part is most important for discussions.
Warning
If you have a question on usages of CuPy, it is highly recommended to send a post to CuPy’s Gitter channel, CuPy User Group or StackOverflow instead of the issue tracker. The issue tracker is not a place to share knowledge on practices. We may suggest these places and immediately close how-to question issues.
How to Send a Pull Request¶
If you can write code to fix an issue, we encourage to send a PR.
First of all, before starting to write any code, do not forget to confirm the following points.
Read through the Coding Guidelines and Unit Testing.
Check the appropriate branch that you should send the PR following Git Branches. If you do not have any idea about selecting a branch, please choose the
master
branch.
In particular, check the branch before writing any code. The current source tree of the chosen branch is the starting point of your change.
After writing your code (including unit tests and hopefully documentations!), send a PR on GitHub. You have to write a precise explanation of what and how you fix; it is the first documentation of your code that developers read, which is a very important part of your PR.
Once you send a PR, it is automatically tested on GitHub Actions
.
After the automatic test passes, core developers will start reviewing your code.
Note that this automatic PR test only includes CPU tests.
Note
We are also running continuous integration with GPU tests for the master
branch and the versioned branch of the latest major version.
Since this service is currently running on our internal server, we do not use it for automatic PR tests to keep the server secure.
If you are planning to add a new feature or modify existing APIs, it is recommended to open an issue and discuss the design first. The design discussion needs lower cost for the core developers than code review. Following the consequences of the discussions, you can send a PR that is smoothly reviewed in a shorter time.
Even if your code is not complete, you can send a pull request as a work-in-progress PR by putting the [WIP]
prefix to the PR title.
If you write a precise explanation about the PR, core developers and other contributors can join the discussion about how to proceed the PR.
WIP PR is also useful to have discussions based on a concrete code.
Coding Guidelines¶
Note
Coding guidelines are updated at v5.0. Those who have contributed to older versions should read the guidelines again.
We use PEP8 and a part of OpenStack Style Guidelines related to general coding style as our basic style guidelines.
You can use autopep8
and flake8
commands to check your code.
In order to avoid confusion from using different tool versions, we pin the versions of those tools. Install them with the following command (from within the top directory of CuPy repository):
$ pip install -e '.[stylecheck]'
And check your code with:
$ autopep8 path/to/your/code.py
$ flake8 path/to/your/code.py
To check Cython code, use .flake8.cython
configuration file:
$ flake8 --config=.flake8.cython path/to/your/cython/code.pyx
The autopep8
supports automatically correct Python code to conform to the PEP 8 style guide:
$ autopep8 --in-place path/to/your/code.py
The flake8
command lets you know the part of your code not obeying our style guidelines.
Before sending a pull request, be sure to check that your code passes the flake8
checking.
Note that flake8
command is not perfect.
It does not check some of the style guidelines.
Here is a (not-complete) list of the rules that flake8
cannot check.
Relative imports are prohibited. [H304]
Importing non-module symbols is prohibited.
Import statements must be organized into three parts: standard libraries, third-party libraries, and internal imports. [H306]
In addition, we restrict the usage of shortcut symbols in our code base.
They are symbols imported by packages and sub-packages of cupy
.
For example, cupy.cuda.Device
is a shortcut of cupy.cuda.device.Device
.
It is not allowed to use such shortcuts in the ``cupy`` library implementation.
Note that you can still use them in tests and examples directories.
Once you send a pull request, your coding style is automatically checked by GitHub Actions. The reviewing process starts after the check passes.
The CuPy is designed based on NumPy’s API design. CuPy’s source code and documents contain the original NumPy ones. Please note the followings when writing the document.
In order to identify overlapping parts, it is preferable to add some remarks that this document is just copied or altered from the original one. It is also preferable to briefly explain the specification of the function in a short paragraph, and refer to the corresponding function in NumPy so that users can read the detailed document. However, it is possible to include a complete copy of the document with such a remark if users cannot summarize in such a way.
If a function in CuPy only implements a limited amount of features in the original one, users should explicitly describe only what is implemented in the document.
For changes that modify or add new Cython files, please make sure the pointer types follow these guidelines (#1913).
Pointers should be
void*
if only used within Cython, orintptr_t
if exposed to the Python space.Memory sizes should be
size_t
.Memory offsets should be
ptrdiff_t
.
Note
We are incrementally enforcing the above rules, so some existing code may not follow the above guidelines, but please ensure all new contributions do.
Unit Testing¶
Testing is one of the most important part of your code. You must write test cases and verify your implementation by following our testing guide.
Note that we are using pytest and mock package for testing, so install them before writing your code:
$ pip install pytest mock
How to Run Tests¶
In order to run unit tests at the repository root, you first have to build Cython files in place by running the following command:
$ pip install -e .
Note
When you modify *.pxd
files, before running pip install -e .
, you must clean *.cpp
and *.so
files once with the following command, because Cython does not automatically rebuild those files nicely:
$ git clean -fdx
Once Cython modules are built, you can run unit tests by running the following command at the repository root:
$ python -m pytest
CUDA must be installed to run unit tests.
Some GPU tests require cuDNN to run.
In order to skip unit tests that require cuDNN, specify -m='not cudnn'
option:
$ python -m pytest path/to/your/test.py -m='not cudnn'
Some GPU tests involve multiple GPUs.
If you want to run GPU tests with insufficient number of GPUs, specify the number of available GPUs to CUPY_TEST_GPU_LIMIT
.
For example, if you have only one GPU, launch pytest
by the following command to skip multi-GPU tests:
$ export CUPY_TEST_GPU_LIMIT=1
$ python -m pytest path/to/gpu/test.py
Following this naming convention, you can run all the tests by running the following command at the repository root:
$ python -m pytest
Or you can also specify a root directory to search test scripts from:
$ python -m pytest tests/cupy_tests # to just run tests of CuPy
$ python -m pytest tests/install_tests # to just run tests of installation modules
If you modify the code related to existing unit tests, you must run appropriate commands.
Test File and Directory Naming Conventions¶
Tests are put into the tests/cupy_tests directory. In order to enable test runner to find test scripts correctly, we are using special naming convention for the test subdirectories and the test scripts.
The name of each subdirectory of
tests
must end with the_tests
suffix.The name of each test script must start with the
test_
prefix.
When we write a test for a module, we use the appropriate path and file name for the test script whose correspondence to the tested module is clear.
For example, if you want to write a test for a module cupy.x.y.z
, the test script must be located at tests/cupy_tests/x_tests/y_tests/test_z.py
.
How to Write Tests¶
There are many examples of unit tests under the tests directory, so reading some of them is a good and recommended way to learn how to write tests for CuPy.
They simply use the unittest
package of the standard library, while some tests are using utilities from cupy.testing
.
In addition to the Coding Guidelines mentioned above, the following rules are applied to the test code:
All test classes must inherit from
unittest.TestCase
.Use
unittest
features to write tests, except for the following cases:Use
assert
statement instead ofself.assert*
methods (e.g., writeassert x == 1
instead ofself.assertEqual(x, 1)
).Use
with pytest.raises(...):
instead ofwith self.assertRaises(...):
.
Note
We are incrementally applying the above style.
Some existing tests may be using the old style (self.assertRaises
, etc.), but all newly written tests should follow the above style.
Even if your patch includes GPU-related code, your tests should not fail without GPU capability.
Test functions that require CUDA must be tagged by the cupy.testing.attr.gpu
:
import unittest
from cupy.testing import attr
class TestMyFunc(unittest.TestCase):
...
@attr.gpu
def test_my_gpu_func(self):
...
The functions tagged by the gpu
decorator are skipped if CUPY_TEST_GPU_LIMIT=0
environment variable is set.
We also have the cupy.testing.attr.cudnn
decorator to let pytest
know that the test depends on cuDNN.
The test functions decorated by cudnn
are skipped if -m='not cudnn'
is given.
The test functions decorated by gpu
must not depend on multiple GPUs.
In order to write tests for multiple GPUs, use cupy.testing.attr.multi_gpu()
decorators instead:
import unittest
from cupy.testing import attr
class TestMyFunc(unittest.TestCase):
...
@attr.multi_gpu(2) # specify the number of required GPUs here
def test_my_two_gpu_func(self):
...
If your test requires too much time, add cupy.testing.attr.slow
decorator.
The test functions decorated by slow
are skipped if -m='not slow'
is given:
import unittest
from cupy.testing import attr
class TestMyFunc(unittest.TestCase):
...
@attr.slow
def test_my_slow_func(self):
...
Note
If you want to specify more than two attributes, use and
operator like -m='not cudnn and not slow'
.
See detail in the document of pytest.
Once you send a pull request, Travis-CI automatically checks if your code meets our coding guidelines described above. Since Travis-CI does not support CUDA, we cannot run unit tests automatically. The reviewing process starts after the automatic check passes. Note that reviewers will test your code without the option to check CUDA-related code.
Note
Some of numerically unstable tests might cause errors irrelevant to your changes. In such a case, we ignore the failures and go on to the review process, so do not worry about it!
Documentation¶
When adding a new feature to the framework, you also need to document it in the reference.
Note
If you are unsure about how to fix the documentation, you can submit a pull request without doing so. Reviewers will help you fix the documentation appropriately.
The documentation source is stored under docs directory and written in reStructuredText format.
To build the documentation, you need to install Sphinx:
$ pip install -r docs/requirements.txt
Then you can build the documentation in HTML format locally:
$ cd docs
$ make html
HTML files are generated under build/html
directory.
Open index.html
with the browser and see if it is rendered as expected.
Note
Docstrings (documentation comments in the source code) are collected from the installed CuPy module. If you modified docstrings, make sure to install the module (e.g., using pip install -e .) before building the documentation.
Tips for Developers¶
Here are some tips for developers hacking CuPy source code.
Install as Editable¶
During the development we recommend using pip
with -e
option to install as editable mode:
$ pip install -e .
Please note that even with -e
, you will have to rerun pip install -e .
to regenerate C++ sources using Cython if you modified Cython source files (e.g., *.pyx
files).
Use ccache¶
NVCC
environment variable can be specified at the build time to use the custom command instead of nvcc
.
You can speed up the rebuild using ccache (v3.4 or later) by:
$ export NVCC='ccache nvcc'
Limit Architecture¶
Use CUPY_NVCC_GENERATE_CODE
environment variable to reduce the build time by limiting the target CUDA architectures.
For example, if you only run your CuPy build with NVIDIA P100 and V100, you can use:
$ export CUPY_NVCC_GENERATE_CODE=arch=compute_60,code=sm_60;arch=compute_70,code=sm_70
See Environment variables for the description.
Upgrade Guide¶
This page covers changes introduced in each major version that users should know when migrating from older releases. Please see also the Compatibility Matrix for supported environments of each major version.
CuPy v10¶
Dropping CUDA 9.2 / 10.0 / 10.1 Support¶
CUDA 10.1 or earlier is no longer supported. Use CUDA 10.2 or later.
Dropping NCCL v2.4 / v2.6 / v2.7 Support¶
NCCL v2.4, v2.6, and v2.7 are no longer supported.
Dropping Python 3.6 Support¶
Python 3.6 is no longer supported.
Dropping NumPy 1.17 Support¶
NumPy 1.17 is no longer supported.
Change in cupy.cuda.Device
Behavior¶
Current device set via use()
will not be restored when exiting with
block¶
The current device set via cupy.cuda.Device.use()
will not be reactivated when exiting a device context manager. An existing code mixing with device:
block and device.use()
may get different results between CuPy v10 and v9.
with cupy.cuda.Device(1) as d1:
d2 = cupy.cuda.Device(0).use()
with d1:
pass
cupy.cuda.Device() # -> CuPy v10 returns device 1 instead of device 0
Changes in cupy.cuda.Stream
Behavior¶
Stream is now managed per-device¶
Previoulys, it was users’ responsibility to keep the current stream consistent with the current CUDA device. For example, the following code raises an error in CuPy v9 or earlier:
import cupy
with cupy.cuda.Device(0):
# Create a stream on device 0.
s0 = cupy.cuda.Stream()
with cupy.cuda.Device(1):
with s0:
# Try to use the stream on device 1
cupy.arange(10) # -> CUDA_ERROR_INVALID_HANDLE: invalid resource handle
CuPy v10 manages the current stream per-device, thus eliminating the need of switching the stream every time the active device is changed. When using CuPy v10, the above example behaves differently because whenever a stream is created, it is automatically associated with the current device and will be ignored when switching devices. In early versions, trying to use s0 in device 1 raises an error because s0 is associated with device 0. However, in v10, this s0 is ignored and the default stream for device 1 will be used instead.
Current stream set via use()
will not be restored when exiting with
block¶
Samely as the change of cupy.cuda.Device
above, the current stream set via cupy.cuda.Stream.use()
will not be reactivated when exiting a stream context manager.
An existing code mixing with stream:
block and stream.use()
may get different results between CuPy v10 and v9.
s1 = cupy.cuda.Stream()
s2 = cupy.cuda.Stream()
s3 = cupy.cuda.Stream()
with s1:
s2.use()
with s3:
pass
cupy.cuda.get_current_stream() # -> CuPy v10 returns `s1` instead of `s2`.
Big-Endian Arrays Automatically Converted to Little-Endian¶
cupy.array()
, cupy.asarray()
and its variants now always transfer the data to GPU in little-endian byte order.
Previously CuPy was copying the given numpy.ndarray
to GPU as-is, regardless of the endianness.
In CuPy v10, big-endian arrays are converted to little-endian before the transfer, which is the native byte order on GPUs.
This change eliminates the need to manually change the array endianness before creating the CuPy array.
Baseline API Update¶
Baseline API has been bumped from NumPy 1.20 and SciPy 1.6 to NumPy 1.21 and SciPy 1.7. CuPy v10 will follow the upstream products’ specifications of these baseline versions.
API Changes¶
Device synchronize detection APIs (
cupyx.allow_synchronize()
andcupyx.DeviceSynchronized
), introduced as an experimental feature in CuPy v8, have been marked as deprecated because it is impossible to detect synchronizations reliably.An internal API
cupy.cuda.compile_with_cache()
has been marked as deprecated as there are better alternatives (seeRawModule
added since CuPy v7 andRawKernel
since v5). While it has a longstanding history, this API has never been meant to be public. We encourage downstream libraries and users to migrate to the aforementioned public APIs. See User-Defined Kernels for their tutorials.The DLPack routine
cupy.fromDlpack()
is deprecated in favor ofcupy.from_dlpack()
, which addresses potential data race issues.A new module
cupyx.profiler
is added to host all profiling related APIs in CuPy. Accordingly, the following APIs are relocated to this module as follows. The old routines are deprecated.cupy.ndarray.__pos__()
now returns a copy (samely ascupy.positive()
) instead of returningself
.
Note that deprecated APIs may be removed in the future CuPy releases.
Update of Docker Images¶
CuPy official Docker images (see Installation for details) are now updated to use CUDA 11.4 and ROCm 4.3.
CuPy v9¶
Dropping Support of CUDA 9.0¶
CUDA 9.0 is no longer supported. Use CUDA 9.2 or later.
Dropping Support of cuDNN v7.5 and NCCL v2.3¶
cuDNN v7.5 (or earlier) and NCCL v2.3 (or earlier) are no longer supported.
Dropping Support of NumPy 1.16 and SciPy 1.3¶
NumPy 1.16 and SciPy 1.3 are no longer supported.
Dropping Support of Python 3.5¶
Python 3.5 is no longer supported in CuPy v9.
NCCL and cuDNN No Longer Included in Wheels¶
NCCL and cuDNN shared libraires are no longer included in wheels (see #4850 for discussions). You can manually install them after installing wheel if you don’t have a previous installation; see Installation for details.
cuTENSOR Enabled in Wheels¶
cuTENSOR can now be used when installing CuPy via wheels.
cupy.cuda.{nccl,cudnn}
Modules Needs Explicit Import¶
Previously cupy.cuda.nccl
and cupy.cuda.cudnn
modules were automatically imported.
Since CuPy v9, these modules need to be explicitly imported (i.e., import cupy.cuda.nccl
/ import cupy.cuda.cudnn
.)
Baseline API Update¶
Baseline API has been bumped from NumPy 1.19 and SciPy 1.5 to NumPy 1.20 and SciPy 1.6. CuPy v9 will follow the upstream products’ specifications of these baseline versions.
Following NumPy 1.20, aliases for the Python scalar types (cupy.bool
, cupy.int
, cupy.float
, and cupy.complex
) are now deprecated.
cupy.bool_
, cupy.int_
, cupy.float_
and cupy.complex_
should be used instead when required.
Update of Docker Images¶
CuPy official Docker images (see Installation for details) are now updated to use CUDA 11.2 and Python 3.8.
CuPy v8¶
Dropping Support of CUDA 8.0 and 9.1¶
CUDA 8.0 and 9.1 are no longer supported. Use CUDA 9.0, 9.2, 10.0, or later.
Dropping Support of NumPy 1.15 and SciPy 1.2¶
NumPy 1.15 (or earlier) and SciPy 1.2 (or earlier) are no longer supported.
Update of Docker Images¶
CuPy official Docker images (see Installation for details) are now updated to use CUDA 10.2 and Python 3.6.
SciPy and Optuna are now pre-installed.
CUB Support and Compiler Requirement¶
CUB module is now built by default.
You can enable the use of CUB by setting CUPY_ACCELERATORS="cub"
(see CUPY_ACCELERATORS
for details).
Due to this change, g++-6 or later is required when building CuPy from the source. See Installation for details.
The following environment variables are no longer effective:
CUB_DISABLED
: UseCUPY_ACCELERATORS
as aforementioned.CUB_PATH
: No longer required as CuPy uses either the CUB source bundled with CUDA (only when using CUDA 11.0 or later) or the one in the CuPy distribution.
API Changes¶
cupy.scatter_add
, which was deprecated in CuPy v4, has been removed. Usecupyx.scatter_add()
instead.cupy.sparse
module has been deprecated and will be removed in future releases. Usecupyx.scipy.sparse
instead.dtype
argument ofcupy.ndarray.min()
andcupy.ndarray.max()
has been removed to align with the NumPy specification.cupy.allclose()
now returns the result as 0-dim GPU array instead of Python bool to avoid device synchronization.cupy.RawModule
now delays the compilation to the time of the first call to align the behavior withcupy.RawKernel
.cupy.cuda.*_enabled
flags (nccl_enabled
,nvtx_enabled
, etc.) has been deprecated. Usecupy.cuda.*.available
flag (cupy.cuda.nccl.available
,cupy.cuda.nvtx.available
, etc.) instead.CHAINER_SEED
environment variable is no longer effective. UseCUPY_SEED
instead.
CuPy v7¶
Dropping Support of Python 2.7 and 3.4¶
Starting from CuPy v7, Python 2.7 and 3.4 are no longer supported as it reaches its end-of-life (EOL) in January 2020 (2.7) and March 2019 (3.4). Python 3.5.1 is the minimum Python version supported by CuPy v7. Please upgrade the Python version if you are using affected versions of Python to any later versions listed under Installation.
CuPy v6¶
Binary Packages Ignore LD_LIBRARY_PATH
¶
Prior to CuPy v6, LD_LIBRARY_PATH
environment variable can be used to override cuDNN / NCCL libraries bundled in the binary distribution (also known as wheels).
In CuPy v6, LD_LIBRARY_PATH
will be ignored during discovery of cuDNN / NCCL; CuPy binary distributions always use libraries that comes with the package to avoid errors caused by unexpected override.
CuPy v5¶
cupyx.scipy
Namespace¶
cupyx.scipy
namespace has been introduced to provide CUDA-enabled SciPy functions.
cupy.sparse
module has been renamed to cupyx.scipy.sparse
; cupy.sparse
will be kept as an alias for backward compatibility.
Dropped Support for CUDA 7.0 / 7.5¶
CuPy v5 no longer supports CUDA 7.0 / 7.5.
Update of Docker Images¶
CuPy official Docker images (see Installation for details) are now updated to use CUDA 9.2 and cuDNN 7.
To use these images, you may need to upgrade the NVIDIA driver on your host. See Requirements of nvidia-docker for details.
CuPy v4¶
Note
The version number has been bumped from v2 to v4 to align with the versioning of Chainer. Therefore, CuPy v3 does not exist.
Default Memory Pool¶
Prior to CuPy v4, memory pool was only enabled by default when CuPy is used with Chainer. In CuPy v4, memory pool is now enabled by default, even when you use CuPy without Chainer. The memory pool significantly improves the performance by mitigating the overhead of memory allocation and CPU/GPU synchronization.
Attention
When you monitor GPU memory usage (e.g., using nvidia-smi
), you may notice that GPU memory not being freed even after the array instance become out of scope.
This is expected behavior, as the default memory pool “caches” the allocated memory blocks.
To access the default memory pool instance, use get_default_memory_pool()
and get_default_pinned_memory_pool()
.
You can access the statistics and free all unused memory blocks “cached” in the memory pool.
import cupy
a = cupy.ndarray(100, dtype=cupy.float32)
mempool = cupy.get_default_memory_pool()
# For performance, the size of actual allocation may become larger than the requested array size.
print(mempool.used_bytes()) # 512
print(mempool.total_bytes()) # 512
# Even if the array goes out of scope, its memory block is kept in the pool.
a = None
print(mempool.used_bytes()) # 0
print(mempool.total_bytes()) # 512
# You can clear the memory block by calling `free_all_blocks`.
mempool.free_all_blocks()
print(mempool.used_bytes()) # 0
print(mempool.total_bytes()) # 0
You can even disable the default memory pool by the code below. Be sure to do this before any other CuPy operations.
import cupy
cupy.cuda.set_allocator(None)
cupy.cuda.set_pinned_memory_allocator(None)
Compute Capability¶
CuPy v4 now requires NVIDIA GPU with Compute Capability 3.0 or larger. See the List of CUDA GPUs to check if your GPU supports Compute Capability 3.0.
CUDA Stream¶
As CUDA Stream is fully supported in CuPy v4, cupy.cuda.RandomState.set_stream
, the function to change the stream used by the random number generator, has been removed.
Please use cupy.cuda.Stream.use()
instead.
See the discussion in #306 for more details.
cupyx
Namespace¶
cupyx
namespace has been introduced to provide features specific to CuPy (i.e., features not provided in NumPy) while avoiding collision in future.
See CuPy-specific functions for the list of such functions.
For this rule, cupy.scatter_add()
has been moved to cupyx.scatter_add()
.
cupy.scatter_add()
is still available as an alias, but it is encouraged to use cupyx.scatter_add()
instead.
Update of Docker Images¶
CuPy official Docker images (see Installation for details) are now updated to use CUDA 8.0 and cuDNN 6.0. This change was introduced because CUDA 7.5 does not support NVIDIA Pascal GPUs.
To use these images, you may need to upgrade the NVIDIA driver on your host. See Requirements of nvidia-docker for details.
CuPy v2¶
Changed Behavior of count_nonzero Function¶
For performance reasons, cupy.count_nonzero()
has been changed to return zero-dimensional ndarray
instead of int when axis=None.
See the discussion in #154 for more details.
Compatibility Matrix¶
CuPy |
CC 1 |
CUDA |
ROCm |
cuTENSOR |
NCCL |
cuDNN |
Python |
NumPy |
SciPy |
Baseline API Spec. |
Docs |
---|---|---|---|---|---|---|---|---|---|---|---|
v10 |
3.0~ |
10.2~ |
4.0~ |
1.3~ |
2.8~ |
7.6~ |
3.7~ |
1.18~ |
1.4~ |
NumPy 1.21 & SciPy 1.7 |
|
v9 |
3.0~8.x |
9.2~11.5 |
3.5~4.3 |
1.2~1.3 |
2.4 & 2.6~2.11 |
7.6~8.2 |
3.6~3.9 |
1.17~1.21 |
1.4~1.7 |
NumPy 1.20 & SciPy 1.6 |
|
v8 |
3.0~8.x |
9.0 & 9.2~11.2 |
3.x 2 |
1.2 |
2.0~2.8 |
7.0~8.1 |
3.5~3.9 |
1.16~1.20 |
1.3~1.6 |
NumPy 1.19 & SciPy 1.5 |
|
v7 |
3.0~8.x |
8.0~11.0 |
2.x 2 |
1.0 |
1.3~2.7 |
5.0~8.0 |
3.5~3.8 |
1.9~1.19 |
(not specified) |
(not specified) |
|
v6 |
3.0~7.x |
8.0~10.1 |
n/a |
n/a |
1.3~2.4 |
5.0~7.5 |
2.7 & 3.4~3.8 |
1.9~1.17 |
(not specified) |
(not specified) |
|
v5 |
3.0~7.x |
8.0~10.1 |
n/a |
n/a |
1.3~2.4 |
5.0~7.5 |
2.7 & 3.4~3.7 |
1.9~1.16 |
(not specified) |
(not specified) |
|
v4 |
3.0~7.x |
7.0~9.2 |
n/a |
n/a |
1.3~2.2 |
4.0~7.1 |
2.7 & 3.4~3.6 |
1.9~1.14 |
(not specified) |
(not specified) |
License¶
Copyright (c) 2015 Preferred Infrastructure, Inc.
Copyright (c) 2015 Preferred Networks, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
NumPy¶
The CuPy is designed based on NumPy’s API. CuPy’s source code and documents contain the original NumPy ones.
Copyright (c) 2005-2016, NumPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name of the NumPy Developers nor the names of any contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
SciPy¶
The CuPy is designed based on SciPy’s API. CuPy’s source code and documents contain the original SciPy ones.
Copyright (c) 2001, 2002 Enthought, Inc.
All rights reserved.
Copyright (c) 2003-2016 SciPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.