Understand numpy's axis

What is axis

In numpy, which handles multidimensional arrays, it is important to have a good understanding of axes.

For example, consider numpy.sum (), which finds the sum of arrays.

For simple scalars and one-dimensional arrays, you can sum all the values without having to think about the axes, but for multidimensional arrays, you need to be aware of which axis to sum. there is.

numpy.sum(a, axis) numpy's sum () specifies an array in the first argument a and an axis in the second argument. The elements are added together along this axis.

2D array axis

First, let's take an easy-to-understand two-dimensional array as an example. In a two-dimensional array, ** row direction is axis = 0 and column direction is axis = 1 **.

numpy_axis_2d.png

Z = np.array([[0,1],
              [2,3]])
print("axis=0    ->", sum(Z,0))
print("axis=1    ->", sum(Z,1))

Execution result

axis=0    -> [2 4]
axis=1    -> [1 5]

If nothing is specified for axis, it will be the total value (scalar) of all elements.

Z = np.array([[0,1],
              [2,3]])
print(sum(Z))

Execution result

6

3D array axis

Next, consider a three-dimensional array. In a 3D array, ** axis = 0 is the depth direction, axis = 1 is the row direction, and axis = 2 is the column direction **.

numpy_axis_3d.png

Z = np.array([[[0,1],
               [2,3]],
              [[4,5],
               [6,7]]])
print("axis=0")
print(sum(Z,0))
print("----")
print("axis=1")
print(sum(Z,1))
print("----")
print("axis=2")
print(sum(Z,2))

Execution result

axis=0
[[ 4  6]
 [ 8 10]]
----
axis=1
[[ 2  4]
 [10 12]]
----
axis=2
[[ 1  5]
 [ 9 13]]

What does axis = -1 represent

If axis = -1, it represents ** the last axis direction **. In other words, it is the same as axis = 2 for a 3D array and axis = 1 for a 2D array.

Z = np.array([[[0,1],
               [2,3]],
              [[4,5],
               [6,7]]])
print("axis=2")
print(sum(Z,2))
print("----")
print("axis=-1")
print(sum(Z,-1))

Execution result

axis=2
[[ 1  5]
 [ 9 13]]
----
axis=-1
[[ 1  5]
 [ 9 13]]

You got the same result.

Recommended Posts

Understand numpy's axis
About axis = 0, axis = 1
Understand Word2Vec
NumPy axis
Understand k-means ++
Understand base64.