Ad

How To Compile A Numba Jit'ed Function With Variable Input Type?

Say I have a function that can accept both an int or a None type as an input argument

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
    np.random.seed(None)
    out = np.random.normal()
    return out

I want the function to simply return a normally distributed random number. If I want reproducible results, seed should be an int.

get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327

If I want random numbers, seed should be left as None. However, if I do not pass an argument (so seed defaults to None) or explicitly pass seed=None, then numba raises a TypeError

get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)

How can I write the function, still declaring the signature and using nopython mode for such a scenario?

My numba version is 0.43.1

Ad

Answer

The first problem is that numba in nopython mode only accepts (as of version 0.43.1) np.random.seed: with an integer argument only.

So, unfortunately, you cannot pass in None.


The second problem is that there is (as far as I know) not a "single" signature that tells numba how to deal with missing values, however you can use two signatures (yes, it's very verbose):

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}

@nb.jit(
    [nb.types.float64(nb.types.misc.Omitted(None)), 
     nb.types.float64(nb.types.int64)], 
    **jitkw)
def get_random(seed=None):
    return np.random.normal()

Just a short explanation about the two parts of the signaure:

  • The nb.types.float64(nb.types.misc.Omitted(None)) tells numba to use None as default type if the argument is omitted
  • and the nb.types.float64(nb.types.int64) is the signature that expects an integer.

Personally I wouldn't specify the signature and simply let numba figure it out. Explicit signatures are seldom worth it in numba and more often then not they lead to slower and less flexible code.

Ad
source: stackoverflow.com
Ad