CuPy – A NumPy-compatible array library accelerated by CUDA

Overview

CuPy is an implementation of NumPy-compatible multi-dimensional array on CUDA. CuPy consists of cupy.ndarray, the core multi-dimensional array class, and many functions on it. It supports a subset of numpy.ndarray interface.

The following is a brief overview of supported subset of NumPy interface:

  • Basic indexing (indexing by ints, slices, newaxes, and Ellipsis)

  • Most of Advanced indexing (except for some indexing patterns with boolean masks)

  • Data types (dtypes): bool_, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64, complex64, complex128

  • Most of the array creation routines (empty, ones_like, diag, etc.)

  • Most of the array manipulation routines (reshape, rollaxis, concatenate, etc.)

  • All operators with broadcasting

  • All universal functions for elementwise operations (except those for complex numbers).

  • Linear algebra functions, including product (dot, matmul, etc.) and decomposition (cholesky, svd, etc.), accelerated by cuBLAS.

  • Reduction along axes (sum, max, argmax, etc.)

CuPy also includes the following features for performance:

  • User-defined elementwise CUDA kernels

  • User-defined reduction CUDA kernels

  • Fusing CUDA kernels to optimize user-defined calculation

  • Customizable memory allocator and memory pool

  • cuDNN utilities

CuPy uses on-the-fly kernel synthesis: when a kernel call is required, it compiles a kernel code optimized for the shapes and dtypes of given arguments, sends it to the GPU device, and executes the kernel. The compiled code is cached to $(HOME)/.cupy/kernel_cache directory (this cache path can be overwritten by setting the CUPY_CACHE_DIR environment variable). It may make things slower at the first kernel call, though this slow down will be resolved at the second execution. CuPy also caches the kernel code sent to GPU device within the process, which reduces the kernel transfer time on further calls.

Installation Guide

Requirements

The following Linux distributions are recommended.

These components must be installed to use CuPy:

  • NVIDIA CUDA GPU with the Compute Capability 3.0 or larger.

  • CUDA Toolkit: v9.0 / v9.2 / v10.0 / v10.1 / v10.2 / v11.0 / v11.1 / v11.2

  • Python: v3.5.1+ / v3.6.0+ / v3.7.0+ / v3.8.0+ / v3.9.0+

Note

On Windows, CuPy only supports Python 3.6.0 or later.

Python Dependencies

NumPy/SciPy-compatible API in CuPy v8 is based on NumPy 1.19 and SciPy 1.5, and has been tested against the following versions:

Note

SciPy and Optuna are optional dependencies and will not be installed automatically.

Note

Before installing CuPy, we recommend you to upgrade setuptools and pip:

$ python -m pip install -U setuptools pip

Additional CUDA Libraries

Part of the CUDA features in CuPy will be activated only when the corresponding libraries are installed.

  • cuTENSOR: v1.2

  • NCCL: v2.0 / v2.1 / v2.2 / v2.3 / v2.4 / v2.5 / v2.6 / v2.7 / v2.8

    • The library to perform collective multi-GPU / multi-node computations.

  • cuDNN: v7.0 / v7.1 / v7.2 / v7.3 / v7.4 / v7.5 / v7.6 / v8.0 / v8.1

    • The library to accelerate deep neural network computations.

Installing CuPy

Wheels (precompiled binary packages) are available for Linux (x86_64, Python 3.5+) and Windows (amd64, Python 3.6+). Package names are different depending on your CUDA Toolkit version.

CUDA

Command

v9.0

$ pip install cupy-cuda90

v9.2

$ pip install cupy-cuda92

v10.0

$ pip install cupy-cuda100

v10.1

$ pip install cupy-cuda101

v10.2

$ pip install cupy-cuda102

v11.0

$ pip install cupy-cuda110

v11.1

$ pip install cupy-cuda111

Note

Wheel packages are built with NCCL (Linux only) and cuDNN support enabled.

  • NCCL library is bundled with these packages. You don’t have to install it manually.

  • cuDNN library is bundled with these packages except for CUDA 10.1+. For CUDA 10.1+, you need to manually download and install cuDNN v8.x library to use cuDNN features.

Note

Use pip install --pre cupy-cudaXXX if you want to install prerelease (development) versions.

When using wheels, please be careful not to install multiple CuPy packages at the same time. Any of these packages and cupy package (source installation) conflict with each other. Please make sure that only one CuPy package (cupy or cupy-cudaXX where XX is a CUDA version) is installed:

$ pip freeze | grep cupy

Installing CuPy from Conda-Forge

Conda/Anaconda is a cross-platform package management solution widely used in scientific computing and other fields. The above pip install instruction is compatible with conda environments. Alternatively, for Linux 64 systems once the CUDA driver is correctly set up, you can install CuPy from the conda-forge channel:

$ conda install -c conda-forge cupy

and conda will install pre-built CuPy and most of the optional dependencies for you, including CUDA runtime libraries (cudatoolkit), NCCL, and cuDNN. It is not necessary to install CUDA Toolkit in advance. If you need to enforce the installation of a particular CUDA version (say 10.0) for driver compatibility, you can do:

$ conda install -c conda-forge cupy cudatoolkit=10.0

Note

cuTENSOR is available on conda-forge for CUDA 10.1+ and is an optional dependency. To install CuPy with the cuTENSOR support enabled, you can do:

$ conda install -c conda-forge cupy cutensor cudatoolkit=10.2

Note that cupy and cutensor must be installed at the same time (as shown above) in order for the conda solver to pick up the right package; otherwise, the cuTENSOR support is disabled.

Note

If you encounter any problem with CuPy from conda-forge, please feel free to report to cupy-feedstock, and we will help investigate if it is just a packaging issue in conda-forge’s recipe or a real issue in CuPy.

Note

If you did not install CUDA Toolkit yourselves, the nvcc compiler might not be available. The cudatoolkit package from Anaconda does not have nvcc included.

Installing CuPy from Source

Use of wheel packages is recommended whenever possible. However, if wheels cannot meet your requirements (e.g., you are running non-Linux environment or want to use a version of CUDA / cuDNN / NCCL not supported by wheels), you can also build CuPy from source.

Note

CuPy source build requires g++-6 or later. For Ubuntu 18.04, run apt-get install g++. For Ubuntu 16.04, CentOS 6 or 7, follow the instructions here.

Note

When installing CuPy from source, features provided by additional CUDA libraries will be disabled if these libraries are not available at the build time. See Installing cuDNN and NCCL for the instructions.

Note

If you upgrade or downgrade the version of CUDA Toolkit, cuDNN, NCCL or cuTENSOR, you may need to reinstall CuPy. See Reinstalling CuPy for details.

You can install the latest stable release version of the CuPy source package via pip.

$ pip install cupy

If you want to install the latest development version of CuPy from a cloned Git repository:

$ git clone --recursive https://github.com/cupy/cupy.git
$ cd cupy
$ pip install .

Note

To build the source tree downloaded from GitHub, you need to install Cython 0.29.22 or later (pip install cython). You don’t have to install Cython to build source packages hosted on PyPI as they include pre-generated C++ source files.

Uninstalling CuPy

Use pip to uninstall CuPy:

$ pip uninstall cupy

Note

If you are using a wheel, cupy shall be replaced with cupy-cudaXX (where XX is a CUDA version number).

Note

If CuPy is installed via conda, please do conda uninstall cupy instead.

Upgrading CuPy

Just use pip install with -U option:

$ pip install -U cupy

Note

If you are using a wheel, cupy shall be replaced with cupy-cudaXX (where XX is a CUDA version number).

Reinstalling CuPy

To reinstall CuPy, please uninstall CuPy and then install it. When reinstalling CuPy, we recommend using --no-cache-dir option as pip caches the previously built binaries:

$ pip uninstall cupy
$ pip install cupy --no-cache-dir

Note

If you are using a wheel, cupy shall be replaced with cupy-cudaXX (where XX is a CUDA version number).

Using CuPy inside Docker

We are providing the official Docker images. Use NVIDIA Container Toolkit to run CuPy image with GPU. You can login to the environment with bash, and run the Python interpreter:

$ docker run --gpus all -it cupy/cupy /bin/bash

Or run the interpreter directly:

$ docker run --gpus all -it cupy/cupy /usr/bin/python

FAQ

pip fails to install CuPy

Please make sure that you are using the latest setuptools and pip:

$ pip install -U setuptools pip

Use -vvvv option with pip command. This will display all logs of installation:

$ pip install cupy -vvvv

If you are using sudo to install CuPy, note that sudo command does not propagate environment variables. If you need to pass environment variable (e.g., CUDA_PATH), you need to specify them inside sudo like this:

$ sudo CUDA_PATH=/opt/nvidia/cuda pip install cupy

