Python vs numba: the differences
Python is known for being rather slow for some styles of programming. Pythons scripts are compiled to a machine-independent byte code (sometimes seen as a .pyc file), and then interpreted. The Numba package offers much better. It can compile functions at run-time to native machine code, and offer performance many times faster. But the result is not quite python compatible. This page highlights some of the ways in which output after using numba can differ from the original python.
Integers in python are unusual, in that they are of arbitrary length. In numba they are 64 bits on 64-bit platforms and 32 bits on (now rare, some Raspberry Pis) 32-bit platforms. So:
from numba import jit @jit def sub(a,b): return a-b print(sub(1,0)) print(sub(123456789012345678901,123456789012345678900))
crashes with an overflow error as the argument is too big to convert to a 64-bit integer. That's okay-ish; at least we know that there is a problem. But try
from numba import jit @jit def fact(x): if (x==0): return 1 return x*fact(x-1) for i in range(15,25): print(i,fact(i))
when run this produces
15 1307674368000 16 20922789888000 17 355687428096000 18 6402373705728000 19 121645100408832000 20 2432902008176640000 21 -4249290049419214848 22 -1250660718674968576 23 8128291617894825984 24 -7835185981329244160
comment out the
@jit and one gets the correct
15 1307674368000 16 20922789888000 17 355687428096000 18 6402373705728000 19 121645100408832000 20 2432902008176640000 21 51090942171709440000 22 1124000727777607680000 23 25852016738884976640000 24 620448401733239439360000
Numba is no worse than C, C++ or Fortran here, but those used to python's behaviour may be surprised, for, with no warning, numba has deviated from python's answer.
Numba defaults to treating global variables as constants, taking their value at the moment the function was first executed. So
from numba import jit rescale=1 @jit def print_it(x): global rescale print(x/rescale) rescale=5 print_it(15) rescale=10 print_it(30)
will print 3.0 followed by 6.0. Without the
prints 3.0 twice. (Note for those used to other languages. In python
the result of division (/) is always a float even when the operands
are integers and the result can be expressed exactly as an
integer. Numba also obeys this convention. If an int is required,
floor(a/b) as an integer
if a and b are both integers, and
floor(a/b) as a float
if either or both are floats.)
The details are quite confusing. Suppose the main body of the code was
rescale=5 print_it(15) rescale=10 print_it(30) print_it(30.0)
The numba version now prints 3.0, 6.0 and 3.0. It first
print_int to accept an integer argument
rescale was 5. When it was called for a second
time with an integer argument it reused the precompiled version with
rescale frozen at 5. When called for a third time, but
now with a float argument, it has to compile the function afresh for
working with floats. It does this with the current value
rescale, which is 10. So now the
print_it will rescale by 5 if its argument is
an integer, and by 10 if it is a float.
It is possible to recompile individual functions, although this approach is useful only of global variables change infrequently as recompilation is slow.
rescale=5 print_it(15) rescale=10 print_it.recompile() print_it(30)
will print 3.0, 3.0. One may wish to use
try ... except:
pass around the call to
recompile so that it
still works in the absence of Numba, or if it is called before the
first call to the function itself, as the recompile method is not
created until Numba first compiles the function.
One reason that Numba is so much faster than python is that it tries to eliminate dynamic typing. In general a python variable can contain any type, and that type may change at any point. This makes python execution slow, as, for every operation, it first has to determine the types of the operands in order to work out the appropriate operation. (E.g. is "+" integer addition, float addition, string concatenation, float addition after the conversion of an int operand to a float, ...) It can take longer to work out what should be done than it takes to do it!
Numba hopes that, if the types of the arguments to a function are known (the function "signature"), then all the other types are deducable. If this is not so, Numba falls back to much slower operations on python objects. Whilst Numba has good support for some packages with add data types to python, such as numpy, it has no support for most, so it cannot do much with code using gmpy2 (GMP), for instance, as it does not recognise the GMP types.
from numba import jit @jit def fudge_it(x): try: if (x==7): return "Lucky" if (x==13): return "Unlucky" if (x==int(x)): return int(x) except: pass return x print(fudge_it(8)) print(fudge_it(7)) print(fudge_it(7.0)) print(fudge_it(8.0)) print(fudge_it(8.1))
The above causes Numba (0.55.1) to crash. It is a function whose return value depends on the value, rather than just the type, of its argument. Simpler examples may run with a warning.
Numba is less likely to promote numpy's float32 to a float64 than numpy itself.
import numpy as np from numba import jit @jit def fudge(x): print(x,type(x)) x=x**40 print(x,type(x)) x=np.float32(10) print(x,type(x)) x=x**40 print(x,type(x)) fudge(np.float32(10))
10.0 <class 'numpy.float32'> 1e+40 <class 'numpy.float64'> 10.0 float32 inf float32
This can lead to unexpected zeros and infinities as the range of float32 is exceeded. But note that (int)*(float32) and (int)+(float32) are always (float64) in both numpy and Numba.
Note that numpy can be very keen on promotion, so that
import numpy as np x=np.identity(2,dtype=np.int64) big=1<<28 x=np.array([[big,big-1],[big+1,big]]) print(x,type(x),x.dtype) print(np.linalg.det(x))
reports that the determinant is zero and a float when it is clearly
one and when the trivial calculation
big2-(big+1)(big-1) would not have been
close to overflowing the int64 datatype. A value
big as small as 98 million triggers this
Just in case anyone stumbles across this page looking for correct determinants from integer matrices in python, I suggest SymPy
import sympy big=1<<28 x=sympy.Matrix([[big,big-1],[big+1,big]]) print(x.det())
which works equally well for
big=1<<10028, the latter being far outside the
range of a double precision float.