Dtype is the reason why Numpy's assert_array_equal doesn't work as expected

Overview

According to the manual, numpy.testing.assert_array_equal compares NaN (np.nan) to each other as the same, but it shouldn't work.

** The cause turned out to be that it doesn't work as expected when dtype is ʻobject` (when it's a mixed ndarray) **

environment

Linux OS (details unconfirmed) Conda version Python v3.7 and Numpy v1.8.1

$ python
Python 3.7.0 (default, Oct  9 2018, 10:31:47) 
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.

Moving case

In a normal comparison, the comparison between NaNs will be False, but ʻassert_array_equalwill compare them as the same. However, if it is an array offloat`.

>>> import numpy as np
>>> from numpy.testing import assert_array_equal as aae

>>> a = np.array([1, np.nan])
>>> b = np.array([1, np.nan])

>>> a.dtype
dtype('float64')
>>> b.dtype
dtype('float64')

>>> a == b
array([ True, False])
>>> aae(a, b)
>>> 

Case that does not move

The reason for specifying dtype when creating ʻa and b is that otherwise it will be an array of strings. For these arrays created as ʻobject type, ʻassert_array_equal` now fails.

>>> a = np.array([1,'test',np.nan], dtype=object)
>>> b = np.array([1,'test',np.nan], dtype=object)

>>> a.dtype
dtype('O')
>>> b.dtype
dtype('O')

>>> a == b
array([ True,  True, False])
>>> aae(a,b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../anaconda3/envs/opt/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 936, in assert_array_equal
    verbose=verbose, header='Arrays are not equal')
  File ".../anaconda3/envs/opt/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal

Mismatched elements: 1 / 3 (33.3%)
 x: array([1, 'test', nan], dtype=object)
 y: array([1, 'test', nan], dtype=object)
>>> 

Recommended Posts

Dtype is the reason why Numpy's assert_array_equal doesn't work as expected
[Python] Why pserve doesn't work
Think about why Kubernetes is described as "Linux in the cloud"
Check items when the imported python module does not work as expected