If you are using certain versions of conda, it may fail to build CuPy with error g++: error: unrecognized command line option ‘-R’. This is due to a bug in conda (see conda/conda#6030 for details). If you encounter this problem, please upgrade your conda.

Installing cuDNN and NCCL

We recommend installing cuDNN and NCCL using binary packages (i.e., using apt or yum) provided by NVIDIA.

If you want to install tar-gz version of cuDNN and NCCL, we recommend installing it under the CUDA_PATH directory. For example, if you are using Ubuntu, copy *.h files to include directory and *.so* files to lib64 directory:

$ cp /path/to/cudnn.h $CUDA_PATH/include
$ cp /path/to/libcudnn.so* $CUDA_PATH/lib64

The destination directories depend on your environment.

If you want to use cuDNN or NCCL installed in another directory, please use CFLAGS, LDFLAGS and LD_LIBRARY_PATH environment variables before installing CuPy:

$ export CFLAGS=-I/path/to/cudnn/include
$ export LDFLAGS=-L/path/to/cudnn/lib
$ export LD_LIBRARY_PATH=/path/to/cudnn/lib:$LD_LIBRARY_PATH

Working with Custom CUDA Installation

If you have installed CUDA on the non-default directory or multiple CUDA versions on the same host, you may need to manually specify the CUDA installation directory to be used by CuPy.

CuPy uses the first CUDA installation directory found by the following order.

  1. CUDA_PATH environment variable.

  2. The parent directory of nvcc command. CuPy looks for nvcc command from PATH environment variable.

  3. /usr/local/cuda

For example, you can build CuPy using non-default CUDA directory by CUDA_PATH environment variable:

$ CUDA_PATH=/opt/nvidia/cuda pip install cupy

Note

CUDA installation discovery is also performed at runtime using the rule above. Depending on your system configuration, you may also need to set LD_LIBRARY_PATH environment variable to $CUDA_PATH/lib64 at runtime.

CuPy always raises cupy.cuda.compiler.CompileException

If CuPy raises a CompileException for almost everything, it is possible that CuPy cannot detect CUDA installed on your system correctly. The followings are error messages commonly observed in such cases.

  • nvrtc: error: failed to load builtins

  • catastrophic error: cannot open source file "cuda_fp16.h"

  • error: cannot overload functions distinguished by return type alone

  • error: identifier "__half_raw" is undefined

Please try setting LD_LIBRARY_PATH and CUDA_PATH environment variable. For example, if you have CUDA installed at /usr/local/cuda-9.0:

$ export CUDA_PATH=/usr/local/cuda-9.0
$ export LD_LIBRARY_PATH=$CUDA_PATH/lib64:$LD_LIBRARY_PATH

Also see Working with Custom CUDA Installation.

Build fails on Ubuntu 16.04, CentOS 6 or 7

In order to build CuPy from source on systems with legacy GCC (g++-5 or earlier), you need to manually set up g++-6 or later and configure NVCC environment variable.

On Ubuntu 16.04:

$ sudo add-apt-repository ppa:ubuntu-toolchain-r/test
$ sudo apt update
$ sudo apt install g++-6
$ export NVCC="nvcc --compiler-bindir gcc-6"

On CentOS 6 / 7:

$ sudo yum install centos-release-scl
$ sudo yum install devtoolset-7-gcc-c++
$ source /opt/rh/devtoolset-7/enable
$ export NVCC="nvcc --compiler-bidir gcc-7"

Tutorial

Basics of CuPy

In this section, you will learn about the following things:

  • Basics of cupy.ndarray

  • The concept of current device

  • host-device and device-device array transfer

Basics of cupy.ndarray

CuPy is a GPU array backend that implements a subset of NumPy interface. In the following code, cp is an abbreviation of cupy, as np is numpy as is customarily done:

>>> import numpy as np
>>> import cupy as cp

The cupy.ndarray class is in its core, which is a compatible GPU alternative of numpy.ndarray.

>>> x_gpu = cp.array([1, 2, 3])

x_gpu in the above example is an instance of cupy.ndarray. You can see its creation of identical to NumPy’s one, except that numpy is replaced with cupy. The main difference of cupy.ndarray from numpy.ndarray is that the content is allocated on the device memory. Its data is allocated on the current device, which will be explained later.

Most of the array manipulations are also done in the way similar to NumPy. Take the Euclidean norm (a.k.a L2 norm) for example. NumPy has numpy.linalg.norm() to calculate it on CPU.

>>> x_cpu = np.array([1, 2, 3])
>>> l2_cpu = np.linalg.norm(x_cpu)

We can calculate it on GPU with CuPy in a similar way:

>>> x_gpu = cp.array([1, 2, 3])
>>> l2_gpu = cp.linalg.norm(x_gpu)

CuPy implements many functions on cupy.ndarray objects. See the reference for the supported subset of NumPy API. Understanding NumPy might help utilizing most features of CuPy. So, we recommend you to read the NumPy documentation.

Current Device

CuPy has a concept of the current device, which is the default device on which the allocation, manipulation, calculation etc. of arrays are taken place. Suppose the ID of current device is 0. The following code allocates array contents on GPU 0.

>>> x_on_gpu0 = cp.array([1, 2, 3, 4, 5])

The current device can be changed by cupy.cuda.Device.use() as follows:

>>> x_on_gpu0 = cp.array([1, 2, 3, 4, 5])
>>> cp.cuda.Device(1).use()
>>> x_on_gpu1 = cp.array([1, 2, 3, 4, 5])

If you switch the current GPU temporarily, with statement comes in handy.

>>> with cp.cuda.Device(1):
...    x_on_gpu1 = cp.array([1, 2, 3, 4, 5])
>>> x_on_gpu0 = cp.array([1, 2, 3, 4, 5])

Most operations of CuPy is done on the current device. Be careful that if processing of an array on a non-current device will cause an error:

>>> with cp.cuda.Device(0):
...    x_on_gpu0 = cp.array([1, 2, 3, 4, 5])
>>> with cp.cuda.Device(1):
...    x_on_gpu0 * 2  # raises error
Traceback (most recent call last):
...
ValueError: Array device must be same as the current device: array device = 0 while current = 1

cupy.ndarray.device attribute indicates the device on which the array is allocated.

>>> with cp.cuda.Device(1):
...    x = cp.array([1, 2, 3, 4, 5])
>>> x.device
<CUDA Device 1>

Note

If the environment has only one device, such explicit device switching is not needed.

Data Transfer

Move arrays to a device

cupy.asarray() can be used to move a numpy.ndarray, a list, or any object that can be passed to numpy.array() to the current device:

>>> x_cpu = np.array([1, 2, 3])
>>> x_gpu = cp.asarray(x_cpu)  # move the data to the current device.

cupy.asarray() can accept cupy.ndarray, which means we can transfer the array between devices with this function.

>>> with cp.cuda.Device(0):
...     x_gpu_0 = cp.ndarray([1, 2, 3])  # create an array in GPU 0
>>> with cp.cuda.Device(1):
...     x_gpu_1 = cp.asarray(x_gpu_0)  # move the array to GPU 1

Note

cupy.asarray() does not copy the input array if possible. So, if you put an array of the current device, it returns the input object itself.

If we do copy the array in this situation, you can use cupy.array() with copy=True. Actually cupy.asarray() is equivalent to cupy.array(arr, dtype, copy=False).

Move array from a device to the host

Moving a device array to the host can be done by cupy.asnumpy() as follows:

>>> x_gpu = cp.array([1, 2, 3])  # create an array in the current device
>>> x_cpu = cp.asnumpy(x_gpu)  # move the array to the host.

We can also use cupy.ndarray.get():

>>> x_cpu = x_gpu.get()

How to write CPU/GPU agnostic code

The compatibility of CuPy with NumPy enables us to write CPU/GPU generic code. It can be made easy by the cupy.get_array_module() function. This function returns the numpy or cupy module based on arguments. A CPU/GPU generic function is defined using it like follows:

>>> # Stable implementation of log(1 + exp(x))
>>> def softplus(x):
...     xp = cp.get_array_module(x)
...     return xp.maximum(0, x) + xp.log1p(xp.exp(-abs(x)))

Sometimes, an explicit conversion to a host or device array may be required. cupy.asarray() and cupy.asnumpy() can be used in agnostic implementations to get host or device arrays from either CuPy or NumPy arrays.

>>> y_cpu = np.array([4, 5, 6])
>>> x_cpu + y_cpu
array([5, 7, 9])
>>> x_gpu + y_cpu
Traceback (most recent call last):
...
TypeError: Unsupported type <class 'numpy.ndarray'>
>>> cp.asnumpy(x_gpu) + y_cpu
array([5, 7, 9])
>>> cp.asnumpy(x_gpu) + cp.asnumpy(y_cpu)
array([5, 7, 9])
>>> x_gpu + cp.asarray(y_cpu)
array([5, 7, 9])
>>> cp.asarray(x_gpu) + cp.asarray(y_cpu)
array([5, 7, 9])

User-Defined Kernels

CuPy provides easy ways to define three types of CUDA kernels: elementwise kernels, reduction kernels and raw kernels. In this documentation, we describe how to define and call each kernels.

Basics of elementwise kernels

An elementwise kernel can be defined by the ElementwiseKernel class. The instance of this class defines a CUDA kernel which can be invoked by the __call__ method of this instance.

A definition of an elementwise kernel consists of four parts: an input argument list, an output argument list, a loop body code, and the kernel name. For example, a kernel that computes a squared difference \(f(x, y) = (x - y)^2\) is defined as follows:

>>> squared_diff = cp.ElementwiseKernel(
...    'float32 x, float32 y',
...    'float32 z',
...    'z = (x - y) * (x - y)',
...    'squared_diff')

The argument lists consist of comma-separated argument definitions. Each argument definition consists of a type specifier and an argument name. Names of NumPy data types can be used as type specifiers.

Note

n, i, and names starting with an underscore _ are reserved for the internal use.

The above kernel can be called on either scalars or arrays with broadcasting:

>>> x = cp.arange(10, dtype=np.float32).reshape(2, 5)
>>> y = cp.arange(5, dtype=np.float32)
>>> squared_diff(x, y)
array([[ 0.,  0.,  0.,  0.,  0.],
       [25., 25., 25., 25., 25.]], dtype=float32)
>>> squared_diff(x, 5)
array([[25., 16.,  9.,  4.,  1.],
       [ 0.,  1.,  4.,  9., 16.]], dtype=float32)

Output arguments can be explicitly specified (next to the input arguments):

>>> z = cp.empty((2, 5), dtype=np.float32)
>>> squared_diff(x, y, z)
array([[ 0.,  0.,  0.,  0.,  0.],
       [25., 25., 25., 25., 25.]], dtype=float32)

Type-generic kernels

If a type specifier is one character, then it is treated as a type placeholder. It can be used to define a type-generic kernels. For example, the above squared_diff kernel can be made type-generic as follows:

>>> squared_diff_generic = cp.ElementwiseKernel(
...     'T x, T y',
...     'T z',
...     'z = (x - y) * (x - y)',
...     'squared_diff_generic')

Type placeholders of a same character in the kernel definition indicate the same type. The actual type of these placeholders is determined by the actual argument type. The ElementwiseKernel class first checks the output arguments and then the input arguments to determine the actual type. If no output arguments are given on the kernel invocation, then only the input arguments are used to determine the type.

The type placeholder can be used in the loop body code:

>>> squared_diff_generic = cp.ElementwiseKernel(
...     'T x, T y',
...     'T z',
...     '''
...         T diff = x - y;
...         z = diff * diff;
...     ''',
...     'squared_diff_generic')

More than one type placeholder can be used in a kernel definition. For example, the above kernel can be further made generic over multiple arguments:

>>> squared_diff_super_generic = cp.ElementwiseKernel(
...     'X x, Y y',
...     'Z z',
...     'z = (x - y) * (x - y)',
...     'squared_diff_super_generic')

Note that this kernel requires the output argument explicitly specified, because the type Z cannot be automatically determined from the input arguments.

Raw argument specifiers

The ElementwiseKernel class does the indexing with broadcasting automatically, which is useful to define most elementwise computations. On the other hand, we sometimes want to write a kernel with manual indexing for some arguments. We can tell the ElementwiseKernel class to use manual indexing by adding the raw keyword preceding the type specifier.

We can use the special variable i and method _ind.size() for the manual indexing. i indicates the index within the loop. _ind.size() indicates total number of elements to apply the elementwise operation. Note that it represents the size after broadcast operation.

For example, a kernel that adds two vectors with reversing one of them can be written as follows:

>>> add_reverse = cp.ElementwiseKernel(
...     'T x, raw T y', 'T z',
...     'z = x + y[_ind.size() - i - 1]',
...     'add_reverse')

(Note that this is an artificial example and you can write such operation just by z = x + y[::-1] without defining a new kernel). A raw argument can be used like an array. The indexing operator y[_ind.size() - i - 1] involves an indexing computation on y, so y can be arbitrarily shaped and strode.

Note that raw arguments are not involved in the broadcasting. If you want to mark all arguments as raw, you must specify the size argument on invocation, which defines the value of _ind.size().

Reduction kernels

Reduction kernels can be defined by the ReductionKernel class. We can use it by defining four parts of the kernel code:

  1. Identity value: This value is used for the initial value of reduction.

  2. Mapping expression: It is used for the pre-processing of each element to be reduced.

  3. Reduction expression: It is an operator to reduce the multiple mapped values. The special variables a and b are used for its operands.

  4. Post mapping expression: It is used to transform the resulting reduced values. The special variable a is used as its input. Output should be written to the output parameter.

ReductionKernel class automatically inserts other code fragments that are required for an efficient and flexible reduction implementation.

For example, L2 norm along specified axes can be written as follows:

>>> l2norm_kernel = cp.ReductionKernel(
...     'T x',  # input params
...     'T y',  # output params
...     'x * x',  # map
...     'a + b',  # reduce
...     'y = sqrt(a)',  # post-reduction map
...     '0',  # identity value
...     'l2norm'  # kernel name
... )
>>> x = cp.arange(10, dtype=np.float32).reshape(2, 5)
>>> l2norm_kernel(x, axis=1)
array([ 5.477226 , 15.9687195], dtype=float32)

Note

raw specifier is restricted for usages that the axes to be reduced are put at the head of the shape. It means, if you want to use raw specifier for at least one argument, the axis argument must be 0 or a contiguous increasing sequence of integers starting from 0, like (0, 1), (0, 1, 2), etc.

Raw kernels

Raw kernels can be defined by the RawKernel class. By using raw kernels, you can define kernels from raw CUDA source.

RawKernel object allows you to call the kernel with CUDA’s cuLaunchKernel interface. In other words, you have control over grid size, block size, shared memory size and stream.

>>> add_kernel = cp.RawKernel(r'''
... extern "C" __global__
... void my_add(const float* x1, const float* x2, float* y) {
...     int tid = blockDim.x * blockIdx.x + threadIdx.x;
...     y[tid] = x1[tid] + x2[tid];
... }
... ''', 'my_add')
>>> x1 = cp.arange(25, dtype=cp.float32).reshape(5, 5)
>>> x2 = cp.arange(25, dtype=cp.float32).reshape(5, 5)
>>> y = cp.zeros((5, 5), dtype=cp.float32)
>>> add_kernel((5,), (5,), (x1, x2, y))  # grid, block and arguments
>>> y
array([[ 0.,  2.,  4.,  6.,  8.],
       [10., 12., 14., 16., 18.],
       [20., 22., 24., 26., 28.],
       [30., 32., 34., 36., 38.],
       [40., 42., 44., 46., 48.]], dtype=float32)

Raw kernels operating on complex-valued arrays can be created as well:

>>> complex_kernel = cp.RawKernel(r'''
... #include <cupy/complex.cuh>
... extern "C" __global__
... void my_func(const complex<float>* x1, const complex<float>* x2,
...              complex<float>* y, float a) {
...     int tid = blockDim.x * blockIdx.x + threadIdx.x;
...     y[tid] = x1[tid] + a * x2[tid];
... }
... ''', 'my_func')
>>> x1 = cupy.arange(25, dtype=cupy.complex64).reshape(5, 5)
>>> x2 = 1j*cupy.arange(25, dtype=cupy.complex64).reshape(5, 5)
>>> y = cupy.zeros((5, 5), dtype=cupy.complex64)
>>> complex_kernel((5,), (5,), (x1, x2, y, cupy.float32(2.0)))  # grid, block and arguments
>>> y
array([[ 0. +0.j,  1. +2.j,  2. +4.j,  3. +6.j,  4. +8.j],
       [ 5.+10.j,  6.+12.j,  7.+14.j,  8.+16.j,  9.+18.j],
       [10.+20.j, 11.+22.j, 12.+24.j, 13.+26.j, 14.+28.j],
       [15.+30.j, 16.+32.j, 17.+34.j, 18.+36.j, 19.+38.j],
       [20.+40.j, 21.+42.j, 22.+44.j, 23.+46.j, 24.+48.j]],
      dtype=complex64)

Note that while we encourage the usage of complex<T> types for complex numbers (available by including <cupy/complex.cuh> as shown above), for CUDA codes already written using functions from cuComplex.h there is no need to make the conversion yourself: just set the option translate_cucomplex=True when creating a RawKernel instance.

The CUDA kernel attributes can be retrieved by either accessing the attributes dictionary, or by accessing the RawKernel object’s attributes directly; the latter can also be used to set certain attributes:

>>> add_kernel = cp.RawKernel(r'''
... extern "C" __global__
... void my_add(const float* x1, const float* x2, float* y) {
...     int tid = blockDim.x * blockIdx.x + threadIdx.x;
...     y[tid] = x1[tid] + x2[tid];
... }
... ''', 'my_add')
>>> add_kernel.attributes  
{'max_threads_per_block': 1024, 'shared_size_bytes': 0, 'const_size_bytes': 0, 'local_size_bytes': 0, 'num_regs': 10, 'ptx_version': 70, 'binary_version': 70, 'cache_mode_ca': 0, 'max_dynamic_shared_size_bytes': 49152, 'preferred_shared_memory_carveout': -1}
>>> add_kernel.max_dynamic_shared_size_bytes  
49152
>>> add_kernel.max_dynamic_shared_size_bytes = 50000  # set a new value for the attribute  
>>> add_kernel.max_dynamic_shared_size_bytes  
50000

Dynamical parallelism is supported by RawKernel. You just need to provide the linking flag (such as -dc) to RawKernel’s options arugment. The static CUDA device runtime library (cudadevrt) is automatically discovered by CuPy. For further detail, see CUDA Toolkit’s documentation.

Accessing texture (surface) memory in RawKernel is supported via CUDA Runtime’s Texture (Surface) Object API, see the documentation for TextureObject (SurfaceObject) as well as CUDA C Programming Guide. For using the Texture Reference API, which is marked as deprecated as of CUDA Toolkit 10.1, see the introduction to RawModule below.

Note

The kernel does not have return values. You need to pass both input arrays and output arrays as arguments.

Note

No validation will be performed by CuPy for arguments passed to the kernel, including types and number of arguments. Especially note that when passing ndarray, its dtype should match with the type of the argument declared in the method signature of the CUDA source code (unless you are casting arrays intentionally). For example, cupy.float32 and cupy.uint64 arrays must be passed to the argument typed as float* and unsigned long long*. For Python primitive types, int, float, complex and bool map to long long, double, cuDoubleComplex and bool, respectively.

Note

When using printf() in your CUDA kernel, you may need to synchronize the stream to see the output. You can use cupy.cuda.Stream.null.synchronize() if you are using the default stream.

Raw modules

For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule class can be more handy. It can be initialized either by a CUDA source code, or by a path to the CUDA binary. The needed kernels can then be retrieved by calling the get_function() method, which returns a RawKernel instance that can be invoked as discussed above.

>>> loaded_from_source = r'''
... extern "C"{
...
... __global__ void test_sum(const float* x1, const float* x2, float* y, \
...                          unsigned int N)
... {
...     unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
...     if (tid < N)
...     {
...         y[tid] = x1[tid] + x2[tid];
...     }
... }
...
... __global__ void test_multiply(const float* x1, const float* x2, float* y, \
...                               unsigned int N)
... {
...     unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
...     if (tid < N)
...     {
...         y[tid] = x1[tid] * x2[tid];
...     }
... }
...
... }'''
>>> module = cp.RawModule(code=loaded_from_source)
>>> ker_sum = module.get_function('test_sum')
>>> ker_times = module.get_function('test_multiply')
>>> N = 10
>>> x1 = cp.arange(N**2, dtype=cp.float32).reshape(N, N)
>>> x2 = cp.ones((N, N), dtype=cp.float32)
>>> y = cp.zeros((N, N), dtype=cp.float32)
>>> ker_sum((N,), (N,), (x1, x2, y, N**2))   # y = x1 + x2
>>> assert cp.allclose(y, x1 + x2)
>>> ker_times((N,), (N,), (x1, x2, y, N**2)) # y = x1 * x2
>>> assert cp.allclose(y, x1 * x2)

The instruction above for using complex numbers in RawKernel also applies to RawModule.

For CUDA kernels that need to access global symbols, such as constant memory, the get_global() method can be used, see its documentation for further detail.

CuPy also supports the Texture Reference API. A handle to the texture reference in a module can be retrieved by name via get_texref(). Then, you need to pass it to TextureReference, along with a resource descriptor and texture descriptor, for binding the reference to the array. (The interface of TextureReference is meant to mimic that of TextureObject to help users make transition to the latter, since as of CUDA Toolkit 10.1 the former is marked as deprecated.)

Kernel fusion

cupy.fuse() is a decorator that fuses functions. This decorator can be used to define an elementwise or reduction kernel more easily than ElementwiseKernel or ReductionKernel.

By using this decorator, we can define the squared_diff kernel as follows:

>>> @cp.fuse()
... def squared_diff(x, y):
...     return (x - y) * (x - y)

The above kernel can be called on either scalars, NumPy arrays or CuPy arrays likes the original function.

>>> x_cp = cp.arange(10)
>>> y_cp = cp.arange(10)[::-1]
>>> squared_diff(x_cp, y_cp)
array([81, 49, 25,  9,  1,  1,  9, 25, 49, 81])
>>> x_np = np.arange(10)
>>> y_np = np.arange(10)[::-1]
>>> squared_diff(x_np, y_np)
array([81, 49, 25,  9,  1,  1,  9, 25, 49, 81])

At the first function call, the fused function analyzes the original function based on the abstracted information of arguments (e.g. their dtypes and ndims) and creates and caches an actual CUDA kernel. From the second function call with the same input types, the fused function calls the previously cached kernel, so it is highly recommended to reuse the same decorated functions instead of decorating local functions that are defined multiple times.

cupy.fuse() also supports simple reduction kernel.

>>> @cp.fuse()
... def sum_of_products(x, y):
...     return cp.sum(x * y, axis = -1)

You can specify the kernel name by using the kernel_name keyword argument as follows:

>>> @cp.fuse(kernel_name='squared_diff')
... def squared_diff(x, y):
...     return (x - y) * (x - y)

Note

Currently, cupy.fuse() can fuse only simple elementwise and reduction operations. Most other routines (e.g. cupy.matmul(), cupy.reshape()) are not supported.

API Reference

This is the official reference of CuPy, a multi-dimensional array on CUDA with a subset of NumPy interface.


Multi-Dimensional Array (ndarray)

cupy.ndarray is the CuPy counterpart of NumPy numpy.ndarray. It provides an intuitive interface for a fixed-size multidimensional array which resides in a CUDA device.

For the basic concept of ndarrays, please refer to the NumPy documentation.

cupy.ndarray

Multi-dimensional array on a CUDA device.

Code compatibility features

cupy.ndarray is designed to be interchangeable with numpy.ndarray in terms of code compatibility as much as possible. But occasionally, you will need to know whether the arrays you’re handling are cupy.ndarray or numpy.ndarray. One example is when invoking module-level functions such as cupy.sum() or numpy.sum(). In such situations, cupy.get_array_module() can be used.

cupy.get_array_module

Returns the array module for arguments.

cupyx.scipy.get_array_module

Returns the array module for arguments.

Conversion to/from NumPy arrays

cupy.ndarray and numpy.ndarray are not implicitly convertible to each other. That means, NumPy functions cannot take cupy.ndarrays as inputs, and vice versa.

Note that converting between cupy.ndarray and numpy.ndarray incurs data transfer between the host (CPU) device and the GPU device, which is costly in terms of performance.

cupy.array

Creates an array on the current device.

cupy.asarray

Converts an object to array.

cupy.asnumpy

Returns an array on the host memory from an arbitrary source array.

Universal Functions (ufunc)

CuPy provides universal functions (a.k.a. ufuncs) to support various elementwise operations. CuPy’s ufunc supports following features of NumPy’s one:

  • Broadcasting

  • Output type determination

  • Casting rules

CuPy’s ufunc currently does not provide methods such as reduce, accumulate, reduceat, outer, and at.

Ufunc class

cupy.ufunc

Universal function.

Available ufuncs

Math operations

cupy.add

Adds two arrays elementwise.

cupy.subtract

Subtracts arguments elementwise.

cupy.multiply

Multiplies two arrays elementwise.

cupy.divide

Elementwise true division (i.e.

cupy.logaddexp

Computes log(exp(x1) + exp(x2)) elementwise.

cupy.logaddexp2

Computes log2(exp2(x1) + exp2(x2)) elementwise.

cupy.true_divide

Elementwise true division (i.e.

cupy.floor_divide

Elementwise floor division (i.e.

cupy.negative

Takes numerical negative elementwise.

cupy.power

Computes x1 ** x2 elementwise.

cupy.remainder

Computes the remainder of Python division elementwise.

cupy.mod

Computes the remainder of Python division elementwise.

cupy.fmod

Computes the remainder of C division elementwise.

cupy.absolute

Elementwise absolute value function.

cupy.rint

Rounds each element of an array to the nearest integer.

cupy.sign

Elementwise sign function.

cupy.exp

Elementwise exponential function.

cupy.exp2

Elementwise exponentiation with base 2.

cupy.log

Elementwise natural logarithm function.

cupy.log2

Elementwise binary logarithm function.

cupy.log10

Elementwise common logarithm function.

cupy.expm1

Computes exp(x) - 1 elementwise.

cupy.log1p

Computes log(1 + x) elementwise.

cupy.sqrt

Elementwise square root function.

cupy.square

Elementwise square function.

cupy.reciprocal

Computes 1 / x elementwise.

cupy.gcd

Computes gcd of x1 and x2 elementwise.

cupy.lcm

Computes lcm of x1 and x2 elementwise.

Trigonometric functions

cupy.sin

Elementwise sine function.

cupy.cos

Elementwise cosine function.

cupy.tan

Elementwise tangent function.

cupy.arcsin

Elementwise inverse-sine function (a.k.a.

cupy.arccos

Elementwise inverse-cosine function (a.k.a.

cupy.arctan

Elementwise inverse-tangent function (a.k.a.

cupy.arctan2

Elementwise inverse-tangent of the ratio of two arrays.

cupy.hypot

Computes the hypoteneous of orthogonal vectors of given length.

cupy.sinh

Elementwise hyperbolic sine function.

cupy.cosh

Elementwise hyperbolic cosine function.

cupy.tanh

Elementwise hyperbolic tangent function.

cupy.arcsinh

Elementwise inverse of hyperbolic sine function.

cupy.arccosh

Elementwise inverse of hyperbolic cosine function.

cupy.arctanh

Elementwise inverse of hyperbolic tangent function.

cupy.deg2rad

Converts angles from degrees to radians elementwise.

cupy.rad2deg

Converts angles from radians to degrees elementwise.

Bit-twiddling functions

cupy.bitwise_and

Computes the bitwise AND of two arrays elementwise.

cupy.bitwise_or

Computes the bitwise OR of two arrays elementwise.

cupy.bitwise_xor

Computes the bitwise XOR of two arrays elementwise.

cupy.invert

Computes the bitwise NOT of an array elementwise.

cupy.left_shift

Shifts the bits of each integer element to the left.

cupy.right_shift

Shifts the bits of each integer element to the right.

Comparison functions

cupy.greater

Tests elementwise if x1 > x2.

cupy.greater_equal

Tests elementwise if x1 >= x2.

cupy.less

Tests elementwise if x1 < x2.

cupy.less_equal

Tests elementwise if x1 <= x2.

cupy.not_equal

Tests elementwise if x1 != x2.

cupy.equal

Tests elementwise if x1 == x2.

cupy.logical_and

Computes the logical AND of two arrays.

cupy.logical_or

Computes the logical OR of two arrays.

cupy.logical_xor

Computes the logical XOR of two arrays.

cupy.logical_not

Computes the logical NOT of an array.

cupy.maximum

Takes the maximum of two arrays elementwise.

cupy.minimum

Takes the minimum of two arrays elementwise.

cupy.fmax

Takes the maximum of two arrays elementwise.

cupy.fmin

Takes the minimum of two arrays elementwise.

Floating functions

cupy.isfinite

Tests finiteness elementwise.

cupy.isinf

Tests if each element is the positive or negative infinity.

cupy.isnan

Tests if each element is a NaN.

cupy.signbit

Tests elementwise if the sign bit is set (i.e.

cupy.copysign

Returns the first argument with the sign bit of the second elementwise.

cupy.nextafter

Computes the nearest neighbor float values towards the second argument.

cupy.modf

Extracts the fractional and integral parts of an array elementwise.

cupy.ldexp

Computes x1 * 2 ** x2 elementwise.

cupy.frexp

Decomposes each element to mantissa and two’s exponent.

cupy.fmod

Computes the remainder of C division elementwise.

cupy.floor

Rounds each element of an array to its floor integer.

cupy.ceil

Rounds each element of an array to its ceiling integer.

cupy.trunc

Rounds each element of an array towards zero.

ufunc.at

Currently, CuPy does not support at for ufuncs in general. However, cupyx.scatter_add() can substitute add.at as both behave identically.

Routines

The following pages describe NumPy-compatible routines. These functions cover a subset of NumPy routines.

Array Creation Routines

Basic creation routines

cupy.empty

Returns an array without initializing the elements.

cupy.empty_like

Returns a new array with same shape and dtype of a given array.

cupy.eye

Returns a 2-D array with ones on the diagonals and zeros elsewhere.

cupy.identity

Returns a 2-D identity array.

cupy.ones

Returns a new array of given shape and dtype, filled with ones.

cupy.ones_like

Returns an array of ones with same shape and dtype as a given array.

cupy.zeros

Returns a new array of given shape and dtype, filled with zeros.

cupy.zeros_like

Returns an array of zeros with same shape and dtype as a given array.

cupy.full

Returns a new array of given shape and dtype, filled with a given value.

cupy.full_like

Returns a full array with same shape and dtype as a given array.

Creation from other data

cupy.array

Creates an array on the current device.

cupy.asarray

Converts an object to array.

cupy.asanyarray

Converts an object to array.

cupy.ascontiguousarray

Returns a C-contiguous array.

cupy.copy

Creates a copy of a given array on the current device.

cupy.fromfile

Reads an array from a file.

Numerical ranges

cupy.arange

Returns an array with evenly spaced values within a given interval.

cupy.linspace

Returns an array with evenly-spaced values within a given interval.

cupy.logspace

Returns an array with evenly-spaced values on a log-scale.

cupy.meshgrid

Return coordinate matrices from coordinate vectors.

cupy.mgrid

Construct a multi-dimensional “meshgrid”.

cupy.ogrid

Construct a multi-dimensional “meshgrid”.

Matrix creation

cupy.diag

Returns a diagonal or a diagonal array.

cupy.diagflat

Creates a diagonal array from the flattened input.

cupy.tri

Creates an array with ones at and below the given diagonal.

cupy.tril

Returns a lower triangle of an array.

cupy.triu

Returns an upper triangle of an array.

Array Manipulation Routines

Basic operations

cupy.copyto

Copies values from one array to another with broadcasting.

cupy.shape

Returns the shape of an array

Changing array shape

cupy.reshape

Returns an array with new shape and same elements.

cupy.ravel

Returns a flattened array.

Transpose-like operations

cupy.moveaxis

Moves axes of an array to new positions.

cupy.rollaxis

Moves the specified axis backwards to the given place.

cupy.swapaxes

Swaps the two axes.

cupy.transpose

Permutes the dimensions of an array.

See also

cupy.ndarray.T

Changing number of dimensions

cupy.atleast_1d

Converts arrays to arrays with dimensions >= 1.

cupy.atleast_2d

Converts arrays to arrays with dimensions >= 2.

cupy.atleast_3d

Converts arrays to arrays with dimensions >= 3.

cupy.broadcast

Object that performs broadcasting.

cupy.broadcast_to

Broadcast an array to a given shape.

cupy.broadcast_arrays

Broadcasts given arrays.

cupy.expand_dims

Expands given arrays.

cupy.squeeze

Removes size-one axes from the shape of an array.

Changing kind of array

cupy.asarray

Converts an object to array.

cupy.asanyarray

Converts an object to array.

cupy.asfortranarray

Return an array laid out in Fortran order in memory.

cupy.ascontiguousarray

Returns a C-contiguous array.

cupy.require

Return an array which satisfies the requirements.

Joining arrays

cupy.concatenate

Joins arrays along an axis.

cupy.stack

Stacks arrays along a new axis.

cupy.column_stack

Stacks 1-D and 2-D arrays as columns into a 2-D array.

cupy.dstack

Stacks arrays along the third axis.

cupy.hstack

Stacks arrays horizontally.

cupy.vstack

Stacks arrays vertically.

Splitting arrays

cupy.split

Splits an array into multiple sub arrays along a given axis.

cupy.array_split

Splits an array into multiple sub arrays along a given axis.

cupy.dsplit

Splits an array into multiple sub arrays along the third axis.

cupy.hsplit

Splits an array into multiple sub arrays horizontally.

cupy.vsplit

Splits an array into multiple sub arrays along the first axis.

Tiling arrays

cupy.tile

Construct an array by repeating A the number of times given by reps.

cupy.repeat

Repeat arrays along an axis.

Adding and removing elements

cupy.unique

Find the unique elements of an array.

cupy.trim_zeros

Trim the leading and/or trailing zeros from a 1-D array or sequence.

Rearranging elements

cupy.flip

Reverse the order of elements in an array along the given axis.

cupy.fliplr

Flip array in the left/right direction.

cupy.flipud

Flip array in the up/down direction.

cupy.reshape

Returns an array with new shape and same elements.

cupy.roll

Roll array elements along a given axis.

cupy.rot90

Rotate an array by 90 degrees in the plane specified by axes.

Binary Operations

Elementwise bit operations

cupy.bitwise_and

Computes the bitwise AND of two arrays elementwise.

cupy.bitwise_not

Computes the bitwise NOT of an array elementwise.

cupy.bitwise_or

Computes the bitwise OR of two arrays elementwise.

cupy.bitwise_xor

Computes the bitwise XOR of two arrays elementwise.

cupy.invert

Computes the bitwise NOT of an array elementwise.

cupy.left_shift

Shifts the bits of each integer element to the left.

cupy.right_shift

Shifts the bits of each integer element to the right.

Bit packing

cupy.packbits

Packs the elements of a binary-valued array into bits in a uint8 array.

cupy.unpackbits

Unpacks elements of a uint8 array into a binary-valued output array.

Output formatting

cupy.binary_repr

Return the binary representation of the input number as a string.

Data Type Routines

cupy.can_cast

Returns True if cast between data types can occur according to the casting rule.

cupy.result_type

Returns the type that results from applying the NumPy type promotion rules to the arguments.

cupy.common_type

Return a scalar type which is common to the input arrays.

cupy.promote_types (alias of numpy.promote_types())

cupy.min_scalar_type (alias of numpy.min_scalar_type())

cupy.obj2sctype (alias of numpy.obj2sctype())

Creating data types

cupy.dtype (alias of numpy.dtype)

cupy.format_parser (alias of numpy.format_parser)

Data type information

cupy.finfo (alias of numpy.finfo)

cupy.iinfo (alias of numpy.iinfo)

cupy.MachAr (alias of numpy.MachAr)

Data type testing

cupy.issctype (alias of numpy.issctype())

cupy.issubdtype (alias of numpy.issubdtype())

cupy.issubsctype (alias of numpy.issubsctype())

cupy.issubclass_ (alias of numpy.issubclass_())

cupy.find_common_type (alias of numpy.find_common_type())

Miscellaneous

cupy.typename (alias of numpy.typename())

cupy.sctype2char (alias of numpy.sctype2char())

cupy.mintypecode (alias of numpy.mintypecode())

FFT Functions

Standard FFTs

cupy.fft.fft

Compute the one-dimensional FFT.

cupy.fft.ifft

Compute the one-dimensional inverse FFT.

cupy.fft.fft2

Compute the two-dimensional FFT.

cupy.fft.ifft2

Compute the two-dimensional inverse FFT.

cupy.fft.fftn

Compute the N-dimensional FFT.

cupy.fft.ifftn

Compute the N-dimensional inverse FFT.

Real FFTs

cupy.fft.rfft

Compute the one-dimensional FFT for real input.

cupy.fft.irfft

Compute the one-dimensional inverse FFT for real input.

cupy.fft.rfft2

Compute the two-dimensional FFT for real input.

cupy.fft.irfft2

Compute the two-dimensional inverse FFT for real input.

cupy.fft.rfftn

Compute the N-dimensional FFT for real input.

cupy.fft.irfftn

Compute the N-dimensional inverse FFT for real input.

Hermitian FFTs

cupy.fft.hfft

Compute the FFT of a signal that has Hermitian symmetry.

cupy.fft.ihfft

Compute the FFT of a signal that has Hermitian symmetry.

Helper routines

cupy.fft.fftfreq

Return the FFT sample frequencies.

cupy.fft.rfftfreq

Return the FFT sample frequencies for real input.

cupy.fft.fftshift

Shift the zero-frequency component to the center of the spectrum.

cupy.fft.ifftshift

The inverse of fftshift().

cupy.fft.config.set_cufft_gpus

Set the GPUs to be used in multi-GPU FFT.

cupy.fft.config.get_plan_cache

Get the per-thread, per-device plan cache, or create one if not found.

cupy.fft.config.show_plan_cache_info

Show all of the plan caches’ info on this thread.

Normalization

The default normalization has the direct transforms unscaled and the inverse transforms are scaled by \(1/n\). If the ketyword argument norm is "ortho", both transforms will be scaled by \(1/\sqrt{n}\).

Code compatibility features

FFT functions of NumPy alway return numpy.ndarray which type is numpy.complex128 or numpy.float64. CuPy functions do not follow the behavior, they will return numpy.complex64 or numpy.float32 if the type of the input is numpy.float16, numpy.float32, or numpy.complex64.

Internally, cupy.fft always generates a cuFFT plan (see the cuFFT documentation for detail) corresponding to the desired transform. When possible, an n-dimensional plan will be used, as opposed to applying separate 1D plans for each axis to be transformed. Using n-dimensional planning can provide better performance for multidimensional transforms, but requires more GPU memory than separable 1D planning. The user can disable n-dimensional planning by setting cupy.fft.config.enable_nd_planning = False. This ability to adjust the planning type is a deviation from the NumPy API, which does not use precomputed FFT plans.

Moreover, the automatic plan generation can be suppressed by using an existing plan returned by cupyx.scipy.fftpack.get_fft_plan() as a context manager. This is again a deviation from NumPy.

Finally, when using the high-level NumPy-like FFT APIs as listed above, internally the cuFFT plans are cached for possible reuse. The plan cache can be retrieved by get_plan_cache(), and its current status can be queried by show_plan_cache_info(). For finer control of the plan cache, see cuFFT Plan Cache.

Multi-GPU FFT

cupy.fft can use multiple GPUs. To enable (disable) this feature, set cupy.fft.config.use_multi_gpus to True (False). Next, to set the number of GPUs or the participating GPU IDs, use the function cupy.fft.config.set_cufft_gpus(). All of the limitations listed in the cuFFT documentation apply here. In particular, using more than one GPU does not guarantee better performance.

Functional programming

cupy.piecewise

Evaluate a piecewise-defined function.

Indexing Routines

cupy.c_

cupy.r_

cupy.nonzero

Return the indices of the elements that are non-zero.

cupy.where

Return elements, either from x or y, depending on condition.

cupy.indices

Returns an array representing the indices of a grid.

cupy.ix_

Construct an open mesh from multiple sequences.

cupy.ravel_multi_index

Converts a tuple of index arrays into an array of flat indices, applying boundary modes to the multi-index.

cupy.unravel_index

Converts array of flat indices into a tuple of coordinate arrays.

cupy.take

Takes elements of an array at specified indices along an axis.

cupy.take_along_axis

Take values from the input array by matching 1d index and data slices.

cupy.choose

cupy.compress

Returns selected slices of an array along given axis.

cupy.diag

Returns a diagonal or a diagonal array.

cupy.diag_indices

Return the indices to access the main diagonal of an array.

cupy.diag_indices_from

Return the indices to access the main diagonal of an n-dimensional array.

cupy.diagonal

Returns specified diagonals.

cupy.extract

Return the elements of an array that satisfy some condition.

cupy.select

Return an array drawn from elements in choicelist, depending on conditions.

cupy.lib.stride_tricks.as_strided

Create a view into the array with the given shape and strides.

cupy.place

Change elements of an array based on conditional and input values.

cupy.put

Replaces specified elements of an array with given values.

cupy.putmask

Changes elements of an array inplace, based on a conditional mask and input values.

cupy.fill_diagonal

Fills the main diagonal of the given array of any dimensionality.

cupy.flatiter

Flat iterator object to iterate over arrays.

Input and Output

NumPy binary files (NPY, NPZ)

cupy.load

Loads arrays or pickled objects from .npy, .npz or pickled file.

cupy.save

Saves an array to a binary file in .npy format.

cupy.savez

Saves one or more arrays into a file in uncompressed .npz format.

cupy.savez_compressed

Saves one or more arrays into a file in compressed .npz format.

String formatting

cupy.array_repr

Returns the string representation of an array.

cupy.array_str

Returns the string representation of the content of an array.

Base-n representations

cupy.binary_repr

Return the binary representation of the input number as a string.

cupy.base_repr

Return a string representation of a number in the given base system.

Linear Algebra

Matrix and vector products

cupy.cross

Returns the cross product of two vectors.

cupy.dot

Returns a dot product of two arrays.

cupy.vdot

Returns the dot product of two vectors.

cupy.inner

Returns the inner product of two arrays.

cupy.outer

Returns the outer product of two vectors.

cupy.matmul

Returns the matrix product of two arrays and is the implementation of the @ operator introduced in Python 3.5 following PEP465.

cupy.tensordot

Returns the tensor dot product of two arrays along specified axes.

cupy.einsum

Evaluates the Einstein summation convention on the operands.

cupy.linalg.matrix_power

Raise a square matrix to the (integer) power n.

cupy.kron

Returns the kronecker product of two arrays.

cupyx.scipy.linalg.kron

Kronecker product.

Decompositions

cupy.linalg.cholesky

Cholesky decomposition.

cupy.linalg.qr

QR decomposition.

cupy.linalg.svd

Singular Value Decomposition.

Matrix eigenvalues

cupy.linalg.eigh

Eigenvalues and eigenvectors of a symmetric matrix.

cupy.linalg.eigvalsh

Calculates eigenvalues of a symmetric matrix.

Norms etc.

cupy.linalg.det

Returns the determinant of an array.

cupy.linalg.norm

Returns one of matrix norms specified by ord parameter.

cupy.linalg.matrix_rank

Return matrix rank of array using SVD method

cupy.linalg.slogdet

Returns sign and logarithm of the determinant of an array.

cupy.trace

Returns the sum along the diagonals of an array.

Solving linear equations

cupy.linalg.solve

Solves a linear matrix equation.

cupy.linalg.tensorsolve

Solves tensor equations denoted by ax = b.

cupy.linalg.lstsq

Return the least-squares solution to a linear matrix equation.

cupy.linalg.inv

Computes the inverse of a matrix.

cupy.linalg.pinv

Compute the Moore-Penrose pseudoinverse of a matrix.

cupy.linalg.tensorinv

Computes the inverse of a tensor.

cupyx.scipy.linalg.lu_factor

LU decomposition.

cupyx.scipy.linalg.lu_solve

Solve an equation system, a * x = b, given the LU factorization of a

cupyx.scipy.linalg.solve_triangular

Solve the equation a x = b for x, assuming a is a triangular matrix.

Special Matrices

cupy.tri

Creates an array with ones at and below the given diagonal.

cupy.tril

Returns a lower triangle of an array.

cupy.triu

Returns an upper triangle of an array.

cupyx.scipy.linalg.tri

Construct (N, M) matrix filled with ones at and below the k-th diagonal.

cupyx.scipy.linalg.tril

Make a copy of a matrix with elements above the k-th diagonal zeroed.

cupyx.scipy.linalg.triu

Make a copy of a matrix with elements below the k-th diagonal zeroed.

cupyx.scipy.linalg.toeplitz

Construct a Toeplitz matrix.

cupyx.scipy.linalg.circulant

Construct a circulant matrix.

cupyx.scipy.linalg.hankel

Construct a Hankel matrix.

cupyx.scipy.linalg.hadamard

Construct an Hadamard matrix.

cupyx.scipy.linalg.leslie

Create a Leslie matrix.

cupyx.scipy.linalg.block_diag

Create a block diagonal matrix from provided arrays.

cupyx.scipy.linalg.companion

Create a companion matrix.

cupyx.scipy.linalg.helmert

Create an Helmert matrix of order n.

cupyx.scipy.linalg.hilbert

Create a Hilbert matrix of order n.

cupyx.scipy.linalg.dft

Discrete Fourier transform matrix.

cupyx.scipy.linalg.fiedler

Returns a symmetric Fiedler matrix

cupyx.scipy.linalg.fiedler_companion

Returns a Fiedler companion matrix

cupyx.scipy.linalg.convolution_matrix

Construct a convolution matrix.

Logic Functions

Truth value testing

cupy.all

Tests whether all array elements along a given axis evaluate to True.

cupy.any

Tests whether any array elements along a given axis evaluate to True.

cupy.in1d

Tests whether each element of a 1-D array is also present in a second array.

cupy.isin

Calculates element in test_elements, broadcasting over element only.

Infinities and NaNs

cupy.isfinite

Tests finiteness elementwise.

cupy.isinf

Tests if each element is the positive or negative infinity.

cupy.isnan

Tests if each element is a NaN.

Array type testing

cupy.iscomplex

Returns a bool array, where True if input element is complex.

cupy.iscomplexobj

Check for a complex type or an array of complex numbers.

cupy.isfortran

Returns True if the array is Fortran contiguous but not C contiguous.

cupy.isreal

Returns a bool array, where True if input element is real.

cupy.isrealobj

Return True if x is a not complex type or an array of complex numbers.

cupy.isscalar

Returns True if the type of num is a scalar type.

Logic operations

cupy.logical_and

Computes the logical AND of two arrays.

cupy.logical_or

Computes the logical OR of two arrays.

cupy.logical_not

Computes the logical NOT of an array.

cupy.logical_xor

Computes the logical XOR of two arrays.

Comparison

cupy.allclose

Returns True if two arrays are element-wise equal within a tolerance.

cupy.array_equal

Returns True if two arrays are element-wise exactly equal.

cupy.isclose

Returns a boolean array where two arrays are equal within a tolerance.

cupy.greater

Tests elementwise if x1 > x2.

cupy.greater_equal

Tests elementwise if x1 >= x2.

cupy.less

Tests elementwise if x1 < x2.

cupy.less_equal

Tests elementwise if x1 <= x2.

cupy.equal

Tests elementwise if x1 == x2.

cupy.not_equal

Tests elementwise if x1 != x2.

Mathematical Functions

Trigonometric functions

cupy.sin

Elementwise sine function.

cupy.cos

Elementwise cosine function.

cupy.tan

Elementwise tangent function.

cupy.arcsin

Elementwise inverse-sine function (a.k.a.

cupy.arccos

Elementwise inverse-cosine function (a.k.a.

cupy.arctan

Elementwise inverse-tangent function (a.k.a.

cupy.hypot

Computes the hypoteneous of orthogonal vectors of given length.

cupy.arctan2

Elementwise inverse-tangent of the ratio of two arrays.

cupy.degrees

Converts angles from radians to degrees elementwise.

cupy.radians

Converts angles from degrees to radians elementwise.

cupy.unwrap

Unwrap by changing deltas between values to 2*pi complement.

cupy.deg2rad

Converts angles from degrees to radians elementwise.

cupy.rad2deg

Converts angles from radians to degrees elementwise.

Hyperbolic functions

cupy.sinh

Elementwise hyperbolic sine function.

cupy.cosh

Elementwise hyperbolic cosine function.

cupy.tanh

Elementwise hyperbolic tangent function.

cupy.arcsinh

Elementwise inverse of hyperbolic sine function.

cupy.arccosh

Elementwise inverse of hyperbolic cosine function.

cupy.arctanh

Elementwise inverse of hyperbolic tangent function.

Rounding

cupy.around

Rounds to the given number of decimals.

cupy.round_

cupy.rint

Rounds each element of an array to the nearest integer.

cupy.fix

If given value x is positive, it return floor(x).

cupy.floor

Rounds each element of an array to its floor integer.

cupy.ceil

Rounds each element of an array to its ceiling integer.

cupy.trunc

Rounds each element of an array towards zero.

Sums, products, differences

cupy.prod

Returns the product of an array along given axes.

cupy.sum

Returns the sum of an array along given axes.

cupy.cumprod

Returns the cumulative product of an array along a given axis.

cupy.cumsum

Returns the cumulative sum of an array along a given axis.

cupy.nansum

Returns the sum of an array along given axes treating Not a Numbers (NaNs) as zero.

cupy.nanprod

Returns the product of an array along given axes treating Not a Numbers (NaNs) as zero.

cupy.diff

Calculate the n-th discrete difference along the given axis.

Exponents and logarithms

cupy.exp

Elementwise exponential function.

cupy.expm1

Computes exp(x) - 1 elementwise.

cupy.exp2

Elementwise exponentiation with base 2.

cupy.log

Elementwise natural logarithm function.

cupy.log10

Elementwise common logarithm function.

cupy.log2

Elementwise binary logarithm function.

cupy.log1p

Computes log(1 + x) elementwise.

cupy.logaddexp

Computes log(exp(x1) + exp(x2)) elementwise.

cupy.logaddexp2

Computes log2(exp2(x1) + exp2(x2)) elementwise.

Other special functions

cupy.i0

Modified Bessel function of the first kind, order 0.

cupy.sinc

Elementwise sinc function.

Floating point routines

cupy.signbit

Tests elementwise if the sign bit is set (i.e.

cupy.copysign

Returns the first argument with the sign bit of the second elementwise.

cupy.frexp

Decomposes each element to mantissa and two’s exponent.

cupy.ldexp

Computes x1 * 2 ** x2 elementwise.

cupy.nextafter

Computes the nearest neighbor float values towards the second argument.

Arithmetic operations

cupy.add

Adds two arrays elementwise.

cupy.reciprocal

Computes 1 / x elementwise.

cupy.negative

Takes numerical negative elementwise.

cupy.multiply

Multiplies two arrays elementwise.

cupy.divide

Elementwise true division (i.e.

cupy.power

Computes x1 ** x2 elementwise.

cupy.subtract

Subtracts arguments elementwise.

cupy.true_divide

Elementwise true division (i.e.

cupy.floor_divide

Elementwise floor division (i.e.

cupy.fmod

Computes the remainder of C division elementwise.

cupy.mod

Computes the remainder of Python division elementwise.

cupy.modf

Extracts the fractional and integral parts of an array elementwise.

cupy.remainder

Computes the remainder of Python division elementwise.

cupy.divmod

Handling complex numbers

cupy.angle

Returns the angle of the complex argument.

cupy.real

Returns the real part of the elements of the array.

cupy.imag

Returns the imaginary part of the elements of the array.

cupy.conj

Returns the complex conjugate, element-wise.

Miscellaneous

cupy.convolve

Returns the discrete, linear convolution of two one-dimensional sequences.

cupy.clip

Clips the values of an array to a given interval.

cupy.sqrt

Elementwise square root function.

cupy.cbrt

Elementwise cube root function.

cupy.square

Elementwise square function.

cupy.absolute

Elementwise absolute value function.

cupy.sign

Elementwise sign function.

cupy.maximum

Takes the maximum of two arrays elementwise.

cupy.minimum

Takes the minimum of two arrays elementwise.

cupy.fmax

Takes the maximum of two arrays elementwise.

cupy.fmin

Takes the minimum of two arrays elementwise.

cupy.nan_to_num

Elementwise nan_to_num function.

cupy.bartlett

Returns the Bartlett window.

cupy.blackman

Returns the Blackman window.

cupy.hamming

Returns the Hamming window.

cupy.hanning

Returns the Hanning window.

cupy.kaiser

Return the Kaiser window.

Padding

cupy.pad

Pads an array with specified widths and values.

Polynomials

Polynomial Package
Polynomial Module

cupy.polynomial.polynomial.polyvander

Computes the Vandermonde matrix of given degree.

cupy.polynomial.polynomial.polycompanion

Computes the companion matrix of c.

Polyutils

cupy.polynomial.polyutils.as_series

Returns argument as a list of 1-d arrays.

cupy.polynomial.polyutils.trimseq

Removes small polynomial series coefficients.

cupy.polynomial.polyutils.trimcoef

Removes small trailing coefficients from a polynomial.

Poly1d
Basics

cupy.poly1d

A one-dimensional polynomial class.

cupy.polyval

Evaluates a polynomial at specific values.

cupy.roots

Computes the roots of a polynomial with given coefficients.

Arithmetic

cupy.polyadd

Computes the sum of two polynomials.

cupy.polysub

Computes the difference of two polynomials.

cupy.polymul

Computes the product of two polynomials.

Random Sampling (cupy.random)

Differences between cupy.random and numpy.random:

  • CuPy provides Legacy Random Generation API (see also: NumPy 1.16 Reference). The new random generator API (numpy.random.Generator class) introduced in NumPy 1.17 has not been implemented yet.

  • Most functions under cupy.random support the dtype option, which do not exist in the corresponding NumPy APIs. This option enables generation of float32 values directly without any space overhead.

  • CuPy does not guarantee that the same number generator is used across major versions. This means that numbers generated by cupy.random by new major version may not be the same as the previous one, even if the same seed and distribution are used.

Simple random data

cupy.random.rand

Returns an array of uniform random values over the interval [0, 1).

cupy.random.randn

Returns an array of standard normal random values.

cupy.random.randint

Returns a scalar or an array of integer values over [low, high).

cupy.random.random_integers

Return a scalar or an array of integer values over [low, high]

cupy.random.random_sample

Returns an array of random values over the interval [0, 1).

cupy.random.random

Returns an array of random values over the interval [0, 1).

cupy.random.ranf

Returns an array of random values over the interval [0, 1).

cupy.random.sample

Returns an array of random values over the interval [0, 1).

cupy.random.choice

Returns an array of random values from a given 1-D array.

cupy.random.bytes

Returns random bytes.

Permutations

cupy.random.shuffle

Shuffles an array.

cupy.random.permutation

Returns a permuted range or a permutation of an array.

Distributions

cupy.random.beta

Beta distribution.

cupy.random.binomial

Binomial distribution.

cupy.random.chisquare

Chi-square distribution.

cupy.random.dirichlet

Dirichlet distribution.

cupy.random.exponential

Exponential distribution.

cupy.random.f

F distribution.

cupy.random.gamma

Gamma distribution.

cupy.random.geometric

Geometric distribution.

cupy.random.gumbel

Returns an array of samples drawn from a Gumbel distribution.

cupy.random.hypergeometric

hypergeometric distribution.

cupy.random.laplace

Laplace distribution.

cupy.random.logistic

Logistic distribution.

cupy.random.lognormal

Returns an array of samples drawn from a log normal distribution.

cupy.random.logseries

Log series distribution.

cupy.random.multinomial

Returns an array from multinomial distribution.

cupy.random.multivariate_normal

Multivariate normal distribution.

cupy.random.negative_binomial

Negative binomial distribution.

cupy.random.noncentral_chisquare

Noncentral chisquare distribution.

cupy.random.noncentral_f

Noncentral F distribution.

cupy.random.normal

Returns an array of normally distributed samples.

cupy.random.pareto

Pareto II or Lomax distribution.

cupy.random.poisson

Poisson distribution.

cupy.random.power

Power distribution.

cupy.random.rayleigh

Rayleigh distribution.

cupy.random.standard_cauchy

Standard cauchy distribution.

cupy.random.standard_exponential

Standard exponential distribution.

cupy.random.standard_gamma

Standard gamma distribution.

cupy.random.standard_normal

Returns an array of samples drawn from the standard normal distribution.

cupy.random.standard_t

Standard Student’s t distribution.

cupy.random.triangular

Triangular distribution.

cupy.random.uniform

Returns an array of uniformly-distributed samples over an interval.

cupy.random.vonmises

von Mises distribution.

cupy.random.wald

Wald distribution.

cupy.random.weibull

weibull distribution.

cupy.random.zipf

Zipf distribution.

Random generator

cupy.random.RandomState

Portable container of a pseudo-random number generator.

cupy.random.seed

Resets the state of the random number generator with a seed.

cupy.random.get_random_state

Gets the state of the random number generator for the current device.

cupy.random.set_random_state

Sets the state of the random number generator for the current device.

Note

CuPy does not provide cupy.random.get_state nor cupy.random.set_state at this time. Use cupy.random.get_random_state() and cupy.random.set_random_state() instead. Note that these functions use cupy.random.RandomState instance to represent the internal state, which cannot be serialized.

Sorting, Searching, and Counting

Sorting

cupy.sort

Returns a sorted copy of an array with a stable sorting algorithm.

cupy.lexsort

Perform an indirect sort using an array of keys.

cupy.argsort

Returns the indices that would sort an array with a stable sorting.

cupy.msort

Returns a copy of an array sorted along the first axis.

cupy.sort_complex

Sort a complex array using the real part first, then the imaginary part.

cupy.partition

Returns a partitioned copy of an array.

cupy.argpartition

Returns the indices that would partially sort an array.

Searching

cupy.argmax

Returns the indices of the maximum along an axis.

cupy.nanargmax

Return the indices of the maximum values in the specified axis ignoring NaNs.

cupy.argmin

Returns the indices of the minimum along an axis.

cupy.nanargmin

Return the indices of the minimum values in the specified axis ignoring NaNs.

cupy.nonzero

Return the indices of the elements that are non-zero.

cupy.flatnonzero

Return indices that are non-zero in the flattened version of a.

cupy.where

Return elements, either from x or y, depending on condition.

cupy.argwhere

Return the indices of the elements that are non-zero.

cupy.searchsorted

Finds indices where elements should be inserted to maintain order.

Counting

cupy.count_nonzero

Counts the number of non-zero values in the array.

Statistical Functions

Order statistics

cupy.amin

Returns the minimum of an array or the minimum along an axis.

cupy.amax

Returns the maximum of an array or the maximum along an axis.

cupy.nanmin

Returns the minimum of an array along an axis ignoring NaN.

cupy.nanmax

Returns the maximum of an array along an axis ignoring NaN.

cupy.percentile

Computes the q-th percentile of the data along the specified axis.

Means and variances

cupy.average

Returns the weighted average along an axis.

cupy.mean

Returns the arithmetic mean along an axis.

cupy.var

Returns the variance along an axis.

cupy.std

Returns the standard deviation along an axis.

cupy.nanmean

Returns the arithmetic mean along an axis ignoring NaN values.

cupy.nanvar

Returns the variance along an axis ignoring NaN values.

cupy.nanstd

Returns the standard deviation along an axis ignoring NaN values.

Histograms

cupy.histogram

Computes the histogram of a set of data.

cupy.bincount

Count number of occurrences of each value in array of non-negative ints.

Correlations

cupy.corrcoef

Returns the Pearson product-moment correlation coefficients of an array.

cupy.cov

Returns the covariance matrix of an array.

cupy.correlate

Returns the cross-correlation of two 1-dimensional sequences.

CuPy-specific Functions

CuPy-specific functions are placed under cupyx namespace.

cupyx.rsqrt

Returns the reciprocal square root.

cupyx.scatter_add

Adds given values to specified elements of an array.

cupyx.scatter_max

Stores a maximum value of elements specified by indices to an array.

cupyx.scatter_min

Stores a minimum value of elements specified by indices to an array.

CUB/cuTENSOR backend for reduction routines

Some CuPy reduction routines, including sum(), amin(), amax(), argmin(), argmax(), and other functions built on top of them, can be accelerated by switching to the CUB or cuTENSOR backend. These backends can be enabled by setting CUPY_ACCELERATORS environement variable as documented here. Note that while in general the accelerated reductions are faster, there could be exceptions depending on the data layout. We recommend users to perform some benchmarks to determine whether CUB/cuTENSOR offers better performance or not.

SciPy-compatible Routines

The following pages describe SciPy-compatible routines. These functions cover a subset of SciPy routines.

Discrete Fourier transforms (scipy.fft)

Fast Fourier Transforms

cupyx.scipy.fft.fft

Compute the one-dimensional FFT.

cupyx.scipy.fft.ifft

Compute the one-dimensional inverse FFT.

cupyx.scipy.fft.fft2

Compute the two-dimensional FFT.

cupyx.scipy.fft.ifft2

Compute the two-dimensional inverse FFT.

cupyx.scipy.fft.fftn

Compute the N-dimensional FFT.

cupyx.scipy.fft.ifftn

Compute the N-dimensional inverse FFT.

cupyx.scipy.fft.rfft

Compute the one-dimensional FFT for real input.

cupyx.scipy.fft.irfft

Compute the one-dimensional inverse FFT for real input.

cupyx.scipy.fft.rfft2

Compute the two-dimensional FFT for real input.

cupyx.scipy.fft.irfft2

Compute the two-dimensional inverse FFT for real input.

cupyx.scipy.fft.rfftn

Compute the N-dimensional FFT for real input.

cupyx.scipy.fft.irfftn

Compute the N-dimensional inverse FFT for real input.

cupyx.scipy.fft.hfft

Compute the FFT of a signal that has Hermitian symmetry.

cupyx.scipy.fft.ihfft

Compute the FFT of a signal that has Hermitian symmetry.

Helper functions for FFT

cupyx.scipy.fft.next_fast_len

Find the next fast size to fft.

Code compatibility features
  1. As with other FFT modules in CuPy, FFT functions in this module can take advantage of an existing cuFFT plan (returned by get_fft_plan()) to accelarate the computation. The plan can be either passed in explicitly via the keyword-only plan argument or used as a context manager.

  2. The boolean switch cupy.fft.config.enable_nd_planning also affects the FFT functions in this module, see FFT Functions. This switch is neglected when planning manually using get_fft_plan().

  3. Like in scipy.fft, all FFT functions in this module have an optional argument overwrite_x (default is False), which has the same semantics as in scipy.fft: when it is set to True, the input array x can (not will) be overwritten arbitrarily. For this reason, when an in-place FFT is desired, the user should always reassign the input in the following manner: x = cupyx.scipy.fftpack.fft(x, ..., overwrite_x=True, ...).

  4. The cupyx.scipy.fft module can also be used as a backend for scipy.fft e.g. by installing with scipy.fft.set_backend(cupyx.scipy.fft). This can allow scipy.fft to work with both numpy and cupy arrays.

  5. The boolean switch cupy.fft.config.use_multi_gpus also affects the FFT functions in this module, see FFT Functions. Moreover, this switch is honored when planning manually using get_fft_plan().

Note

scipy.fft requires SciPy version 1.4.0 or newer.

Note

To use scipy.fft.set_backend() together with an explicit plan argument requires SciPy version 1.5.0 or newer.

Legacy Discrete Fourier transforms (scipy.fftpack)

Note

As of SciPy version 1.4.0, scipy.fft is recommended over scipy.fftpack. Consider using cupyx.scipy.fft instead.

Fast Fourier Transforms

cupyx.scipy.fftpack.fft

Compute the one-dimensional FFT.

cupyx.scipy.fftpack.ifft

Compute the one-dimensional inverse FFT.

cupyx.scipy.fftpack.fft2

Compute the two-dimensional FFT.

cupyx.scipy.fftpack.ifft2

Compute the two-dimensional inverse FFT.

cupyx.scipy.fftpack.fftn

Compute the N-dimensional FFT.

cupyx.scipy.fftpack.ifftn

Compute the N-dimensional inverse FFT.

cupyx.scipy.fftpack.rfft

Compute the one-dimensional FFT for real input.

cupyx.scipy.fftpack.irfft

Compute the one-dimensional inverse FFT for real input.

cupyx.scipy.fftpack.get_fft_plan

Generate a CUDA FFT plan for transforming up to three axes.

Code compatibility features
  1. As with other FFT modules in CuPy, FFT functions in this module can take advantage of an existing cuFFT plan (returned by get_fft_plan()) to accelarate the computation. The plan can be either passed in explicitly via the plan argument or used as a context manager. The argument plan is currently experimental and the interface may be changed in the future version. The get_fft_plan() function has no counterpart in scipy.fftpack.

  2. The boolean switch cupy.fft.config.enable_nd_planning also affects the FFT functions in this module, see FFT Functions. This switch is neglected when planning manually using get_fft_plan().

  3. Like in scipy.fftpack, all FFT functions in this module have an optional argument overwrite_x (default is False), which has the same semantics as in scipy.fftpack: when it is set to True, the input array x can (not will) be overwritten arbitrarily. For this reason, when an in-place FFT is desired, the user should always reassign the input in the following manner: x = cupyx.scipy.fftpack.fft(x, ..., overwrite_x=True, ...).

  4. The boolean switch cupy.fft.config.use_multi_gpus also affects the FFT functions in this module, see FFT Functions. Moreover, this switch is honored when planning manually using get_fft_plan().

Linear Algebra

Basics

cupyx.scipy.linalg.solve_triangular

Solve the equation a x = b for x, assuming a is a triangular matrix.

Decompositions

cupyx.scipy.linalg.lu_factor

LU decomposition.

cupyx.scipy.linalg.lu_solve

Solve an equation system, a * x = b, given the LU factorization of a

Special Matrices

cupyx.scipy.linalg.block_diag

Create a block diagonal matrix from provided arrays.

cupyx.scipy.linalg.circulant

Construct a circulant matrix.

cupyx.scipy.linalg.companion

Create a companion matrix.

cupyx.scipy.linalg.convolution_matrix

Construct a convolution matrix.

cupyx.scipy.linalg.dft

Discrete Fourier transform matrix.

cupyx.scipy.linalg.fiedler

Returns a symmetric Fiedler matrix

cupyx.scipy.linalg.fiedler_companion

Returns a Fiedler companion matrix

cupyx.scipy.linalg.hadamard

Construct an Hadamard matrix.

cupyx.scipy.linalg.hankel

Construct a Hankel matrix.

cupyx.scipy.linalg.helmert

Create an Helmert matrix of order n.

cupyx.scipy.linalg.hilbert

Create a Hilbert matrix of order n.

cupyx.scipy.linalg.kron

Kronecker product.

cupyx.scipy.linalg.leslie

Create a Leslie matrix.

cupyx.scipy.linalg.toeplitz

Construct a Toeplitz matrix.

cupyx.scipy.linalg.tri

Construct (N, M) matrix filled with ones at and below the k-th diagonal.

cupyx.scipy.linalg.tril

Make a copy of a matrix with elements above the k-th diagonal zeroed.

cupyx.scipy.linalg.triu

Make a copy of a matrix with elements below the k-th diagonal zeroed.

Multi-dimensional image processing

CuPy provides multi-dimensional image processing functions. It supports a subset of scipy.ndimage interface.

Filters

cupyx.scipy.ndimage.convolve

Multi-dimensional convolution.

cupyx.scipy.ndimage.convolve1d

One-dimensional convolution.

cupyx.scipy.ndimage.correlate

Multi-dimensional correlate.

cupyx.scipy.ndimage.correlate1d

One-dimensional correlate.

cupyx.scipy.ndimage.gaussian_filter

Multi-dimensional Gaussian filter.

cupyx.scipy.ndimage.gaussian_filter1d

One-dimensional Gaussian filter along the given axis.

cupyx.scipy.ndimage.gaussian_gradient_magnitude

Multi-dimensional gradient magnitude using Gaussian derivatives.

cupyx.scipy.ndimage.gaussian_laplace

Multi-dimensional Laplace filter using Gaussian second derivatives.

cupyx.scipy.ndimage.generic_filter

Compute a multi-dimensional filter using the provided raw kernel or reduction kernel.

cupyx.scipy.ndimage.generic_filter1d

Compute a 1D filter along the given axis using the provided raw kernel.

cupyx.scipy.ndimage.generic_gradient_magnitude

Multi-dimensional gradient magnitude filter using a provided derivative function.

cupyx.scipy.ndimage.generic_laplace

Multi-dimensional Laplace filter using a provided second derivative function.

cupyx.scipy.ndimage.laplace

Multi-dimensional Laplace filter based on approximate second derivatives.

cupyx.scipy.ndimage.maximum_filter

Multi-dimensional maximum filter.

cupyx.scipy.ndimage.maximum_filter1d

Compute the maximum filter along a single axis.

cupyx.scipy.ndimage.median_filter

Multi-dimensional median filter.

cupyx.scipy.ndimage.minimum_filter

Multi-dimensional minimum filter.

cupyx.scipy.ndimage.minimum_filter1d

Compute the minimum filter along a single axis.

cupyx.scipy.ndimage.percentile_filter

Multi-dimensional percentile filter.

cupyx.scipy.ndimage.prewitt

Compute a Prewitt filter along the given axis.

cupyx.scipy.ndimage.rank_filter

Multi-dimensional rank filter.

cupyx.scipy.ndimage.sobel

Compute a Sobel filter along the given axis.

cupyx.scipy.ndimage.uniform_filter

Multi-dimensional uniform filter.

cupyx.scipy.ndimage.uniform_filter1d

One-dimensional uniform filter along the given axis.

Fourier Filters

cupyx.scipy.ndimage.fourier_gaussian

Multidimensional Gaussian shift filter.

cupyx.scipy.ndimage.fourier_shift

Multidimensional Fourier shift filter.

cupyx.scipy.ndimage.fourier_uniform

Multidimensional uniform shift filter.

Interpolation

cupyx.scipy.ndimage.affine_transform

Apply an affine transformation.

cupyx.scipy.ndimage.map_coordinates

Map the input array to new coordinates by interpolation.

cupyx.scipy.ndimage.rotate

Rotate an array.

cupyx.scipy.ndimage.shift

Shift an array.

cupyx.scipy.ndimage.zoom

Zoom an array.

Measurements

cupyx.scipy.ndimage.label

Labels features in an array.

cupyx.scipy.ndimage.mean

Calculates the mean of the values of an n-D image array, optionally

cupyx.scipy.ndimage.standard_deviation

Calculates the standard deviation of the values of an n-D image array, optionally at specified sub-regions.

cupyx.scipy.ndimage.sum

Calculates the sum of the values of an n-D image array, optionally

cupyx.scipy.ndimage.variance

Calculates the variance of the values of an n-D image array, optionally at specified sub-regions.

Morphology

cupyx.scipy.ndimage.grey_closing

Calculates a multi-dimensional greyscale closing.

cupyx.scipy.ndimage.grey_dilation

Calculates a greyscale dilation.

cupyx.scipy.ndimage.grey_erosion

Calculates a greyscale erosion.

cupyx.scipy.ndimage.grey_opening

Calculates a multi-dimensional greyscale opening.

OpenCV mode

cupyx.scipy.ndimage supports additional mode, opencv. If it is given, the function performs like cv2.warpAffine or cv2.resize. Example:

import cupyx.scipy.ndimage
import cupy as cp
import cv2

im = cv2.imread('TODO') # pls fill in your image path

trans_mat = cp.eye(4)
trans_mat[0][0] = trans_mat[1][1] = 0.5

smaller_shape = (im.shape[0] // 2, im.shape[1] // 2, 3)
smaller = cp.zeros(smaller_shape) # preallocate memory for resized image

cupyx.scipy.ndimage.affine_transform(im, trans_mat, output_shape=smaller_shape,
                                     output=smaller, mode='opencv')

cv2.imwrite('smaller.jpg', cp.asnumpy(smaller)) # smaller image saved locally

Sparse matrices

CuPy supports sparse matrices using cuSPARSE. These matrices have the same interfaces of SciPy’s sparse matrices.

Conversion to/from SciPy sparse matrices

cupyx.scipy.sparse.*_matrix and scipy.sparse.*_matrix are not implicitly convertible to each other. That means, SciPy functions cannot take cupyx.scipy.sparse.*_matrix objects as inputs, and vice versa.

  • To convert SciPy sparse matrices to CuPy, pass it to the constructor of each CuPy sparse matrix class.

  • To convert CuPy sparse matrices to SciPy, use get method of each CuPy sparse matrix class.

Note that converting between CuPy and SciPy incurs data transfer between the host (CPU) device and the GPU device, which is costly in terms of performance.

Conversion to/from CuPy ndarrays
  • To convert CuPy ndarray to CuPy sparse matrices, pass it to the constructor of each CuPy sparse matrix class.

  • To convert CuPy sparse matrices to CuPy ndarray, use toarray of each CuPy sparse matrix instance (e.g., cupyx.scipy.sparse.csr_matrix.toarray()).

Converting between CuPy ndarray and CuPy sparse matrices does not incur data transfer; it is copied inside the GPU device.

Sparse matrix classes

cupyx.scipy.sparse.coo_matrix

COOrdinate format sparse matrix.

cupyx.scipy.sparse.csc_matrix

Compressed Sparse Column matrix.

cupyx.scipy.sparse.csr_matrix

Compressed Sparse Row matrix.

cupyx.scipy.sparse.dia_matrix

Sparse matrix with DIAgonal storage.

cupyx.scipy.sparse.spmatrix

Base class of all sparse matrixes.

Functions
Building sparse matrices

cupyx.scipy.sparse.bmat

Builds a sparse matrix from sparse sub-blocks

cupyx.scipy.sparse.diags

Construct a sparse matrix from diagonals.

cupyx.scipy.sparse.eye

Creates a sparse matrix with ones on diagonal.

cupyx.scipy.sparse.hstack

Stacks sparse matrices horizontally (column wise)

cupyx.scipy.sparse.identity

Creates an identity matrix in sparse format.

cupyx.scipy.sparse.kron

Kronecker product of sparse matrices A and B.

cupyx.scipy.sparse.spdiags

Creates a sparse matrix from diagonals.

cupyx.scipy.sparse.rand

Generates a random sparse matrix.

cupyx.scipy.sparse.random

Generates a random sparse matrix.

cupyx.scipy.sparse.vstack

Stacks sparse matrices vertically (row wise)

Identifying sparse matrices

cupyx.scipy.sparse.issparse

Checks if a given matrix is a sparse matrix.

cupyx.scipy.sparse.isspmatrix

Checks if a given matrix is a sparse matrix.

cupyx.scipy.sparse.isspmatrix_csc

Checks if a given matrix is of CSC format.

cupyx.scipy.sparse.isspmatrix_csr

Checks if a given matrix is of CSR format.

cupyx.scipy.sparse.isspmatrix_coo

Checks if a given matrix is of COO format.

cupyx.scipy.sparse.isspmatrix_dia

Checks if a given matrix is of DIA format.

Linear Algebra

cupyx.scipy.sparse.linalg.lsqr

Solves linear system with QR decomposition.

cupyx.scipy.sparse.linalg.norm

Norm of a cupy.scipy.spmatrix

Special Functions

Bessel Functions

cupyx.scipy.special.j0

Bessel function of the first kind of order 0.

cupyx.scipy.special.j1

Bessel function of the first kind of order 1.

cupyx.scipy.special.y0

Bessel function of the second kind of order 0.

cupyx.scipy.special.y1

Bessel function of the second kind of order 1.

cupyx.scipy.special.i0

Modified Bessel function of order 0.

cupyx.scipy.special.i1

Modified Bessel function of order 1.

Information Theory Functions

cupyx.scipy.special.entr

Elementwise function for computing entropy.

cupyx.scipy.special.huber

Elementwise function for computing the Huber loss.

cupyx.scipy.special.kl_div

Elementwise function for computing Kullback-Leibler divergence.

cupyx.scipy.special.pseudo_huber

Elementwise function for computing the Pseudo-Huber loss.

cupyx.scipy.special.rel_entr

Elementwise function for computing relative entropy.

Raw Statistical Functions

cupyx.scipy.special.ndtr

Cumulative distribution function of normal distribution.

Error Function

cupyx.scipy.special.erf

Error function.

cupyx.scipy.special.erfc

Complementary error function.

cupyx.scipy.special.erfcx

Scaled complementary error function.

cupyx.scipy.special.erfinv

Inverse function of error function.

cupyx.scipy.special.erfcinv

Inverse function of complementary error function.

Other Special Functions

cupyx.scipy.special.zeta

Hurwitz zeta function.

Signal processing

Convolution

cupyx.scipy.signal.choose_conv_method

Find the fastest convolution/correlation method.

cupyx.scipy.signal.convolve

Convolve two N-dimensional arrays.

cupyx.scipy.signal.convolve2d

Convolve two 2-dimensional arrays.

cupyx.scipy.signal.correlate

Cross-correlate two N-dimensional arrays.

cupyx.scipy.signal.correlate2d

Cross-correlate two 2-dimensional arrays.

Filtering

cupyx.scipy.signal.order_filter

Perform an order filter on an N-D array.

cupyx.scipy.signal.medfilt

Perform a median filter on an N-dimensional array.

cupyx.scipy.signal.medfilt2d

Median filter a 2-dimensional array.

cupyx.scipy.signal.wiener

Perform a Wiener filter on an N-dimensional array.

NumPy-CuPy Generic Code Support

cupy.get_array_module

Returns the array module for arguments.

cupyx.scipy.get_array_module

Returns the array module for arguments.

Memory Management

CuPy uses memory pool for memory allocations by default. The memory pool significantly improves the performance by mitigating the overhead of memory allocation and CPU/GPU synchronization.

There are two different memory pools in CuPy:

  • Device memory pool (GPU device memory), which is used for GPU memory allocations.

  • Pinned memory pool (non-swappable CPU memory), which is used during CPU-to-GPU data transfer.

Attention

When you monitor the memory usage (e.g., using nvidia-smi for GPU memory or ps for CPU memory), you may notice that memory not being freed even after the array instance become out of scope. This is an expected behavior, as the default memory pool “caches” the allocated memory blocks.

See Low-Level CUDA Support for the details of memory management APIs.

Memory Pool Operations

The memory pool instance provides statistics about memory allocation. To access the default memory pool instance, use cupy.get_default_memory_pool() and cupy.get_default_pinned_memory_pool(). You can also free all unused memory blocks hold in the memory pool. See the example code below for details:

import cupy
import numpy

mempool = cupy.get_default_memory_pool()
pinned_mempool = cupy.get_default_pinned_memory_pool()

# Create an array on CPU.
# NumPy allocates 400 bytes in CPU (not managed by CuPy memory pool).
a_cpu = numpy.ndarray(100, dtype=numpy.float32)
print(a_cpu.nbytes)                      # 400

# You can access statistics of these memory pools.
print(mempool.used_bytes())              # 0
print(mempool.total_bytes())             # 0
print(pinned_mempool.n_free_blocks())    # 0

# Transfer the array from CPU to GPU.
# This allocates 400 bytes from the device memory pool, and another 400
# bytes from the pinned memory pool.  The allocated pinned memory will be
# released just after the transfer is complete.  Note that the actual
# allocation size may be rounded to larger value than the requested size
# for performance.
a = cupy.array(a_cpu)
print(a.nbytes)                          # 400
print(mempool.used_bytes())              # 512
print(mempool.total_bytes())             # 512
print(pinned_mempool.n_free_blocks())    # 1

# When the array goes out of scope, the allocated device memory is released
# and kept in the pool for future reuse.
a = None  # (or `del a`)
print(mempool.used_bytes())              # 0
print(mempool.total_bytes())             # 512
print(pinned_mempool.n_free_blocks())    # 1

# You can clear the memory pool by calling `free_all_blocks`.
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
print(mempool.used_bytes())              # 0
print(mempool.total_bytes())             # 0
print(pinned_mempool.n_free_blocks())    # 0

See cupy.cuda.MemoryPool and cupy.cuda.PinnedMemoryPool for details.

Limiting GPU Memory Usage

You can hard-limit the amount of GPU memory that can be allocated by using CUPY_GPU_MEMORY_LIMIT environment variable (see Environment variables for details).

# Set the hard-limit to 1 GiB:
#   $ export CUPY_GPU_MEMORY_LIMIT="1073741824"

# You can also specify the limit in fraction of the total amount of memory
# on the GPU. If you have a GPU with 2 GiB memory, the following is
# equivalent to the above configuration.
#   $ export CUPY_GPU_MEMORY_LIMIT="50%"

import cupy
print(cupy.get_default_memory_pool().get_limit())  # 1073741824

You can also set the limit (or override the value specified via the environment variable) using cupy.cuda.MemoryPool.set_limit(). In this way, you can use a different limit for each GPU device.

import cupy

mempool = cupy.get_default_memory_pool()

with cupy.cuda.Device(0):
    mempool.set_limit(size=1024**3)  # 1 GiB

with cupy.cuda.Device(1):
    mempool.set_limit(size=2*1024**3)  # 2 GiB

Note

CUDA allocates some GPU memory outside of the memory pool (such as CUDA context, library handles, etc.). Depending on the usage, such memory may take one to few hundred MiB. That will not be counted in the limit.

Changing Memory Pool

You can use your own memory allocator instead of the default memory pool by passing the memory allocation function to cupy.cuda.set_allocator() / cupy.cuda.set_pinned_memory_allocator(). The memory allocator function should take 1 argument (the requested size in bytes) and return cupy.cuda.MemoryPointer / cupy.cuda.PinnedMemoryPointer.

You can even disable the default memory pool by the code below. Be sure to do this before any other CuPy operations.

import cupy

# Disable memory pool for device memory (GPU)
cupy.cuda.set_allocator(None)

# Disable memory pool for pinned memory (CPU).
cupy.cuda.set_pinned_memory_allocator(None)

Low-Level CUDA Support

Device management

cupy.cuda.Device

Object that represents a CUDA device.

Memory management

cupy.get_default_memory_pool

Returns CuPy default memory pool for GPU memory.

cupy.get_default_pinned_memory_pool

Returns CuPy default memory pool for pinned memory.

cupy.cuda.Memory

Memory allocation on a CUDA device.

cupy.cuda.UnownedMemory

CUDA memory that is not owned by CuPy.

cupy.cuda.PinnedMemory

Pinned memory allocation on host.

cupy.cuda.MemoryPointer

Pointer to a point on a device memory.

cupy.cuda.PinnedMemoryPointer

Pointer of a pinned memory.

cupy.cuda.alloc

Calls the current allocator.

cupy.cuda.alloc_pinned_memory

Calls the current allocator.

cupy.cuda.get_allocator

Returns the current allocator for GPU memory.

cupy.cuda.set_allocator

Sets the current allocator for GPU memory.

cupy.cuda.using_allocator

Sets a thread-local allocator for GPU memory inside

cupy.cuda.set_pinned_memory_allocator

Sets the current allocator for the pinned memory.

cupy.cuda.MemoryPool

Memory pool for all GPU devices on the host.

cupy.cuda.PinnedMemoryPool

Memory pool for pinned memory on the host.

cupy.cuda.PythonFunctionAllocator

Allocator with python functions to perform memory allocation.

Memory hook

cupy.cuda.MemoryHook

Base class of hooks for Memory allocations.

cupy.cuda.memory_hooks.DebugPrintHook

Memory hook that prints debug information.

cupy.cuda.memory_hooks.LineProfileHook

Code line CuPy memory profiler.

Streams and events

cupy.cuda.Stream

CUDA stream.

cupy.cuda.ExternalStream

CUDA stream.

cupy.cuda.get_current_stream

Gets current CUDA stream.

cupy.cuda.Event

CUDA event, a synchronization point of CUDA streams.

cupy.cuda.get_elapsed_time

Gets the elapsed time between two events.

Texture and surface memory

cupy.cuda.texture.ChannelFormatDescriptor

A class that holds the channel format description.

cupy.cuda.texture.CUDAarray

Allocate a CUDA array (cudaArray_t) that can be used as texture memory.

cupy.cuda.texture.ResourceDescriptor

A class that holds the resource description.

cupy.cuda.texture.TextureDescriptor

A class that holds the texture description.

cupy.cuda.texture.TextureObject

A class that holds a texture object.

cupy.cuda.texture.SurfaceObject

A class that holds a surface object.

cupy.cuda.texture.TextureReference

A class that holds a texture reference.

Profiler

cupy.cuda.profile

Enable CUDA profiling during with statement.

cupy.cuda.profiler.initialize

Initialize the CUDA profiler.

cupy.cuda.profiler.start

Enable profiling.

cupy.cuda.profiler.stop

Disable profiling.

cupy.cuda.nvtx.Mark

Marks an instantaneous event (marker) in the application.

cupy.cuda.nvtx.MarkC

Marks an instantaneous event (marker) in the application.

cupy.cuda.nvtx.RangePush

Starts a nested range.

cupy.cuda.nvtx.RangePushC

Starts a nested range.

cupy.cuda.nvtx.RangePop

Ends a nested range.

NCCL

cupy.cuda.nccl.NcclCommunicator

Initialize an NCCL communicator for one device controlled by one process.

cupy.cuda.nccl.get_build_version

cupy.cuda.nccl.get_version

Returns the runtime version of NCCL.

cupy.cuda.nccl.get_unique_id

cupy.cuda.nccl.groupStart

Start a group of NCCL calls.

cupy.cuda.nccl.groupEnd

End a group of NCCL calls.

Runtime API

CuPy wraps CUDA Runtime APIs to provide the native CUDA operations. Please check the Original CUDA Runtime API document to use these functions.

cupy.cuda.runtime.driverGetVersion

cupy.cuda.runtime.runtimeGetVersion

cupy.cuda.runtime.getDevice

cupy.cuda.runtime.deviceGetAttribute

cupy.cuda.runtime.deviceGetByPCIBusId

cupy.cuda.runtime.deviceGetPCIBusId

cupy.cuda.runtime.getDeviceCount

cupy.cuda.runtime.setDevice

cupy.cuda.runtime.deviceSynchronize

cupy.cuda.runtime.deviceCanAccessPeer

cupy.cuda.runtime.deviceEnablePeerAccess

cupy.cuda.runtime.deviceGetLimit

cupy.cuda.runtime.deviceSetLimit

cupy.cuda.runtime.malloc

cupy.cuda.runtime.mallocManaged

cupy.cuda.runtime.malloc3DArray

cupy.cuda.runtime.mallocArray

cupy.cuda.runtime.hostAlloc

cupy.cuda.runtime.hostRegister

cupy.cuda.runtime.hostUnregister

cupy.cuda.runtime.free

cupy.cuda.runtime.freeHost

cupy.cuda.runtime.freeArray

cupy.cuda.runtime.memGetInfo

cupy.cuda.runtime.memcpy

cupy.cuda.runtime.memcpyAsync

cupy.cuda.runtime.memcpyPeer

cupy.cuda.runtime.memcpyPeerAsync

cupy.cuda.runtime.memcpy2D

cupy.cuda.runtime.memcpy2DAsync

cupy.cuda.runtime.memcpy2DFromArray

cupy.cuda.runtime.memcpy2DFromArrayAsync

cupy.cuda.runtime.memcpy2DToArray

cupy.cuda.runtime.memcpy2DToArrayAsync

cupy.cuda.runtime.memcpy3D

cupy.cuda.runtime.memcpy3DAsync

cupy.cuda.runtime.memset

cupy.cuda.runtime.memsetAsync

cupy.cuda.runtime.memPrefetchAsync

cupy.cuda.runtime.memAdvise

cupy.cuda.runtime.pointerGetAttributes

cupy.cuda.runtime.streamCreate

cupy.cuda.runtime.streamCreateWithFlags

cupy.cuda.runtime.streamDestroy

cupy.cuda.runtime.streamSynchronize

cupy.cuda.runtime.streamAddCallback

cupy.cuda.runtime.streamQuery

cupy.cuda.runtime.streamWaitEvent

cupy.cuda.runtime.eventCreate

cupy.cuda.runtime.eventCreateWithFlags

cupy.cuda.runtime.eventDestroy

cupy.cuda.runtime.eventElapsedTime

cupy.cuda.runtime.eventQuery

cupy.cuda.runtime.eventRecord

cupy.cuda.runtime.eventSynchronize

cupy.cuda.runtime.ipcGetMemHandle

cupy.cuda.runtime.ipcOpenMemHandle

cupy.cuda.runtime.ipcCloseMemHandle

cupy.cuda.runtime.ipcGetEventHandle

cupy.cuda.runtime.ipcOpenEventHandle

Kernel binary memoization

cupy.memoize

Makes a function memoizing the result for each argument and device.

cupy.clear_memo

Clears the memoized results for all functions decorated by memoize.

Custom kernels

cupy.ElementwiseKernel

User-defined elementwise kernel.

cupy.ReductionKernel

User-defined reduction kernel.

cupy.RawKernel

User-defined custom kernel.

cupy.RawModule

User-defined custom module.

cupy.fuse

Decorator that fuses a function.

Automatic Kernel Parameters Optimizations

cupyx.optimizing.optimize

Context manager that optimizes kernel launch parameters.

Interoperability

CuPy can also be used in conjunction with other frameworks.

NumPy

cupy.ndarray implements __array_ufunc__ interface (see NEP 13 — A Mechanism for Overriding Ufuncs for details). This enables NumPy ufuncs to be directly operated on CuPy arrays. __array_ufunc__ feature requires NumPy 1.13 or later.

import cupy
import numpy

arr = cupy.random.randn(1, 2, 3, 4).astype(cupy.float32)
result = numpy.sum(arr)
print(type(result))  # => <class 'cupy.core.core.ndarray'>

cupy.ndarray also implements __array_function__ interface (see NEP 18 — A dispatch mechanism for NumPy’s high level array functions for details). This enables code using NumPy to be directly operated on CuPy arrays. __array_function__ feature requires NumPy 1.16 or later; note that this is currently defined as an experimental feature of NumPy and you need to specify the environment variable (NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1) to enable it.

Numba

Numba is a Python JIT compiler with NumPy support.

cupy.ndarray implements __cuda_array_interface__, which is the CUDA array interchange interface compatible with Numba v0.39.0 or later (see CUDA Array Interface for details). It means you can pass CuPy arrays to kernels JITed with Numba. The following is a simple example code borrowed from numba/numba#2860:

import cupy
from numba import cuda

@cuda.jit
def add(x, y, out):
        start = cuda.grid(1)
        stride = cuda.gridsize(1)
        for i in range(start, x.shape[0], stride):
                out[i] = x[i] + y[i]

a = cupy.arange(10)
b = a * 2
out = cupy.zeros_like(a)

print(out)  # => [0 0 0 0 0 0 0 0 0 0]

add[1, 32](a, b, out)

print(out)  # => [ 0  3  6  9 12 15 18 21 24 27]

In addition, cupy.asarray() supports zero-copy conversion from Numba CUDA array to CuPy array.

import numpy
import numba
import cupy

x = numpy.arange(10)  # type: numpy.ndarray
x_numba = numba.cuda.to_device(x)  # type: numba.cuda.cudadrv.devicearray.DeviceNDArray
x_cupy = cupy.asarray(x_numba)  # type: cupy.ndarray

mpi4py

MPI for Python (mpi4py) is a Python wrapper for the Message Passing Interface (MPI) libraries.

MPI is the most widely used standard for high-performance inter-process communications. Recently several MPI vendors, including Open MPI and MVAPICH, have extended their support beyond the v3.1 standard to enable “CUDA-awareness”; that is, passing CUDA device pointers directly to MPI calls to avoid explicit data movement between the host and the device.

With the aforementioned __cuda_array_interface__ standard implemented in CuPy, mpi4py now provides (experimental) support for passing CuPy arrays to MPI calls, provided that mpi4py is built against a CUDA-aware MPI implementation. The following is a simple example code borrowed from mpi4py Tutorial:

# To run this script with N MPI processes, do
# mpiexec -n N python this_script.py

import cupy
from mpi4py import MPI

comm = MPI.COMM_WORLD
size = comm.Get_size()

# Allreduce
sendbuf = cupy.arange(10, dtype='i')
recvbuf = cupy.empty_like(sendbuf)
comm.Allreduce(sendbuf, recvbuf)
assert cupy.allclose(recvbuf, sendbuf*size)

This new feature will be officially released in mpi4py 3.1.0. To try it out, please build mpi4py from source for the time being. See the mpi4py website for more information.

DLPack

DLPack is a specification of tensor structure to share tensors among frameworks.

CuPy supports importing from and exporting to DLPack data structure (cupy.fromDlpack() and cupy.ndarray.toDlpack()).

cupy.fromDlpack

Zero-copy conversion from a DLPack tensor to a ndarray.

Here is a simple example:

import cupy

# Create a CuPy array.
cx1 = cupy.random.randn(1, 2, 3, 4).astype(cupy.float32)

# Convert it into a DLPack tensor.
dx = cx1.toDlpack()

# Convert it back to a CuPy array.
cx2 = cupy.fromDlpack(dx)

Here is an example of converting PyTorch tensor into cupy.ndarray.

import cupy
import torch

from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack

# Create a PyTorch tensor.
tx1 = torch.randn(1, 2, 3, 4).cuda()

# Convert it into a DLPack tensor.
dx = to_dlpack(tx1)

# Convert it into a CuPy array.
cx = cupy.fromDlpack(dx)

# Convert it back to a PyTorch tensor.
tx2 = from_dlpack(cx.toDlpack())

Note that as of DLPack v0.3 for correctness it (implicitly) requires users to ensure that such conversion (both importing and exporting a CuPy array) must happen on the same CUDA/HIP stream. If in doubt, the current CuPy stream in use can be fetched by, for example, calling cupy.cuda.get_current_stream(). Please consult the other framework’s documentation for how to access and control the streams. This requirement might be relaxed/changed in a future DLPack version.

Testing Modules

CuPy offers testing utilities to support unit testing. They are under namespace cupy.testing.

Standard Assertions

The assertions have same names as NumPy’s ones. The difference from NumPy is that they can accept both numpy.ndarray and cupy.ndarray.

cupy.testing.assert_allclose

Raises an AssertionError if objects are not equal up to desired tolerance.

cupy.testing.assert_array_almost_equal

Raises an AssertionError if objects are not equal up to desired precision.

cupy.testing.assert_array_almost_equal_nulp

Compare two arrays relatively to their spacing.

cupy.testing.assert_array_max_ulp

Check that all items of arrays differ in at most N Units in the Last Place.

cupy.testing.assert_array_equal

Raises an AssertionError if two array_like objects are not equal.

cupy.testing.assert_array_list_equal

Compares lists of arrays pairwise with assert_array_equal.

cupy.testing.assert_array_less

Raises an AssertionError if array_like objects are not ordered by less than.

NumPy-CuPy Consistency Check

The following decorators are for testing consistency between CuPy’s functions and corresponding NumPy’s ones.

cupy.testing.numpy_cupy_allclose

Decorator that checks NumPy results and CuPy ones are close.

cupy.testing.numpy_cupy_array_almost_equal

Decorator that checks NumPy results and CuPy ones are almost equal.

cupy.testing.numpy_cupy_array_almost_equal_nulp

Decorator that checks results of NumPy and CuPy are equal w.r.t.

cupy.testing.numpy_cupy_array_max_ulp

Decorator that checks results of NumPy and CuPy ones are equal w.r.t.

cupy.testing.numpy_cupy_array_equal

Decorator that checks NumPy results and CuPy ones are equal.

cupy.testing.numpy_cupy_array_list_equal

Decorator that checks the resulting lists of NumPy and CuPy’s one are equal.

cupy.testing.numpy_cupy_array_less

Decorator that checks the CuPy result is less than NumPy result.

Parameterized dtype Test

The following decorators offer the standard way for parameterized test with respect to single or the combination of dtype(s).

cupy.testing.for_dtypes

Decorator for parameterized dtype test.

cupy.testing.for_all_dtypes

Decorator that checks the fixture with all dtypes.

cupy.testing.for_float_dtypes

Decorator that checks the fixture with float dtypes.

cupy.testing.for_signed_dtypes

Decorator that checks the fixture with signed dtypes.

cupy.testing.for_unsigned_dtypes

Decorator that checks the fixture with unsinged dtypes.

cupy.testing.for_int_dtypes

Decorator that checks the fixture with integer and optionally bool dtypes.

cupy.testing.for_complex_dtypes

Decorator that checks the fixture with complex dtypes.

cupy.testing.for_dtypes_combination

Decorator that checks the fixture with a product set of dtypes.

cupy.testing.for_all_dtypes_combination

Decorator that checks the fixture with a product set of all dtypes.

cupy.testing.for_signed_dtypes_combination

Decorator for parameterized test w.r.t.

cupy.testing.for_unsigned_dtypes_combination

Decorator for parameterized test w.r.t.

cupy.testing.for_int_dtypes_combination

Decorator for parameterized test w.r.t.

Parameterized order Test

The following decorators offer the standard way to parameterize tests with orders.

cupy.testing.for_orders

Decorator to parameterize tests with order.

cupy.testing.for_CF_orders

Decorator that checks the fixture with orders ‘C’ and ‘F’.

Profiling

time range

cupy.prof.TimeRangeDecorator

Decorator to mark function calls with range in NVIDIA profiler

cupy.prof.time_range

A context manager to describe the enclosed block as a nested range

Device synchronization detection

cupyx.allow_synchronize

Allows or disallows device synchronization temporarily in the current thread.

cupyx.DeviceSynchronized

Raised when device synchronization is detected while disallowed.

Environment variables

Here are the environment variables CuPy uses.

CUDA_PATH

Path to the directory containing CUDA. The parent of the directory containing nvcc is used as default. When nvcc is not found, /usr/local/cuda is used. See Working with Custom CUDA Installation for details.

CUPY_CACHE_DIR

Path to the directory to store kernel cache. ${HOME}/.cupy/kernel_cache is used by default. See Overview for details.

CUPY_CACHE_SAVE_CUDA_SOURCE

If set to 1, CUDA source file will be saved along with compiled binary in the cache directory for debug purpose. It is disabled by default. Note: source file will not be saved if the compiled binary is already stored in the cache.

CUPY_CACHE_IN_MEMORY

If set to 1, CUPY_CACHE_DIR (and its default) and CUPY_CACHE_SAVE_CUDA_SOURCE will be ignored, and the cache is in memory. This env var allows reducing disk I/O, but is ignoed when nvcc is set to be the compiler backend.

CUPY_DUMP_CUDA_SOURCE_ON_ERROR

If set to 1, when CUDA kernel compilation fails, CuPy dumps CUDA kernel code to standard error. It is disabled by default.

CUPY_CUDA_COMPILE_WITH_DEBUG

If set to 1, CUDA kernel will be compiled with debug information (--device-debug and --generate-line-info). It is disabled by default.

CUPY_GPU_MEMORY_LIMIT

The amount of memory that can be allocated for each device. The value can be specified in absolute bytes or fraction (e.g., "90%") of the total memory of each GPU. See Memory Management for details. 0 (unlimited) is used by default.

CUPY_SEED

Set the seed for random number generators.

CUPY_EXPERIMENTAL_SLICE_COPY

If set to 1, the following syntax is enabled: cupy_ndarray[:] = numpy_ndarray.

CUPY_ACCELERATORS

A comma-separated string of backend names (cub or cutensor) which indicates the acceleration backends used in CuPy operations and its priority. Default is empty string (all accelerators are disabled).

CUPY_TF32

If set to 1, it allows CUDA libraries to use Tensor Cores TF32 compute for 32-bit floating point compute. The default is 0 and TF32 is not used.

CUPY_CUDA_ARRAY_INTERFACE_SYNC

This controls CuPy’s behavior as a Consumer. If set to 0, a stream synchronization will not be performed when a device array provided by an external library that implements the CUDA Array Interface is being consumed by CuPy. Default is 1. For more detail, see the Synchronization requirement in the CUDA Array Interface v3 documentation.

CUPY_CUDA_ARRAY_INTERFACE_EXPORT_VERSION

This controls CuPy’s behavior as a Producer. If set to 2, the CuPy stream on which the data is being operated will not be exported and thus the Consumer (another library) will not perform any stream synchronization. Default is 3. For more detail, see the Synchronization requirement in the CUDA Array Interface v3 documentation.

NVCC

Define the compiler to use when compiling CUDA source. Note that most CuPy kernels are built with NVRTC; this environment is only effective for RawKernels/RawModules with nvcc backend or when using cub as the accelerator.

Moreover, as in any CUDA programs, all of the CUDA environment variables listed in the CUDA Toolkit Documentation will also be honored. When CUPY_ACCELERATORS or NVCC environment variables are set, g++-6 or later is required as the runtime host compiler. Please refer to Installing CuPy from Source for the details on how to install g++.

For installation

These environment variables are used during installation (building CuPy from source).

CUDA_PATH

See the description above.

CUTENSOR_PATH

Path to the cuTENSOR root directory that contains lib and include directories. (experimental)

NVCC

Define the compiler to use when compiling CUDA files.

CUPY_PYTHON_350_FORCE

Enforce CuPy to be installed against Python 3.5.0 (not recommended).

CUPY_INSTALL_USE_HIP

For building the ROCm support, see Building CuPy for ROCm for further detail.

CUPY_NVCC_GENERATE_CODE

To build CuPy for a particular CUDA architecture. For example, CUPY_NVCC_GENERATE_CODE="arch=compute_60,code=sm_60". For specifying multiple archs, concatenate the arch=... strings with semicolons (;). When this is not set, the default is to support all architectures.

Difference between CuPy and NumPy

The interface of CuPy is designed to obey that of NumPy. However, there are some differences.

Cast behavior from float to integer

Some casting behaviors from float to integer are not defined in C++ specification. The casting from a negative float to unsigned integer and infinity to integer is one of such examples. The behavior of NumPy depends on your CPU architecture. This is the result on an Intel CPU:

>>> np.array([-1], dtype=np.float32).astype(np.uint32)
array([4294967295], dtype=uint32)
>>> cupy.array([-1], dtype=np.float32).astype(np.uint32)
array([0], dtype=uint32)
>>> np.array([float('inf')], dtype=np.float32).astype(np.int32)
array([-2147483648], dtype=int32)
>>> cupy.array([float('inf')], dtype=np.float32).astype(np.int32)
array([2147483647], dtype=int32)

Random methods support dtype argument

NumPy’s random value generator does not support a dtype argument and instead always returns a float64 value. We support the option in CuPy because cuRAND, which is used in CuPy, supports both float32 and float64.

>>> np.random.randn(dtype=np.float32)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: randn() got an unexpected keyword argument 'dtype'
>>> cupy.random.randn(dtype=np.float32)    
array(0.10689262300729752, dtype=float32)

Out-of-bounds indices

CuPy handles out-of-bounds indices differently by default from NumPy when using integer array indexing. NumPy handles them by raising an error, but CuPy wraps around them.

>>> x = np.array([0, 1, 2])
>>> x[[1, 3]] = 10
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: index 3 is out of bounds for axis 1 with size 3
>>> x = cupy.array([0, 1, 2])
>>> x[[1, 3]] = 10
>>> x
array([10, 10,  2])

Duplicate values in indices

CuPy’s __setitem__ behaves differently from NumPy when integer arrays reference the same location multiple times. In that case, the value that is actually stored is undefined. Here is an example of CuPy.

>>> a = cupy.zeros((2,))
>>> i = cupy.arange(10000) % 2
>>> v = cupy.arange(10000).astype(np.float32)
>>> a[i] = v
>>> a  
array([ 9150.,  9151.])

NumPy stores the value corresponding to the last element among elements referencing duplicate locations.

>>> a_cpu = np.zeros((2,))
>>> i_cpu = np.arange(10000) % 2
>>> v_cpu = np.arange(10000).astype(np.float32)
>>> a_cpu[i_cpu] = v_cpu
>>> a_cpu
array([9998., 9999.])

Zero-dimensional array

Reduction methods

NumPy’s reduction functions (e.g. numpy.sum()) return scalar values (e.g. numpy.float32). However CuPy counterparts return zero-dimensional cupy.ndarray s. That is because CuPy scalar values (e.g. cupy.float32) are aliases of NumPy scalar values and are allocated in CPU memory. If these types were returned, it would be required to synchronize between GPU and CPU. If you want to use scalar values, cast the returned arrays explicitly.

>>> type(np.sum(np.arange(3))) == np.int64
True
>>> type(cupy.sum(cupy.arange(3))) == cupy.core.core.ndarray
True
Type promotion

CuPy automatically promotes dtypes of cupy.ndarray s in a function with two or more operands, the result dtype is determined by the dtypes of the inputs. This is different from NumPy’s rule on type promotion, when operands contain zero-dimensional arrays. Zero-dimensional numpy.ndarray s are treated as if they were scalar values if they appear in operands of NumPy’s function, This may affect the dtype of its output, depending on the values of the “scalar” inputs.

>>> (np.array(3, dtype=np.int32) * np.array([1., 2.], dtype=np.float32)).dtype
dtype('float32')
>>> (np.array(300000, dtype=np.int32) * np.array([1., 2.], dtype=np.float32)).dtype
dtype('float64')
>>> (cupy.array(3, dtype=np.int32) * cupy.array([1., 2.], dtype=np.float32)).dtype
dtype('float64')

Data types

Data type of CuPy arrays cannot be non-numeric like strings and objects. See Overview for details.

Universal Functions only work with CuPy array or scalar

Unlike NumPy, Universal Functions in CuPy only work with CuPy array or scalar. They do not accept other objects (e.g., lists or numpy.ndarray).

>>> np.power([np.arange(5)], 2)
array([[ 0,  1,  4,  9, 16]])
>>> cupy.power([cupy.arange(5)], 2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Unsupported type <class 'list'>

Random seed arrays are hashed to scalars

Like Numpy, CuPy’s RandomState objects accept seeds either as numbers or as full numpy arrays.

>>> seed = np.array([1, 2, 3, 4, 5])
>>> rs = cupy.random.RandomState(seed=seed)

However, unlike Numpy, array seeds will be hashed down to a single number and so may not communicate as much entropy to the underlying random number generator.

Comparison Table

Here is a list of NumPy / SciPy APIs and its corresponding CuPy implementations.

- in CuPy column denotes that CuPy implementation is not provided yet. We welcome contributions for these functions.

NumPy / CuPy APIs

Module-Level

NumPy

CuPy

numpy.abs

cupy.abs

numpy.absolute

cupy.absolute

numpy.add

cupy.add

numpy.add_docstring

-

numpy.add_newdoc

-

numpy.add_newdoc_ufunc

-

numpy.alen

-

numpy.all

cupy.all

numpy.allclose

cupy.allclose

numpy.alltrue

-

numpy.amax

cupy.amax

numpy.amin

cupy.amin

numpy.angle

cupy.angle

numpy.any

cupy.any

numpy.append

-

numpy.apply_along_axis

-

numpy.apply_over_axes

-

numpy.arange

cupy.arange

numpy.arccos

cupy.arccos

numpy.arccosh

cupy.arccosh

numpy.arcsin

cupy.arcsin

numpy.arcsinh

cupy.arcsinh

numpy.arctan

cupy.arctan

numpy.arctan2

cupy.arctan2

numpy.arctanh

cupy.arctanh

numpy.argmax

cupy.argmax

numpy.argmin

cupy.argmin

numpy.argpartition

cupy.argpartition

numpy.argsort

cupy.argsort

numpy.argwhere

cupy.argwhere

numpy.around

cupy.around

numpy.array

cupy.array

numpy.array2string

-

numpy.array_equal

cupy.array_equal

numpy.array_equiv

-

numpy.array_repr

cupy.array_repr

numpy.array_split

cupy.array_split

numpy.array_str

cupy.array_str

numpy.asanyarray

cupy.asanyarray

numpy.asarray

cupy.asarray

numpy.asarray_chkfinite

-

numpy.ascontiguousarray

cupy.ascontiguousarray

numpy.asfarray

-

numpy.asfortranarray

cupy.asfortranarray

numpy.asmatrix

-

numpy.asscalar

-

numpy.atleast_1d

cupy.atleast_1d

numpy.atleast_2d

cupy.atleast_2d

numpy.atleast_3d

cupy.atleast_3d

numpy.average

cupy.average

numpy.bartlett

cupy.bartlett

numpy.base_repr

cupy.base_repr

numpy.binary_repr

cupy.binary_repr

numpy.bincount

cupy.bincount

numpy.bitwise_and

cupy.bitwise_and

numpy.bitwise_not

cupy.bitwise_not

numpy.bitwise_or

cupy.bitwise_or

numpy.bitwise_xor

cupy.bitwise_xor

numpy.blackman

cupy.blackman

numpy.block

-

numpy.bmat

-

numpy.broadcast_arrays

cupy.broadcast_arrays

numpy.broadcast_shapes

-

numpy.broadcast_to

cupy.broadcast_to

numpy.busday_count

-

numpy.busday_offset

-

numpy.byte_bounds

-

numpy.can_cast

cupy.can_cast

numpy.cbrt

cupy.cbrt

numpy.ceil

cupy.ceil

numpy.choose

cupy.choose

numpy.clip

cupy.clip

numpy.column_stack

cupy.column_stack

numpy.common_type

cupy.common_type

numpy.compare_chararrays

-

numpy.compress

cupy.compress

numpy.concatenate

cupy.concatenate

numpy.conj

cupy.conj

numpy.conjugate

cupy.conjugate

numpy.convolve

cupy.convolve

numpy.copy

cupy.copy

numpy.copysign

cupy.copysign

numpy.copyto

cupy.copyto

numpy.corrcoef

cupy.corrcoef

numpy.correlate

cupy.correlate

numpy.cos

cupy.cos

numpy.cosh

cupy.cosh

numpy.count_nonzero

cupy.count_nonzero

numpy.cov

cupy.cov

numpy.cross

cupy.cross

numpy.cumprod

cupy.cumprod

numpy.cumproduct

-

numpy.cumsum

cupy.cumsum

numpy.datetime_as_string

-

numpy.datetime_data

-

numpy.deg2rad

cupy.deg2rad

numpy.degrees

cupy.degrees

numpy.delete

-

numpy.deprecate

-

numpy.deprecate_with_doc

-

numpy.diag

cupy.diag

numpy.diag_indices

cupy.diag_indices

numpy.diag_indices_from

cupy.diag_indices_from

numpy.diagflat

cupy.diagflat

numpy.diagonal

cupy.diagonal

numpy.diff

cupy.diff

numpy.digitize

cupy.digitize

numpy.disp

-

numpy.divide

cupy.divide

numpy.divmod

cupy.divmod

numpy.dot

cupy.dot

numpy.dsplit

cupy.dsplit

numpy.dstack

cupy.dstack

numpy.ediff1d

-

numpy.einsum

cupy.einsum

numpy.einsum_path

-

numpy.empty

cupy.empty

numpy.empty_like

cupy.empty_like

numpy.equal

cupy.equal

numpy.exp

cupy.exp

numpy.exp2

cupy.exp2

numpy.expand_dims

cupy.expand_dims

numpy.expm1

cupy.expm1

numpy.extract

cupy.extract

numpy.eye

cupy.eye

numpy.fabs

-

numpy.fastCopyAndTranspose

-

numpy.fill_diagonal

cupy.fill_diagonal

numpy.find_common_type

cupy.find_common_type (alias of numpy.find_common_type)

numpy.fix

cupy.fix

numpy.flatnonzero

cupy.flatnonzero

numpy.flip

cupy.flip

numpy.fliplr

cupy.fliplr

numpy.flipud

cupy.flipud

numpy.float_power

-

numpy.floor

cupy.floor

numpy.floor_divide

cupy.floor_divide

numpy.fmax

cupy.fmax

numpy.fmin

cupy.fmin

numpy.fmod

cupy.fmod

numpy.format_float_positional

-

numpy.format_float_scientific

-

numpy.frexp

cupy.frexp

numpy.frombuffer

-

numpy.fromfile

cupy.fromfile

numpy.fromfunction

-

numpy.fromiter

-

numpy.frompyfunc

-

numpy.fromregex

-

numpy.fromstring

-

numpy.full

cupy.full

numpy.full_like

cupy.full_like

numpy.gcd

cupy.gcd

numpy.genfromtxt

-

numpy.geomspace

-

numpy.get_array_wrap

-

numpy.get_include

-

numpy.get_printoptions

-

numpy.getbufsize

-

numpy.geterr

-

numpy.geterrcall

-

numpy.geterrobj

-

numpy.gradient

-

numpy.greater

cupy.greater

numpy.greater_equal

cupy.greater_equal

numpy.hamming

cupy.hamming

numpy.hanning

cupy.hanning

numpy.heaviside

-

numpy.histogram

cupy.histogram

numpy.histogram2d

-

numpy.histogram_bin_edges

-

numpy.histogramdd

-

numpy.hsplit

cupy.hsplit

numpy.hstack

cupy.hstack

numpy.hypot

cupy.hypot

numpy.i0

cupy.i0

numpy.identity

cupy.identity

numpy.imag

cupy.imag

numpy.in1d

cupy.in1d

numpy.indices

cupy.indices

numpy.info

-

numpy.inner

cupy.inner

numpy.insert

-

numpy.interp

-

numpy.intersect1d

-

numpy.invert

cupy.invert

numpy.is_busday

-

numpy.isclose

cupy.isclose

numpy.iscomplex

cupy.iscomplex

numpy.iscomplexobj

cupy.iscomplexobj

numpy.isfinite

cupy.isfinite

numpy.isfortran

cupy.isfortran

numpy.isin

cupy.isin

numpy.isinf

cupy.isinf

numpy.isnan

cupy.isnan

numpy.isnat

-

numpy.isneginf

-

numpy.isposinf

-

numpy.isreal

cupy.isreal

numpy.isrealobj

cupy.isrealobj

numpy.isscalar

cupy.isscalar

numpy.issctype

cupy.issctype (alias of numpy.issctype)

numpy.issubclass_

cupy.issubclass_ (alias of numpy.issubclass_)

numpy.issubdtype

cupy.issubdtype (alias of numpy.issubdtype)

numpy.issubsctype

cupy.issubsctype (alias of numpy.issubsctype)

numpy.iterable

-

numpy.ix_

cupy.ix_

numpy.kaiser

cupy.kaiser

numpy.kron

cupy.kron

numpy.lcm

cupy.lcm

numpy.ldexp

cupy.ldexp

numpy.left_shift

cupy.left_shift

numpy.less

cupy.less

numpy.less_equal

cupy.less_equal

numpy.lexsort

cupy.lexsort

numpy.linspace

cupy.linspace

numpy.load

cupy.load

numpy.loads

-

numpy.loadtxt

-

numpy.log

cupy.log

numpy.log10

cupy.log10

numpy.log1p

cupy.log1p

numpy.log2

cupy.log2

numpy.logaddexp

cupy.logaddexp

numpy.logaddexp2

cupy.logaddexp2

numpy.logical_and

cupy.logical_and

numpy.logical_not

cupy.logical_not

numpy.logical_or

cupy.logical_or

numpy.logical_xor

cupy.logical_xor

numpy.logspace

cupy.logspace

numpy.lookfor

-

numpy.mafromtxt

-

numpy.mask_indices

-

numpy.mat

-

numpy.matmul

cupy.matmul

numpy.max

cupy.max

numpy.maximum

cupy.maximum

numpy.maximum_sctype

-

numpy.may_share_memory

cupy.may_share_memory

numpy.mean

cupy.mean

numpy.median

cupy.median

numpy.meshgrid

cupy.meshgrid

numpy.min

cupy.min

numpy.min_scalar_type

cupy.min_scalar_type (alias of numpy.min_scalar_type)

numpy.minimum

cupy.minimum

numpy.mintypecode

cupy.mintypecode (alias of numpy.mintypecode)

numpy.mod

cupy.mod

numpy.modf

cupy.modf

numpy.moveaxis

cupy.moveaxis

numpy.msort

cupy.msort

numpy.multiply

cupy.multiply

numpy.nan_to_num

cupy.nan_to_num

numpy.nanargmax

cupy.nanargmax

numpy.nanargmin

cupy.nanargmin

numpy.nancumprod

-

numpy.nancumsum

-

numpy.nanmax

cupy.nanmax

numpy.nanmean

cupy.nanmean

numpy.nanmedian

-

numpy.nanmin

cupy.nanmin

numpy.nanpercentile

-

numpy.nanprod

cupy.nanprod

numpy.nanquantile

-

numpy.nanstd

cupy.nanstd

numpy.nansum

cupy.nansum

numpy.nanvar

cupy.nanvar

numpy.ndfromtxt

-

numpy.ndim

cupy.ndim

numpy.negative

cupy.negative

numpy.nested_iters

-

numpy.nextafter

cupy.nextafter

numpy.nonzero

cupy.nonzero

numpy.not_equal

cupy.not_equal

numpy.obj2sctype

cupy.obj2sctype (alias of numpy.obj2sctype)

numpy.ones

cupy.ones

numpy.ones_like

cupy.ones_like

numpy.outer

cupy.outer

numpy.packbits

cupy.packbits

numpy.pad

cupy.pad

numpy.partition

cupy.partition

numpy.percentile

cupy.percentile

numpy.piecewise

cupy.piecewise

numpy.place

cupy.place

numpy.poly

-

numpy.polyadd

cupy.polyadd

numpy.polyder

-

numpy.polydiv

-

numpy.polyfit

-

numpy.polyint

-

numpy.polymul

cupy.polymul

numpy.polysub

cupy.polysub

numpy.polyval

cupy.polyval

numpy.positive

-

numpy.power

cupy.power

numpy.printoptions

-

numpy.prod

cupy.prod

numpy.product

-

numpy.promote_types

cupy.promote_types (alias of numpy.promote_types)

numpy.ptp

cupy.ptp

numpy.put

cupy.put

numpy.put_along_axis

-

numpy.putmask

cupy.putmask

numpy.quantile

-

numpy.rad2deg

cupy.rad2deg

numpy.radians

cupy.radians

numpy.ravel

cupy.ravel

numpy.ravel_multi_index

cupy.ravel_multi_index

numpy.real

cupy.real

numpy.real_if_close

-

numpy.recfromcsv

-

numpy.recfromtxt

-

numpy.reciprocal

cupy.reciprocal

numpy.remainder

cupy.remainder

numpy.repeat

cupy.repeat

numpy.require

cupy.require

numpy.reshape

cupy.reshape

numpy.resize

-

numpy.result_type

cupy.result_type

numpy.right_shift

cupy.right_shift

numpy.rint

cupy.rint

numpy.roll

cupy.roll

numpy.rollaxis

cupy.rollaxis

numpy.roots

cupy.roots

numpy.rot90

cupy.rot90

numpy.round

-

numpy.round_

cupy.round_

numpy.row_stack

-

numpy.safe_eval

-

numpy.save

cupy.save

numpy.savetxt

-

numpy.savez

cupy.savez

numpy.savez_compressed

cupy.savez_compressed

numpy.sctype2char

cupy.sctype2char (alias of numpy.sctype2char)

numpy.searchsorted

cupy.searchsorted

numpy.select

cupy.select

numpy.set_numeric_ops

-

numpy.set_printoptions

-

numpy.set_string_function

-

numpy.setbufsize

-

numpy.setdiff1d

-

numpy.seterr

-

numpy.seterrcall

-

numpy.seterrobj

-

numpy.setxor1d

-

numpy.shape

cupy.shape

numpy.shares_memory

cupy.shares_memory

numpy.show_config

cupy.show_config

numpy.sign

cupy.sign

numpy.signbit

cupy.signbit

numpy.sin

cupy.sin

numpy.sinc

cupy.sinc

numpy.sinh

cupy.sinh

numpy.size

cupy.size

numpy.sometrue

-

numpy.sort

cupy.sort

numpy.sort_complex

cupy.sort_complex

numpy.source

-

numpy.spacing

-

numpy.split

cupy.split

numpy.sqrt

cupy.sqrt

numpy.square

cupy.square

numpy.squeeze

cupy.squeeze

numpy.stack

cupy.stack

numpy.std

cupy.std

numpy.subtract

cupy.subtract

numpy.sum

cupy.sum

numpy.swapaxes

cupy.swapaxes

numpy.take

cupy.take

numpy.take_along_axis

cupy.take_along_axis

numpy.tan

cupy.tan

numpy.tanh

cupy.tanh

numpy.tensordot

cupy.tensordot

numpy.tile

cupy.tile

numpy.trace

cupy.trace

numpy.transpose

cupy.transpose

numpy.trapz

-

numpy.tri

cupy.tri

numpy.tril

cupy.tril

numpy.tril_indices

-

numpy.tril_indices_from

-

numpy.trim_zeros

cupy.trim_zeros

numpy.triu

cupy.triu

numpy.triu_indices

-

numpy.triu_indices_from

-

numpy.true_divide

cupy.true_divide

numpy.trunc

cupy.trunc

numpy.typename

cupy.typename (alias of numpy.typename)

numpy.union1d

-

numpy.unique

cupy.unique

numpy.unpackbits

cupy.unpackbits

numpy.unravel_index

cupy.unravel_index

numpy.unwrap

cupy.unwrap

numpy.vander

-

numpy.var

cupy.var

numpy.vdot

cupy.vdot

numpy.vsplit

cupy.vsplit

numpy.vstack

cupy.vstack

numpy.where

cupy.where

numpy.who

cupy.who

numpy.zeros

cupy.zeros

numpy.zeros_like

cupy.zeros_like

Multi-Dimensional Array

NumPy

CuPy

numpy.ndarray.all()

cupy.ndarray.all()

numpy.ndarray.any()

cupy.ndarray.any()

numpy.ndarray.argmax()

cupy.ndarray.argmax()

numpy.ndarray.argmin()

cupy.ndarray.argmin()

numpy.ndarray.argpartition()

cupy.ndarray.argpartition()

numpy.ndarray.argsort()

cupy.ndarray.argsort()

numpy.ndarray.astype()

cupy.ndarray.astype()

numpy.ndarray.byteswap()

-

numpy.ndarray.choose()

cupy.ndarray.choose()

numpy.ndarray.clip()

cupy.ndarray.clip()

numpy.ndarray.compress()

cupy.ndarray.compress()

numpy.ndarray.conj()

cupy.ndarray.conj()

numpy.ndarray.conjugate()

cupy.ndarray.conjugate()

numpy.ndarray.copy()

cupy.ndarray.copy()

numpy.ndarray.cumprod()

cupy.ndarray.cumprod()

numpy.ndarray.cumsum()

cupy.ndarray.cumsum()

numpy.ndarray.diagonal()

cupy.ndarray.diagonal()

numpy.ndarray.dot()

cupy.ndarray.dot()

numpy.ndarray.dump()

cupy.ndarray.dump()

numpy.ndarray.dumps()

cupy.ndarray.dumps()

numpy.ndarray.fill()

cupy.ndarray.fill()

numpy.ndarray.flatten()

cupy.ndarray.flatten()

numpy.ndarray.getfield()

-

numpy.ndarray.item()

cupy.ndarray.item()

numpy.ndarray.itemset()

-

numpy.ndarray.max()

cupy.ndarray.max()

numpy.ndarray.mean()

cupy.ndarray.mean()

numpy.ndarray.min()

cupy.ndarray.min()

numpy.ndarray.newbyteorder()

-

numpy.ndarray.nonzero()

cupy.ndarray.nonzero()

numpy.ndarray.partition()

cupy.ndarray.partition()

numpy.ndarray.prod()

cupy.ndarray.prod()

numpy.ndarray.ptp()

cupy.ndarray.ptp()

numpy.ndarray.put()

cupy.ndarray.put()

numpy.ndarray.ravel()

cupy.ndarray.ravel()

numpy.ndarray.repeat()

cupy.ndarray.repeat()

numpy.ndarray.reshape()

cupy.ndarray.reshape()

numpy.ndarray.resize()

-

numpy.ndarray.round()

cupy.ndarray.round()

numpy.ndarray.searchsorted()

-

numpy.ndarray.setfield()

-

numpy.ndarray.setflags()

-

numpy.ndarray.sort()

cupy.ndarray.sort()

numpy.ndarray.squeeze()

cupy.ndarray.squeeze()

numpy.ndarray.std()

cupy.ndarray.std()

numpy.ndarray.sum()

cupy.ndarray.sum()

numpy.ndarray.swapaxes()

cupy.ndarray.swapaxes()

numpy.ndarray.take()

cupy.ndarray.take()

numpy.ndarray.tobytes()

cupy.ndarray.tobytes()

numpy.ndarray.tofile()

cupy.ndarray.tofile()

numpy.ndarray.tolist()

cupy.ndarray.tolist()

numpy.ndarray.tostring()

-

numpy.ndarray.trace()

cupy.ndarray.trace()

numpy.ndarray.transpose()

cupy.ndarray.transpose()

numpy.ndarray.var()

cupy.ndarray.var()

numpy.ndarray.view()

cupy.ndarray.view()

Random Sampling

NumPy

CuPy

numpy.random.beta

cupy.random.beta

numpy.random.binomial

cupy.random.binomial

numpy.random.bytes

cupy.random.bytes

numpy.random.chisquare

cupy.random.chisquare

numpy.random.choice

cupy.random.choice

numpy.random.default_rng

-

numpy.random.dirichlet

cupy.random.dirichlet

numpy.random.exponential

cupy.random.exponential

numpy.random.f

cupy.random.f

numpy.random.gamma

cupy.random.gamma

numpy.random.geometric

cupy.random.geometric

numpy.random.get_state

-

numpy.random.gumbel

cupy.random.gumbel

numpy.random.hypergeometric

cupy.random.hypergeometric

numpy.random.laplace

cupy.random.laplace

numpy.random.logistic

cupy.random.logistic

numpy.random.lognormal

cupy.random.lognormal

numpy.random.logseries

cupy.random.logseries

numpy.random.multinomial

cupy.random.multinomial

numpy.random.multivariate_normal

cupy.random.multivariate_normal

numpy.random.negative_binomial

cupy.random.negative_binomial

numpy.random.noncentral_chisquare

cupy.random.noncentral_chisquare

numpy.random.noncentral_f

cupy.random.noncentral_f

numpy.random.normal

cupy.random.normal

numpy.random.pareto

cupy.random.pareto

numpy.random.permutation

cupy.random.permutation

numpy.random.poisson

cupy.random.poisson

numpy.random.power

cupy.random.power

numpy.random.rand

cupy.random.rand

numpy.random.randint

cupy.random.randint

numpy.random.randn

cupy.random.randn

numpy.random.random

cupy.random.random

numpy.random.random_integers

cupy.random.random_integers

numpy.random.random_sample

cupy.random.random_sample

numpy.random.ranf

cupy.random.ranf

numpy.random.rayleigh

cupy.random.rayleigh

numpy.random.sample

cupy.random.sample

numpy.random.seed

cupy.random.seed

numpy.random.set_state

-

numpy.random.shuffle

cupy.random.shuffle

numpy.random.standard_cauchy

cupy.random.standard_cauchy

numpy.random.standard_exponential

cupy.random.standard_exponential

numpy.random.standard_gamma

cupy.random.standard_gamma

numpy.random.standard_normal

cupy.random.standard_normal

numpy.random.standard_t

cupy.random.standard_t

numpy.random.triangular

cupy.random.triangular

numpy.random.uniform

cupy.random.uniform

numpy.random.vonmises

cupy.random.vonmises

numpy.random.wald

cupy.random.wald

numpy.random.weibull

cupy.random.weibull

numpy.random.zipf

cupy.random.zipf

SciPy / CuPy APIs

Advanced Linear Algebra

SciPy

CuPy

scipy.linalg.block_diag

cupyx.scipy.linalg.block_diag

scipy.linalg.cdf2rdf

-

scipy.linalg.cho_factor

-

scipy.linalg.cho_solve

-

scipy.linalg.cho_solve_banded

-

scipy.linalg.cholesky_banded

-

scipy.linalg.circulant

cupyx.scipy.linalg.circulant

scipy.linalg.clarkson_woodruff_transform

-

scipy.linalg.companion

cupyx.scipy.linalg.companion

scipy.linalg.convolution_matrix

cupyx.scipy.linalg.convolution_matrix

scipy.linalg.coshm

-

scipy.linalg.cosm

-

scipy.linalg.cossin

-

scipy.linalg.dft

cupyx.scipy.linalg.dft

scipy.linalg.diagsvd

-

scipy.linalg.eig_banded

-

scipy.linalg.eigh_tridiagonal

-

scipy.linalg.eigvals_banded

-

scipy.linalg.eigvalsh_tridiagonal

-

scipy.linalg.expm

-

scipy.linalg.expm_cond

-

scipy.linalg.expm_frechet

-

scipy.linalg.fiedler

cupyx.scipy.linalg.fiedler

scipy.linalg.fiedler_companion

cupyx.scipy.linalg.fiedler_companion

scipy.linalg.find_best_blas_type

-

scipy.linalg.fractional_matrix_power

-

scipy.linalg.funm

-

scipy.linalg.get_blas_funcs

-

scipy.linalg.get_lapack_funcs

-

scipy.linalg.hadamard

cupyx.scipy.linalg.hadamard

scipy.linalg.hankel

cupyx.scipy.linalg.hankel

scipy.linalg.helmert

cupyx.scipy.linalg.helmert

scipy.linalg.hessenberg

-

scipy.linalg.hilbert

cupyx.scipy.linalg.hilbert

scipy.linalg.invhilbert

-

scipy.linalg.invpascal

-

scipy.linalg.khatri_rao

-

scipy.linalg.kron

cupyx.scipy.linalg.kron

scipy.linalg.ldl

-

scipy.linalg.leslie

cupyx.scipy.linalg.leslie

scipy.linalg.logm

-

scipy.linalg.lu

-

scipy.linalg.lu_factor

cupyx.scipy.linalg.lu_factor

scipy.linalg.lu_solve

cupyx.scipy.linalg.lu_solve

scipy.linalg.matmul_toeplitz

-

scipy.linalg.matrix_balance

-

scipy.linalg.null_space

-

scipy.linalg.ordqz

-

scipy.linalg.orth

-

scipy.linalg.orthogonal_procrustes

-

scipy.linalg.pascal

-

scipy.linalg.pinv2

-

scipy.linalg.pinvh

-

scipy.linalg.polar

-

scipy.linalg.qr_delete

-

scipy.linalg.qr_insert

-

scipy.linalg.qr_multiply

-

scipy.linalg.qr_update

-

scipy.linalg.qz

-

scipy.linalg.rq

-

scipy.linalg.rsf2csf

-

scipy.linalg.schur

-

scipy.linalg.signm

-

scipy.linalg.sinhm

-

scipy.linalg.sinm

-

scipy.linalg.solve_banded

-

scipy.linalg.solve_circulant

-

scipy.linalg.solve_continuous_are

-

scipy.linalg.solve_continuous_lyapunov

-

scipy.linalg.solve_discrete_are

-

scipy.linalg.solve_discrete_lyapunov

-

scipy.linalg.solve_lyapunov

-

scipy.linalg.solve_sylvester

-

scipy.linalg.solve_toeplitz

-

scipy.linalg.solve_triangular

cupyx.scipy.linalg.solve_triangular

scipy.linalg.solveh_banded

-

scipy.linalg.sqrtm

-

scipy.linalg.subspace_angles

-

scipy.linalg.svdvals

-

scipy.linalg.tanhm

-

scipy.linalg.tanm

-

scipy.linalg.toeplitz

cupyx.scipy.linalg.toeplitz

scipy.linalg.tri

cupyx.scipy.linalg.tri

scipy.linalg.tril

cupyx.scipy.linalg.tril

scipy.linalg.triu

cupyx.scipy.linalg.triu

Multidimensional Image Processing

SciPy

CuPy

scipy.ndimage.affine_transform

cupyx.scipy.ndimage.affine_transform

scipy.ndimage.binary_closing

-

scipy.ndimage.binary_dilation

-

scipy.ndimage.binary_erosion

-

scipy.ndimage.binary_fill_holes

-

scipy.ndimage.binary_hit_or_miss

-

scipy.ndimage.binary_opening

-

scipy.ndimage.binary_propagation

-

scipy.ndimage.black_tophat

-

scipy.ndimage.center_of_mass

-

scipy.ndimage.convolve

cupyx.scipy.ndimage.convolve

scipy.ndimage.convolve1d

cupyx.scipy.ndimage.convolve1d

scipy.ndimage.correlate

cupyx.scipy.ndimage.correlate

scipy.ndimage.correlate1d

cupyx.scipy.ndimage.correlate1d

scipy.ndimage.distance_transform_bf

-

scipy.ndimage.distance_transform_cdt

-

scipy.ndimage.distance_transform_edt

-

scipy.ndimage.extrema

-

scipy.ndimage.find_objects

-

scipy.ndimage.fourier_ellipsoid

-

scipy.ndimage.fourier_gaussian

cupyx.scipy.ndimage.fourier_gaussian

scipy.ndimage.fourier_shift

cupyx.scipy.ndimage.fourier_shift

scipy.ndimage.fourier_uniform

cupyx.scipy.ndimage.fourier_uniform

scipy.ndimage.gaussian_filter

cupyx.scipy.ndimage.gaussian_filter

scipy.ndimage.gaussian_filter1d

cupyx.scipy.ndimage.gaussian_filter1d

scipy.ndimage.gaussian_gradient_magnitude

cupyx.scipy.ndimage.gaussian_gradient_magnitude

scipy.ndimage.gaussian_laplace

cupyx.scipy.ndimage.gaussian_laplace

scipy.ndimage.generate_binary_structure

-

scipy.ndimage.generic_filter

cupyx.scipy.ndimage.generic_filter

scipy.ndimage.generic_filter1d

cupyx.scipy.ndimage.generic_filter1d

scipy.ndimage.generic_gradient_magnitude

cupyx.scipy.ndimage.generic_gradient_magnitude

scipy.ndimage.generic_laplace

cupyx.scipy.ndimage.generic_laplace

scipy.ndimage.geometric_transform

-

scipy.ndimage.grey_closing

cupyx.scipy.ndimage.grey_closing

scipy.ndimage.grey_dilation

cupyx.scipy.ndimage.grey_dilation

scipy.ndimage.grey_erosion

cupyx.scipy.ndimage.grey_erosion

scipy.ndimage.grey_opening

cupyx.scipy.ndimage.grey_opening

scipy.ndimage.histogram

-

scipy.ndimage.iterate_structure

-

scipy.ndimage.label

cupyx.scipy.ndimage.label

scipy.ndimage.labeled_comprehension

-

scipy.ndimage.laplace

cupyx.scipy.ndimage.laplace

scipy.ndimage.map_coordinates

cupyx.scipy.ndimage.map_coordinates

scipy.ndimage.maximum

-

scipy.ndimage.maximum_filter

cupyx.scipy.ndimage.maximum_filter

scipy.ndimage.maximum_filter1d

cupyx.scipy.ndimage.maximum_filter1d

scipy.ndimage.maximum_position

-

scipy.ndimage.mean

cupyx.scipy.ndimage.mean

scipy.ndimage.median

-

scipy.ndimage.median_filter

cupyx.scipy.ndimage.median_filter

scipy.ndimage.minimum

-

scipy.ndimage.minimum_filter

cupyx.scipy.ndimage.minimum_filter

scipy.ndimage.minimum_filter1d

cupyx.scipy.ndimage.minimum_filter1d

scipy.ndimage.minimum_position

-

scipy.ndimage.morphological_gradient

-

scipy.ndimage.morphological_laplace

-

scipy.ndimage.percentile_filter

cupyx.scipy.ndimage.percentile_filter

scipy.ndimage.prewitt

cupyx.scipy.ndimage.prewitt

scipy.ndimage.rank_filter

cupyx.scipy.ndimage.rank_filter

scipy.ndimage.rotate

cupyx.scipy.ndimage.rotate

scipy.ndimage.shift

cupyx.scipy.ndimage.shift

scipy.ndimage.sobel

cupyx.scipy.ndimage.sobel

scipy.ndimage.spline_filter

-

scipy.ndimage.spline_filter1d

-

scipy.ndimage.standard_deviation

cupyx.scipy.ndimage.standard_deviation

scipy.ndimage.sum

cupyx.scipy.ndimage.sum

scipy.ndimage.sum_labels

-

scipy.ndimage.uniform_filter

cupyx.scipy.ndimage.uniform_filter

scipy.ndimage.uniform_filter1d

cupyx.scipy.ndimage.uniform_filter1d

scipy.ndimage.variance

cupyx.scipy.ndimage.variance

scipy.ndimage.watershed_ift

-

scipy.ndimage.white_tophat

-

scipy.ndimage.zoom

cupyx.scipy.ndimage.zoom

Special Functions

SciPy

CuPy

scipy.special.agm

-

scipy.special.ai_zeros

-

scipy.special.airy

-

scipy.special.airye

-

scipy.special.assoc_laguerre

-

scipy.special.bdtr

-

scipy.special.bdtrc

-

scipy.special.bdtri

-

scipy.special.bdtrik

-

scipy.special.bdtrin

-

scipy.special.bei

-

scipy.special.bei_zeros

-

scipy.special.beip

-

scipy.special.beip_zeros

-

scipy.special.ber

-

scipy.special.ber_zeros

-

scipy.special.bernoulli

-

scipy.special.berp

-

scipy.special.berp_zeros

-

scipy.special.besselpoly

-

scipy.special.beta

-

scipy.special.betainc

-

scipy.special.betaincinv

-

scipy.special.betaln

-

scipy.special.bi_zeros

-

scipy.special.binom

-

scipy.special.boxcox

-

scipy.special.boxcox1p

-

scipy.special.btdtr

-

scipy.special.btdtri

-

scipy.special.btdtria

-

scipy.special.btdtrib

-

scipy.special.c_roots

-

scipy.special.cbrt

-

scipy.special.cg_roots

-

scipy.special.chdtr

-

scipy.special.chdtrc

-

scipy.special.chdtri

-

scipy.special.chdtriv

-

scipy.special.chebyc

-

scipy.special.chebys

-

scipy.special.chebyt

-

scipy.special.chebyu

-

scipy.special.chndtr

-

scipy.special.chndtridf

-

scipy.special.chndtrinc

-

scipy.special.chndtrix

-

scipy.special.clpmn

-

scipy.special.comb

-

scipy.special.cosdg

-

scipy.special.cosm1

-

scipy.special.cotdg

-

scipy.special.dawsn

-

scipy.special.digamma

cupyx.scipy.special.digamma

scipy.special.diric

-

scipy.special.ellip_harm

-

scipy.special.ellip_harm_2

-

scipy.special.ellip_normal

-

scipy.special.ellipe

-

scipy.special.ellipeinc

-

scipy.special.ellipj

-

scipy.special.ellipk

-

scipy.special.ellipkinc

-

scipy.special.ellipkm1

-

scipy.special.entr

cupyx.scipy.special.entr

scipy.special.erf

cupyx.scipy.special.erf

scipy.special.erf_zeros

-

scipy.special.erfc

cupyx.scipy.special.erfc

scipy.special.erfcinv

cupyx.scipy.special.erfcinv

scipy.special.erfcx

cupyx.scipy.special.erfcx

scipy.special.erfi

-

scipy.special.erfinv

cupyx.scipy.special.erfinv

scipy.special.euler

-

scipy.special.eval_chebyc

-

scipy.special.eval_chebys

-

scipy.special.eval_chebyt

-

scipy.special.eval_chebyu

-

scipy.special.eval_gegenbauer

-

scipy.special.eval_genlaguerre

-

scipy.special.eval_hermite

-

scipy.special.eval_hermitenorm

-

scipy.special.eval_jacobi

-

scipy.special.eval_laguerre

-

scipy.special.eval_legendre

-

scipy.special.eval_sh_chebyt

-

scipy.special.eval_sh_chebyu

-

scipy.special.eval_sh_jacobi

-

scipy.special.eval_sh_legendre

-

scipy.special.exp1

-

scipy.special.exp10

-

scipy.special.exp2

-

scipy.special.expi

-

scipy.special.expit

-

scipy.special.expm1

-

scipy.special.expn

-

scipy.special.exprel

-

scipy.special.factorial

-

scipy.special.factorial2

-

scipy.special.factorialk

-

scipy.special.fdtr

-

scipy.special.fdtrc

-

scipy.special.fdtri

-

scipy.special.fdtridfd

-

scipy.special.fresnel

-

scipy.special.fresnel_zeros

-

scipy.special.fresnelc_zeros

-

scipy.special.fresnels_zeros

-

scipy.special.gamma

cupyx.scipy.special.gamma

scipy.special.gammainc

-

scipy.special.gammaincc

-

scipy.special.gammainccinv

-

scipy.special.gammaincinv

-

scipy.special.gammaln

cupyx.scipy.special.gammaln

scipy.special.gammasgn

-

scipy.special.gdtr

-

scipy.special.gdtrc

-

scipy.special.gdtria

-

scipy.special.gdtrib

-

scipy.special.gdtrix

-

scipy.special.gegenbauer

-

scipy.special.genlaguerre

-

scipy.special.geterr

-

scipy.special.h1vp

-

scipy.special.h2vp

-

scipy.special.h_roots

-

scipy.special.hankel1

-

scipy.special.hankel1e

-

scipy.special.hankel2

-

scipy.special.hankel2e

-

scipy.special.he_roots

-

scipy.special.hermite

-

scipy.special.hermitenorm

-

scipy.special.huber

cupyx.scipy.special.huber

scipy.special.hyp0f1

-

scipy.special.hyp1f1

-

scipy.special.hyp2f1

-

scipy.special.hyperu

-

scipy.special.i0

cupyx.scipy.special.i0

scipy.special.i0e

-

scipy.special.i1

cupyx.scipy.special.i1

scipy.special.i1e

-

scipy.special.inv_boxcox

-

scipy.special.inv_boxcox1p

-

scipy.special.it2i0k0

-

scipy.special.it2j0y0

-

scipy.special.it2struve0

-

scipy.special.itairy

-

scipy.special.iti0k0

-

scipy.special.itj0y0

-

scipy.special.itmodstruve0

-

scipy.special.itstruve0

-

scipy.special.iv

-

scipy.special.ive

-

scipy.special.ivp

-

scipy.special.j0

cupyx.scipy.special.j0

scipy.special.j1

cupyx.scipy.special.j1

scipy.special.j_roots

-

scipy.special.jacobi

-

scipy.special.jn

-

scipy.special.jn_zeros

-

scipy.special.jnjnp_zeros

-

scipy.special.jnp_zeros

-

scipy.special.jnyn_zeros

-

scipy.special.js_roots

-

scipy.special.jv

-

scipy.special.jve

-

scipy.special.jvp

-

scipy.special.k0

-

scipy.special.k0e

-

scipy.special.k1

-

scipy.special.k1e

-

scipy.special.kei

-

scipy.special.kei_zeros

-

scipy.special.keip

-

scipy.special.keip_zeros

-

scipy.special.kelvin

-

scipy.special.kelvin_zeros

-

scipy.special.ker

-

scipy.special.ker_zeros

-

scipy.special.kerp

-

scipy.special.kerp_zeros

-

scipy.special.kl_div

cupyx.scipy.special.kl_div

scipy.special.kn

-

scipy.special.kolmogi

-

scipy.special.kolmogorov

-

scipy.special.kv

-

scipy.special.kve

-

scipy.special.kvp

-

scipy.special.l_roots

-

scipy.special.la_roots

-

scipy.special.laguerre

-

scipy.special.lambertw

-

scipy.special.legendre

-

scipy.special.lmbda

-

scipy.special.log1p

-

scipy.special.log_ndtr

-

scipy.special.log_softmax

-

scipy.special.loggamma

-

scipy.special.logit

-

scipy.special.logsumexp

-

scipy.special.lpmn

-

scipy.special.lpmv

-

scipy.special.lpn

-

scipy.special.lqmn

-

scipy.special.lqn

-

scipy.special.mathieu_a

-

scipy.special.mathieu_b

-

scipy.special.mathieu_cem

-

scipy.special.mathieu_even_coef

-

scipy.special.mathieu_modcem1

-

scipy.special.mathieu_modcem2

-

scipy.special.mathieu_modsem1

-

scipy.special.mathieu_modsem2

-

scipy.special.mathieu_odd_coef

-

scipy.special.mathieu_sem

-

scipy.special.modfresnelm

-

scipy.special.modfresnelp

-

scipy.special.modstruve

-

scipy.special.multigammaln

-

scipy.special.nbdtr

-

scipy.special.nbdtrc

-

scipy.special.nbdtri

-

scipy.special.nbdtrik

-

scipy.special.nbdtrin

-

scipy.special.ncfdtr

-

scipy.special.ncfdtri

-

scipy.special.ncfdtridfd

-

scipy.special.ncfdtridfn

-

scipy.special.ncfdtrinc

-

scipy.special.nctdtr

-

scipy.special.nctdtridf

-

scipy.special.nctdtrinc

-

scipy.special.nctdtrit

-

scipy.special.ndtr

cupyx.scipy.special.ndtr

scipy.special.ndtri

-

scipy.special.nrdtrimn

-

scipy.special.nrdtrisd

-

scipy.special.obl_ang1

-

scipy.special.obl_ang1_cv

-

scipy.special.obl_cv

-

scipy.special.obl_cv_seq

-

scipy.special.obl_rad1

-

scipy.special.obl_rad1_cv

-

scipy.special.obl_rad2

-

scipy.special.obl_rad2_cv

-

scipy.special.owens_t

-

scipy.special.p_roots

-

scipy.special.pbdn_seq

-

scipy.special.pbdv

-

scipy.special.pbdv_seq

-

scipy.special.pbvv

-

scipy.special.pbvv_seq

-

scipy.special.pbwa

-

scipy.special.pdtr

-

scipy.special.pdtrc

-

scipy.special.pdtri

-

scipy.special.pdtrik

-

scipy.special.perm

-

scipy.special.poch

-

scipy.special.polygamma

cupyx.scipy.special.polygamma

scipy.special.pro_ang1

-

scipy.special.pro_ang1_cv

-

scipy.special.pro_cv

-

scipy.special.pro_cv_seq

-

scipy.special.pro_rad1

-

scipy.special.pro_rad1_cv

-

scipy.special.pro_rad2

-

scipy.special.pro_rad2_cv

-

scipy.special.ps_roots

-

scipy.special.pseudo_huber

cupyx.scipy.special.pseudo_huber

scipy.special.psi

-

scipy.special.radian

-

scipy.special.rel_entr

cupyx.scipy.special.rel_entr

scipy.special.rgamma

-

scipy.special.riccati_jn

-

scipy.special.riccati_yn

-

scipy.special.roots_chebyc

-

scipy.special.roots_chebys

-

scipy.special.roots_chebyt

-

scipy.special.roots_chebyu

-

scipy.special.roots_gegenbauer

-

scipy.special.roots_genlaguerre

-

scipy.special.roots_hermite

-

scipy.special.roots_hermitenorm

-

scipy.special.roots_jacobi

-

scipy.special.roots_laguerre

-

scipy.special.roots_legendre

-

scipy.special.roots_sh_chebyt

-

scipy.special.roots_sh_chebyu

-

scipy.special.roots_sh_jacobi

-

scipy.special.roots_sh_legendre

-

scipy.special.round

-

scipy.special.s_roots

-

scipy.special.seterr

-

scipy.special.sh_chebyt

-

scipy.special.sh_chebyu

-

scipy.special.sh_jacobi

-

scipy.special.sh_legendre

-

scipy.special.shichi

-

scipy.special.sici

-

scipy.special.sinc

-

scipy.special.sindg

-

scipy.special.smirnov

-

scipy.special.smirnovi

-

scipy.special.softmax

-

scipy.special.spence

-

scipy.special.sph_harm

-

scipy.special.spherical_in

-

scipy.special.spherical_jn

-

scipy.special.spherical_kn

-

scipy.special.spherical_yn

-

scipy.special.stdtr

-

scipy.special.stdtridf

-

scipy.special.stdtrit

-

scipy.special.struve

-

scipy.special.t_roots

-

scipy.special.tandg

-

scipy.special.tklmbda

-

scipy.special.ts_roots

-

scipy.special.u_roots

-

scipy.special.us_roots

-

scipy.special.voigt_profile

-

scipy.special.wofz

-

scipy.special.wrightomega

-

scipy.special.xlog1py

-

scipy.special.xlogy

-

scipy.special.y0

cupyx.scipy.special.y0

scipy.special.y0_zeros

-

scipy.special.y1

cupyx.scipy.special.y1

scipy.special.y1_zeros

-

scipy.special.y1p_zeros

-

scipy.special.yn

-

scipy.special.yn_zeros

-

scipy.special.ynp_zeros

-

scipy.special.yv

-

scipy.special.yve

-

scipy.special.yvp

-

scipy.special.zeta

cupyx.scipy.special.zeta

scipy.special.zetac

-

Miscellaneous functions

cupy.may_share_memory

cupy.shares_memory

cupy.show_config

Prints the current runtime configuration to standard output.

cupy.who

Print the CuPy arrays in the given dictionary.

API Compatibility Policy

This document expresses the design policy on compatibilities of CuPy APIs. Development team should obey this policy on deciding to add, extend, and change APIs and their behaviors.

This document is written for both users and developers. Users can decide the level of dependencies on CuPy’s implementations in their codes based on this document. Developers should read through this document before creating pull requests that contain changes on the interface. Note that this document may contain ambiguities on the level of supported compatibilities.

Versioning and Backward Compatibilities

The updates of CuPy are classified into three levels: major, minor, and revision. These types have distinct levels of backward compatibilities.

  • Major update contains disruptive changes that break the backward compatibility.

  • Minor update contains addition and extension to the APIs keeping the supported backward compatibility.

  • Revision update contains improvements on the API implementations without changing any API specifications.

Note that we do not support full backward compatibility, which is almost infeasible for Python-based APIs, since there is no way to completely hide the implementation details.

Processes to Break Backward Compatibilities

Deprecation, Dropping, and Its Preparation

Any APIs may be deprecated at some minor updates. In such a case, the deprecation note is added to the API documentation, and the API implementation is changed to fire deprecation warning (if possible). There should be another way to reimplement the same things previously written with the deprecated APIs.

Any APIs may be marked as to be dropped in the future. In such a case, the dropping is stated in the documentation with the major version number on which the API is planned to be dropped, and the API implementation is changed to fire the future warning (if possible).

The actual dropping should be done through the following steps:

  • Make the API deprecated. At this point, users should not need the deprecated API in their new application codes.

  • After that, mark the API as to be dropped in the future. It must be done in the minor update different from that of the deprecation.

  • At the major version announced in the above update, drop the API.

Consequently, it takes at least two minor versions to drop any APIs after the first deprecation.

API Changes and Its Preparation

Any APIs may be marked as to be changed in the future for changes without backward compatibility. In such a case, the change is stated in the documentation with the version number on which the API is planned to be changed, and the API implementation is changed to fire the future warning on the certain usages.

The actual change should be done in the following steps:

  • Announce that the API will be changed in the future. At this point, the actual version of change need not be accurate.

  • After the announcement, mark the API as to be changed in the future with version number of planned changes. At this point, users should not use the marked API in their new application codes.

  • At the major update announced in the above update, change the API.

Supported Backward Compatibility

This section defines backward compatibilities that minor updates must maintain.

Documented Interface

CuPy has the official API documentation. Many applications can be written based on the documented features. We support backward compatibilities of documented features. In other words, codes only based on the documented features run correctly with minor/revision-updated versions.

Developers are encouraged to use apparent names for objects of implementation details. For example, attributes outside of the documented APIs should have one or more underscores at the prefix of their names.

Undocumented behaviors

Behaviors of CuPy implementation not stated in the documentation are undefined. Undocumented behaviors are not guaranteed to be stable between different minor/revision versions.

Minor update may contain changes to undocumented behaviors. For example, suppose an API X is added at the minor update. In the previous version, attempts to use X cause AttributeError. This behavior is not stated in the documentation, so this is undefined. Thus, adding the API X in minor version is permissible.

Revision update may also contain changes to undefined behaviors. Typical example is a bug fix. Another example is an improvement on implementation, which may change the internal object structures not shown in the documentation. As a consequence, even revision updates do not support compatibility of pickling, unless the full layout of pickled objects is clearly documented.

Documentation Error

Compatibility is basically determined based on the documentation, though it sometimes contains errors. It may make the APIs confusing to assume the documentation always stronger than the implementations. We therefore may fix the documentation errors in any updates that may break the compatibility in regard to the documentation.

Note

Developers MUST NOT fix the documentation and implementation of the same functionality at the same time in revision updates as “bug fix”. Such a change completely breaks the backward compatibility. If you want to fix the bugs in both sides, first fix the documentation to fit it into the implementation, and start the API changing procedure described above.

Object Attributes and Properties

Object attributes and properties are sometimes replaced by each other at minor updates. It does not break the user codes, except the codes depend on how the attributes and properties are implemented.

Functions and Methods

Methods may be replaced by callable attributes keeping the compatibility of parameters and return values in minor updates. It does not break the user codes, except the codes depend on how the methods and callable attributes are implemented.

Exceptions and Warnings

The specifications of raising exceptions are considered as a part of standard backward compatibilities. No exception is raised in the future versions with correct usages that the documentation allows, unless the API changing process is completed.

On the other hand, warnings may be added at any minor updates for any APIs. It means minor updates do not keep backward compatibility of warnings.

Installation Compatibility

The installation process is another concern of compatibilities. We support environmental compatibilities in the following ways.

  • Any changes of dependent libraries that force modifications on the existing environments must be done in major updates. Such changes include following cases:

    • dropping supported versions of dependent libraries (e.g. dropping cuDNN v2)

    • adding new mandatory dependencies (e.g. adding h5py to setup_requires)

  • Supporting optional packages/libraries may be done in minor updates (e.g. supporting h5py in optional features).

Note

The installation compatibility does not guarantee that all the features of CuPy correctly run on supported environments. It may contain bugs that only occurs in certain environments. Such bugs should be fixed in some updates.

Contribution Guide

This is a guide for all contributions to CuPy. The development of CuPy is running on the official repository at GitHub. Anyone that wants to register an issue or to send a pull request should read through this document.

Classification of Contributions

There are several ways to contribute to CuPy community:

  1. Registering an issue

  2. Sending a pull request (PR)

  3. Sending a question to CuPy User Group

  4. Open-sourcing an external example

  5. Writing a post about CuPy

This document mainly focuses on 1 and 2, though other contributions are also appreciated.

Development Cycle

This section explains the development process of CuPy. Before contributing to CuPy, it is strongly recommended to understand the development cycle.

Versioning

The versioning of CuPy follows PEP 440 and a part of Semantic versioning. The version number consists of three or four parts: X.Y.Zw where X denotes the major version, Y denotes the minor version, Z denotes the revision number, and the optional w denotes the prelease suffix. While the major, minor, and revision numbers follow the rule of semantic versioning, the pre-release suffix follows PEP 440 so that the version string is much friendly with Python eco-system.

Note that a major update basically does not contain compatibility-breaking changes from the last release candidate (RC). This is not a strict rule, though; if there is a critical API bug that we have to fix for the major version, we may add breaking changes to the major version up.

As for the backward compatibility, see API Compatibility Policy.

Release Cycle

The first one is the track of stable versions, which is a series of revision updates for the latest major version. The second one is the track of development versions, which is a series of pre-releases for the upcoming major version.

Consider that X.0.0 is the latest major version and Y.0.0, Z.0.0 are the succeeding major versions. Then, the timeline of the updates is depicted by the following table.

Date

ver X

ver Y

ver Z

0 weeks

X.0.0rc1

4 weeks

X.0.0

Y.0.0a1

8 weeks

X.1.0*

Y.0.0b1

12 weeks

X.2.0*

Y.0.0rc1

16 weeks

Y.0.0

Z.0.0a1

(* These might be revision releases)

The dates shown in the left-most column are relative to the release of X.0.0rc1. In particular, each revision/minor release is made four weeks after the previous one of the same major version, and the pre-release of the upcoming major version is made at the same time. Whether these releases are revision or minor is determined based on the contents of each update.

Note that there are only three stable releases for the versions X.x.x. During the parallel development of Y.0.0 and Z.0.0a1, the version Y is treated as an almost-stable version and Z is treated as a development version.

If there is a critical bug found in X.x.x after stopping the development of version X, we may release a hot-fix for this version at any time.

We create a milestone for each upcoming release at GitHub. The GitHub milestone is basically used for collecting the issues and PRs resolved in the release.

Git Branches

The master branch is used to develop pre-release versions. It means that alpha, beta, and RC updates are developed at the master branch. This branch contains the most up-to-date source tree that includes features newly added after the latest major version.

The stable version is developed at the individual branch named as vN where “N” reflects the version number (we call it a versioned branch). For example, v1.0.0, v1.0.1, and v1.0.2 will be developed at the v1 branch.

Notes for contributors: When you send a pull request, you basically have to send it to the master branch. If the change can also be applied to the stable version, a core team member will apply the same change to the stable version so that the change is also included in the next revision update.

If the change is only applicable to the stable version and not to the master branch, please send it to the versioned branch. We basically only accept changes to the latest versioned branch (where the stable version is developed) unless the fix is critical.

If you want to make a new feature of the master branch available in the current stable version, please send a backport PR to the stable version (the latest vN branch). See the next section for details.

Note: a change that can be applied to both branches should be sent to the master branch. Each release of the stable version is also merged to the development version so that the change is also reflected to the next major version.

Feature Backport PRs

We basically do not backport any new features of the development version to the stable versions. If you desire to include the feature to the current stable version and you can work on the backport work, we welcome such a contribution. In such a case, you have to send a backport PR to the latest vN branch. Note that we do not accept any feature backport PRs to older versions because we are not running quality assurance workflows (e.g. CI) for older versions so that we cannot ensure that the PR is correctly ported.

There are some rules on sending a backport PR.

  • Start the PR title from the prefix [backport].

  • Clarify the original PR number in the PR description (something like “This is a backport of #XXXX”).

  • (optional) Write to the PR description the motivation of backporting the feature to the stable version.

Please follow these rules when you create a feature backport PR.

Note: PRs that do not include any changes/additions to APIs (e.g. bug fixes, documentation improvements) are usually backported by core dev members. It is also appreciated to make such a backport PR by any contributors, though, so that the overall development proceeds more smoothly!

Issues and Pull Requests

In this section, we explain how to file issues and send pull requests (PRs).

Issue/PR Labels

Issues and PRs are labeled by the following tags:

  • Bug: bug reports (issues) and bug fixes (PRs)

  • Enhancement: implementation improvements without breaking the interface

  • Feature: feature requests (issues) and their implementations (PRs)

  • NoCompat: disrupts backward compatibility

  • Test: test fixes and updates

  • Document: document fixes and improvements

  • Example: fixes and improvements on the examples

  • Install: fixes installation script

  • Contribution-Welcome: issues that we request for contribution (only issues are categorized to this)

  • Other: other issues and PRs

Multiple tags might be labeled to one issue/PR. Note that revision releases cannot include PRs in Feature and NoCompat categories.

How to File an Issue

On registering an issue, write precise explanations on how you want CuPy to be. Bug reports must include necessary and sufficient conditions to reproduce the bugs. Feature requests must include what you want to do (and why you want to do, if needed) with CuPy. You can contain your thoughts on how to realize it into the feature requests, though what part is most important for discussions.

Warning

If you have a question on usages of CuPy, it is highly recommended to send a post to CuPy User Group instead of the issue tracker. The issue tracker is not a place to share knowledge on practices. We may suggest these places and immediately close how-to question issues.

How to Send a Pull Request

If you can write code to fix an issue, we encourage to send a PR.

First of all, before starting to write any code, do not forget to confirm the following points.

  • Read through the Coding Guidelines and Unit Testing.

  • Check the appropriate branch that you should send the PR following Git Branches. If you do not have any idea about selecting a branch, please choose the master branch.

In particular, check the branch before writing any code. The current source tree of the chosen branch is the starting point of your change.

After writing your code (including unit tests and hopefully documentations!), send a PR on GitHub. You have to write a precise explanation of what and how you fix; it is the first documentation of your code that developers read, which is a very important part of your PR.

Once you send a PR, it is automatically tested on GitHub Actions for Linux, and on AppVeyor for Windows. Your PR needs to pass at least the test for Linux on Travis CI. After the automatic test passes, some of the core developers will start reviewing your code. Note that this automatic PR test only includes CPU tests.

Note

We are also running continuous integration with GPU tests for the master branch and the versioned branch of the latest major version. Since this service is currently running on our internal server, we do not use it for automatic PR tests to keep the server secure.

If you are planning to add a new feature or modify existing APIs, it is recommended to open an issue and discuss the design first. The design discussion needs lower cost for the core developers than code review. Following the consequences of the discussions, you can send a PR that is smoothly reviewed in a shorter time.

Even if your code is not complete, you can send a pull request as a work-in-progress PR by putting the [WIP] prefix to the PR title. If you write a precise explanation about the PR, core developers and other contributors can join the discussion about how to proceed the PR. WIP PR is also useful to have discussions based on a concrete code.

Coding Guidelines

Note

Coding guidelines are updated at v5.0. Those who have contributed to older versions should read the guidelines again.

We use PEP8 and a part of OpenStack Style Guidelines related to general coding style as our basic style guidelines.

You can use autopep8 and flake8 commands to check your code.

In order to avoid confusion from using different tool versions, we pin the versions of those tools. Install them with the following command (from within the top directory of CuPy repository):

$ pip install -e '.[stylecheck]'

And check your code with:

$ autopep8 path/to/your/code.py
$ flake8 path/to/your/code.py

To check Cython code, use .flake8.cython configuration file:

$ flake8 --config=.flake8.cython path/to/your/cython/code.pyx

The autopep8 supports automatically correct Python code to conform to the PEP 8 style guide:

$ autopep8 --in-place path/to/your/code.py

The flake8 command lets you know the part of your code not obeying our style guidelines. Before sending a pull request, be sure to check that your code passes the flake8 checking.

Note that flake8 command is not perfect. It does not check some of the style guidelines. Here is a (not-complete) list of the rules that flake8 cannot check.

  • Relative imports are prohibited. [H304]

  • Importing non-module symbols is prohibited.

  • Import statements must be organized into three parts: standard libraries, third-party libraries, and internal imports. [H306]

In addition, we restrict the usage of shortcut symbols in our code base. They are symbols imported by packages and sub-packages of cupy. For example, cupy.cuda.Device is a shortcut of cupy.cuda.device.Device. It is not allowed to use such shortcuts in the ``cupy`` library implementation. Note that you can still use them in tests and examples directories.

Once you send a pull request, your coding style is automatically checked by GitHub Actions. The reviewing process starts after the check passes.

The CuPy is designed based on NumPy’s API design. CuPy’s source code and documents contain the original NumPy ones. Please note the followings when writing the document.

  • In order to identify overlapping parts, it is preferable to add some remarks that this document is just copied or altered from the original one. It is also preferable to briefly explain the specification of the function in a short paragraph, and refer to the corresponding function in NumPy so that users can read the detailed document. However, it is possible to include a complete copy of the document with such a remark if users cannot summarize in such a way.

  • If a function in CuPy only implements a limited amount of features in the original one, users should explicitly describe only what is implemented in the document.

For changes that modify or add new Cython files, please make sure the pointer types follow these guidelines (#1913).

  • Pointers should be void* if only used within Cython, or intptr_t if exposed to the Python space.

  • Memory sizes should be size_t.

  • Memory offsets should be ptrdiff_t.

Note

We are incrementally enforcing the above rules, so some existing code may not follow the above guidelines, but please ensure all new contributions do.

Unit Testing

Testing is one of the most important part of your code. You must write test cases and verify your implementation by following our testing guide.

Note that we are using pytest and mock package for testing, so install them before writing your code:

$ pip install pytest mock

How to Run Tests

In order to run unit tests at the repository root, you first have to build Cython files in place by running the following command:

$ pip install -e .

Note

When you modify *.pxd files, before running pip install -e ., you must clean *.cpp and *.so files once with the following command, because Cython does not automatically rebuild those files nicely:

$ git clean -fdx

Once Cython modules are built, you can run unit tests by running the following command at the repository root:

$ python -m pytest

CUDA must be installed to run unit tests.

Some GPU tests require cuDNN to run. In order to skip unit tests that require cuDNN, specify -m='not cudnn' option:

$ python -m pytest path/to/your/test.py -m='not cudnn'

Some GPU tests involve multiple GPUs. If you want to run GPU tests with insufficient number of GPUs, specify the number of available GPUs to CUPY_TEST_GPU_LIMIT. For example, if you have only one GPU, launch pytest by the following command to skip multi-GPU tests:

$ export CUPY_TEST_GPU_LIMIT=1
$ python -m pytest path/to/gpu/test.py

Following this naming convention, you can run all the tests by running the following command at the repository root:

$ python -m pytest

Or you can also specify a root directory to search test scripts from:

$ python -m pytest tests/cupy_tests     # to just run tests of CuPy
$ python -m pytest tests/install_tests  # to just run tests of installation modules

If you modify the code related to existing unit tests, you must run appropriate commands.

Test File and Directory Naming Conventions

Tests are put into the tests/cupy_tests directory. In order to enable test runner to find test scripts correctly, we are using special naming convention for the test subdirectories and the test scripts.

  • The name of each subdirectory of tests must end with the _tests suffix.

  • The name of each test script must start with the test_ prefix.

When we write a test for a module, we use the appropriate path and file name for the test script whose correspondence to the tested module is clear. For example, if you want to write a test for a module cupy.x.y.z, the test script must be located at tests/cupy_tests/x_tests/y_tests/test_z.py.

How to Write Tests

There are many examples of unit tests under the tests directory, so reading some of them is a good and recommended way to learn how to write tests for CuPy. They simply use the unittest package of the standard library, while some tests are using utilities from cupy.testing.

In addition to the Coding Guidelines mentioned above, the following rules are applied to the test code:

  • All test classes must inherit from unittest.TestCase.

  • Use unittest features to write tests, except for the following cases:

    • Use assert statement instead of self.assert* methods (e.g., write assert x == 1 instead of self.assertEqual(x, 1)).

    • Use with pytest.raises(...): instead of with self.assertRaises(...):.

Note

We are incrementally applying the above style. Some existing tests may be using the old style (self.assertRaises, etc.), but all newly written tests should follow the above style.

Even if your patch includes GPU-related code, your tests should not fail without GPU capability. Test functions that require CUDA must be tagged by the cupy.testing.attr.gpu:

import unittest
from cupy.testing import attr

class TestMyFunc(unittest.TestCase):
    ...

    @attr.gpu
    def test_my_gpu_func(self):
        ...

The functions tagged by the gpu decorator are skipped if CUPY_TEST_GPU_LIMIT=0 environment variable is set. We also have the cupy.testing.attr.cudnn decorator to let pytest know that the test depends on cuDNN. The test functions decorated by cudnn are skipped if -m='not cudnn' is given.

The test functions decorated by gpu must not depend on multiple GPUs. In order to write tests for multiple GPUs, use cupy.testing.attr.multi_gpu() decorators instead:

import unittest
from cupy.testing import attr

class TestMyFunc(unittest.TestCase):
    ...

    @attr.multi_gpu(2)  # specify the number of required GPUs here
    def test_my_two_gpu_func(self):
        ...

If your test requires too much time, add cupy.testing.attr.slow decorator. The test functions decorated by slow are skipped if -m='not slow' is given:

import unittest
from cupy.testing import attr

class TestMyFunc(unittest.TestCase):
    ...

    @attr.slow
    def test_my_slow_func(self):
        ...

Note

If you want to specify more than two attributes, use and operator like -m='not cudnn and not slow'. See detail in the document of pytest.

Once you send a pull request, Travis-CI automatically checks if your code meets our coding guidelines described above. Since Travis-CI does not support CUDA, we cannot run unit tests automatically. The reviewing process starts after the automatic check passes. Note that reviewers will test your code without the option to check CUDA-related code.

Note

Some of numerically unstable tests might cause errors irrelevant to your changes. In such a case, we ignore the failures and go on to the review process, so do not worry about it!

Documentation

When adding a new feature to the framework, you also need to document it in the reference.

Note

If you are unsure about how to fix the documentation, you can submit a pull request without doing so. Reviewers will help you fix the documentation appropriately.

The documentation source is stored under docs directory and written in reStructuredText format.

To build the documentation, you need to install Sphinx:

$ pip install sphinx sphinx_rtd_theme

Then you can build the documentation in HTML format locally:

$ cd docs
$ make html

HTML files are generated under build/html directory. Open index.html with the browser and see if it is rendered as expected.

Note

Docstrings (documentation comments in the source code) are collected from the installed CuPy module. If you modified docstrings, make sure to install the module (e.g., using pip install -e .) before building the documentation.

Tips for Developers

Here are some tips for developers hacking CuPy source code.

Install as Editable

During the development we recommend using pip with -e option to install as editable mode:

$ pip install -e .

Please note that even with -e, you will have to rerun pip install -e . to regenerate C++ sources using Cython if you modified Cython source files (e.g., *.pyx files).

Use ccache

NVCC environment variable can be specified at the build time to use the custom command instead of nvcc . You can speed up the rebuild using ccache (v3.4 or later) by:

$ export NVCC='ccache nvcc'

Limit Architecture

Use CUPY_NVCC_GENERATE_CODE environment variable to reduce the build time by limiting the target CUDA architectures. For example, if you only run your CuPy build with NVIDIA P100 and V100, you can use:

$ export CUPY_NVCC_GENERATE_CODE=arch=compute_60,code=sm_60;arch=compute_70,code=sm_70

See Environment variables for the description.

Using CuPy on AMD GPU (experimental)

CuPy has an experimental support for AMD GPU (ROCm).

Requirements

The following ROCm libraries are required:

$ sudo apt install hipblas hipsparse rocsparse rocrand rocthrust rocsolver rocfft hipcub rocprim

Before installing CuPy, we recommend you to upgrade setuptools and pip:

$ pip install -U setuptools pip

Building CuPy for ROCm

Currently, you need to build CuPy from source to run on AMD GPU.

$ export HCC_AMDGPU_TARGET=gfx900  # This value should be changed based on your GPU
$ export CUPY_INSTALL_USE_HIP=1
$ pip install cupy

Note that HCC_AMDGPU_TARGET must be set to the ISA name supported by your GPU. Run rocminfo and use the value displayed in Name: line (e.g., gfx900).

You may also need to set ROCM_HOME (e.g., ROCM_HOME=/opt/rocm).

Upgrade Guide

This is a list of changes introduced in each release that users should be aware of when migrating from older versions.

CuPy v8

Dropping Support of CUDA 8.0 and 9.1

CUDA 8.0 and 9.1 are no longer supported. Use CUDA 9.0, 9.2, 10.0, or later.

Dropping Support of NumPy 1.15 and SciPy 1.2

NumPy 1.15 (or earlier) and SciPy 1.2 (or earlier) are no longer supported.

Update of Docker Images

  • CuPy official Docker images (see Installation Guide for details) are now updated to use CUDA 10.2 and Python 3.6.

  • SciPy and Optuna are now pre-installed.

CUB Support and Compiler Requirement

CUB module is now built by default. You can enable the use of CUB by setting CUPY_ACCELERATORS="cub" (see Environment variables for details).

Due to this change, g++-6 or later is required when building CuPy from the source. See Installation Guide for details.

The following environment variables are no longer effective:

  • CUB_DISABLED: Use CUPY_ACCELERATORS as aforementioned.

  • CUB_PATH: No longer required as CuPy uses either the CUB source bundled with CUDA (only when using CUDA 11.0 or later) or the one in the CuPy distribution.

API Changes

  • cupy.scatter_add, which was deprecated in CuPy v4, has been removed. Use cupyx.scatter_add() instead.

  • cupy.sparse module has been deprecated and will be removed in future releases. Use cupyx.scipy.sparse instead.

  • dtype argument of cupy.ndarray.min() and cupy.ndarray.max() has been removed to align with the NumPy specification.

  • cupy.allclose() now returns the result as 0-dim GPU array instead of Python bool to avoid device synchronization.

  • cupy.RawModule now delays the compilation to the time of the first call to align the behavior with cupy.RawKernel.

  • cupy.cuda.*_enabled flags (nccl_enabled, nvtx_enabled, etc.) has been deprecated. Use cupy.cuda.*.available flag (cupy.cuda.nccl.available, cupy.cuda.nvtx.available, etc.) instead.

  • CHAINER_SEED environment variable is no longer effective. Use CUPY_SEED instead.

CuPy v7

Dropping Support of Python 2.7 and 3.4

Starting from CuPy v7, Python 2.7 and 3.4 are no longer supported as it reaches its end-of-life (EOL) in January 2020 (2.7) and March 2019 (3.4). Python 3.5.1 is the minimum Python version supported by CuPy v7. Please upgrade the Python version if you are using affected versions of Python to any later versions listed under Installation Guide.

CuPy v6

Binary Packages Ignore LD_LIBRARY_PATH

Prior to CuPy v6, LD_LIBRARY_PATH environment variable can be used to override cuDNN / NCCL libraries bundled in the binary distribution (also known as wheels). In CuPy v6, LD_LIBRARY_PATH will be ignored during discovery of cuDNN / NCCL; CuPy binary distributions always use libraries that comes with the package to avoid errors caused by unexpected override.

CuPy v5

cupyx.scipy Namespace

cupyx.scipy namespace has been introduced to provide CUDA-enabled SciPy functions. cupy.sparse module has been renamed to cupyx.scipy.sparse; cupy.sparse will be kept as an alias for backward compatibility.

Dropped Support for CUDA 7.0 / 7.5

CuPy v5 no longer supports CUDA 7.0 / 7.5.

Update of Docker Images

CuPy official Docker images (see Installation Guide for details) are now updated to use CUDA 9.2 and cuDNN 7.

To use these images, you may need to upgrade the NVIDIA driver on your host. See Requirements of nvidia-docker for details.

CuPy v4

Note

The version number has been bumped from v2 to v4 to align with the versioning of Chainer. Therefore, CuPy v3 does not exist.

Default Memory Pool

Prior to CuPy v4, memory pool was only enabled by default when CuPy is used with Chainer. In CuPy v4, memory pool is now enabled by default, even when you use CuPy without Chainer. The memory pool significantly improves the performance by mitigating the overhead of memory allocation and CPU/GPU synchronization.

Attention

When you monitor GPU memory usage (e.g., using nvidia-smi), you may notice that GPU memory not being freed even after the array instance become out of scope. This is expected behavior, as the default memory pool “caches” the allocated memory blocks.

To access the default memory pool instance, use get_default_memory_pool() and get_default_pinned_memory_pool(). You can access the statistics and free all unused memory blocks “cached” in the memory pool.

import cupy
a = cupy.ndarray(100, dtype=cupy.float32)
mempool = cupy.get_default_memory_pool()

# For performance, the size of actual allocation may become larger than the requested array size.
print(mempool.used_bytes())   # 512
print(mempool.total_bytes())  # 512

# Even if the array goes out of scope, its memory block is kept in the pool.
a = None
print(mempool.used_bytes())   # 0
print(mempool.total_bytes())  # 512

# You can clear the memory block by calling `free_all_blocks`.
mempool.free_all_blocks()
print(mempool.used_bytes())   # 0
print(mempool.total_bytes())  # 0

You can even disable the default memory pool by the code below. Be sure to do this before any other CuPy operations.

import cupy
cupy.cuda.set_allocator(None)
cupy.cuda.set_pinned_memory_allocator(None)

Compute Capability

CuPy v4 now requires NVIDIA GPU with Compute Capability 3.0 or larger. See the List of CUDA GPUs to check if your GPU supports Compute Capability 3.0.

CUDA Stream

As CUDA Stream is fully supported in CuPy v4, cupy.cuda.RandomState.set_stream, the function to change the stream used by the random number generator, has been removed. Please use cupy.cuda.Stream.use() instead.

See the discussion in #306 for more details.

cupyx Namespace

cupyx namespace has been introduced to provide features specific to CuPy (i.e., features not provided in NumPy) while avoiding collision in future. See CuPy-specific Functions for the list of such functions.

For this rule, cupy.scatter_add() has been moved to cupyx.scatter_add(). cupy.scatter_add() is still available as an alias, but it is encouraged to use cupyx.scatter_add() instead.

Update of Docker Images

CuPy official Docker images (see Installation Guide for details) are now updated to use CUDA 8.0 and cuDNN 6.0. This change was introduced because CUDA 7.5 does not support NVIDIA Pascal GPUs.

To use these images, you may need to upgrade the NVIDIA driver on your host. See Requirements of nvidia-docker for details.

CuPy v2

Changed Behavior of count_nonzero Function

For performance reasons, cupy.count_nonzero() has been changed to return zero-dimensional ndarray instead of int when axis=None. See the discussion in #154 for more details.

License

Copyright (c) 2015 Preferred Infrastructure, Inc.

Copyright (c) 2015 Preferred Networks, Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

NumPy

The CuPy is designed based on NumPy’s API. CuPy’s source code and documents contain the original NumPy ones.

Copyright (c) 2005-2016, NumPy Developers.

All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

  • Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

  • Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

  • Neither the name of the NumPy Developers nor the names of any contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

SciPy

The CuPy is designed based on SciPy’s API. CuPy’s source code and documents contain the original SciPy ones.

Copyright (c) 2001, 2002 Enthought, Inc.

All rights reserved.

Copyright (c) 2003-2016 SciPy Developers.

All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

  1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

  2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

  3. Neither the name of Enthought nor the names of the SciPy Developers may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.