[TensorFlow] I want to master the indexing for Ragged Tensor

Introduction

RaggedTensor that represents variable length data introduced in TensorFlow 2.1 or later, but if you try to write with ordinary Tensor glue, there are various things I'm addicted to it. This time is Indexing. Try to retrieve the value from RaggedTensor by specifying a specific index. As you get used to it, you will be able to perform complicated operations ...

Verification environment

Indexing example

Suppose that x is created as the RaggedTensor to be indexed as follows.

x = tf.RaggedTensor.from_row_lengths(tf.range(15), tf.range(1, 6))
print(x)
# <tf.RaggedTensor [[0], [1, 2], [3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13, 14]]>
Column index 0 1 2 3 4
Line 0 0
The first line 1 2
2nd line 3 4 5
3rd line 6 7 8 9
4th line 10 11 12 13 14

Slicing on a specific line

The first operation is to retrieve a line, which is the same as a normal Tensor. You can think of it as numpy.ndarray. If you specify a range, ** includes the first index and does not include the last index. ** If you are a Python user, I think there is no problem.

print(x[2])
# tf.Tensor([3 4 5], shape=(3,), dtype=int32)
print(x[1:4])
# <tf.RaggedTensor [[1, 2], [3, 4, 5], [6, 7, 8, 9]]>

However, unlike numpy.ndarray, it seems that slicing that specifies discrete rows cannot be used.

#This can be done for ndarray
print(x.numpy()[[1, 3]])                                                                                                                    
# [array([1, 2], dtype=int32) array([6, 7, 8, 9], dtype=int32)]

# Tensor/Not available for Ragged Tensor
print(x[[1, 3]])
# InvalidArgumentError: slice index 3 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice/

Please go here instead.

# Tensor/Fancy Indexing with Ragged Tensor
print(tf.gather(x, [1, 3], axis=0))
# <tf.RaggedTensor [[1, 2], [6, 7, 8, 9]]>

Slicing with fixed column index

The following is an example of slicing with a fixed column index. Unlike a normal Tensor, the presence or absence of an element at that index depends on the row, so it's simply

print(x[:, 2])
# ValueError: Cannot index into an inner ragged dimension.

It is not possible to do like. If you specify the range

print(x[:, 2:3])
# <tf.RaggedTensor [[], [], [5], [8], [12]]>

It works like. It is [] for the row where the specified index does not exist.

Column index 0 1 2 3 4
Line 0 0
The first line 1 2
2nd line 3 4 5
3rd line 6 7 8 9
4th line 10 11 12 13 14

Slicing with different column indexes for each row

If you have a Tensor that lists the 2D indexes you want to collect, you can use tf.gather_nd ().

ind = tf.constant([[0, 0], [1, 1], [2, 0], [4, 3]])
#x(0, 0), (1, 1), (2, 0), (4, 3)I want to collect elements
print(tf.gather_nd(x, ind))
# tf.Tensor([ 0  2  3 13], shape=(4,), dtype=int32)
Column index 0 1 2 3 4
Line 0 0
The first line 1 2
2nd line 3 4 5
3rd line 6 7 8 9
4th line 10 11 12 13 14

On the other hand, I fetch one element for each row, but I think there are times when you want to fetch from different columns.

col = tf.constant([0, 0, 2, 1, 2])
#x(0, 0), (1, 0), (2, 2), (3, 1), (4, 2)I want to collect elements
#Add line numbers to the index, then use the same method as before
ind = tf.transpose(tf.stack([tf.range(tf.shape(col)[0]), col]))
print(tf.gather_nd(x, ind))
# tf.Tensor([ 0  1  5  7 12], shape=(5,), dtype=int32)
Column index 0 1 2 3 4
Line 0 0
The first line 1 2
2nd line 3 4 5
3rd line 6 7 8 9
4th line 10 11 12 13 14

But I feel like it's going to be late, so I thought about a smarter way.

print(tf.gather(x.values, x.row_starts() + col))
# tf.Tensor([ 0  1  5  7 12], shape=(5,), dtype=int32)

This is OK. The entity of the value of x is contained in Tensor (not RaggedTensor) that connects each line (one dimension less) and can be obtained by accessing x.values. I will. It also holds information about the start index of each row (x.row_starts ()) to represent the shape of x. Therefore, you can add the specified offset to this index and slice against x.values.

%timeit tf.gather_nd(x, tf.transpose(tf.stack([tf.range(tf.shape(col)[0]), col])))
# 739 µs ± 75.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit tf.gather(x.values, x.row_starts() + col)                                                                                           
# 124 µs ± 6.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

This one is faster (^_^)

If you want to master the operation around here, it is good to see the official document.

If the column index is in a Ragged Tensor

Apply the fact that the substance of the value is in the one-dimensional Tensor.

col = tf.ragged.constant([[0], [], [0, 2], [1, 3], [2]])
#x(0, 0), (2, 0), (2, 2), (3, 1), (3, 3), (4, 2)I want to collect elements

#Get the start index of each row of x
row_starts = tf.cast(x.row_starts(), "int32")
#Get the line number to which each component of col belongs, convert it to the starting index at x, and add the offset
ind_flat = tf.gather(row_starts, col.value_rowids()) + col.values
ret = tf.gather(x.values, ind_flat)
print(ret)
# tf.Tensor([ 0  3  5  7  9 12], shape=(6,), dtype=int32)
Column index 0 1 2 3 4
Line 0 0
The first line 1 2
2nd line 3 4 5
3rd line 6 7 8 9
4th line 10 11 12 13 14

If you want to save the information of the original line

The result above is a normal Tensor with the values listed, and the information in the original row is lost, but what if you want to save the row information? You can create a RaggedTensor by giving the Tensor information about the starting index of the row. The length of each row should match col, so you can get this starting index from col.value_rowids ().

print(tf.RaggedTensor.from_value_rowids(ret, col.value_rowids()))
# <tf.RaggedTensor [[0], [], [3, 5], [7, 9], [12]]>

When the target Ragged Tensor is 3D or more

Even if the data of 2 dimensions or more are arranged in chronological order (3 dimensions or more for RaggedTensor including batch dimension), the conventional method can be used as it is.

x = tf.RaggedTensor.from_row_lengths(tf.reshape(tf.range(30), (15, 2)), tf.range(1, 6))
print(x)
# <tf.RaggedTensor [[[0, 1]], [[2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15], [16, 17], [18, 19]], [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]]]>

The structure of this x can be interpreted as follows.

Column index 0 1 2 3 4
Line 0 [0, 1]
The first line [2, 3] [4, 5]
2nd line [6, 7] [8, 9] [10, 11]
3rd line [12, 13] [14, 15] [16, 17] [18, 19]
4th line [20, 21] [22, 23] [24, 25] [26, 27] [28, 29]

The rest is exactly the same as before. However, note that the returned Tensor is two-dimensional.

ind = tf.constant([[0, 0], [1, 1], [2, 0], [4, 3]])
#x(0, 0), (1, 1), (2, 0), (4, 3)I want to collect elements
print(tf.gather_nd(x, ind))
# tf.Tensor(
# [[ 0  1]
#  [ 4  5]
#  [ 6  7]
#  [26 27]], shape=(4, 2), dtype=int32)
col = tf.constant([0, 0, 2, 1, 2])
#x(0, 0), (1, 0), (2, 2), (3, 1), (4, 2)I want to collect elements
print(tf.gather(x.values, x.row_starts() + col))
# tf.Tensor(
# [[ 0  1]
#  [ 2  3]
#  [10 11]
#  [14 15]
#  [24 25]], shape=(5, 2), dtype=int32)
col = tf.ragged.constant([[0], [], [0, 2], [1, 3], [2]])
#x(0, 0), (2, 0), (2, 2), (3, 1), (3, 3), (4, 2)I want to collect elements

#Get the start index of each row of x
row_starts = tf.cast(x.row_starts(), "int32")
#Get the line number to which each component of col belongs, convert it to the starting index at x, and add the offset
ind_flat = tf.gather(row_starts, col.value_rowids()) + col.values
ret = tf.gather(x.values, ind_flat)
print(ret)
# tf.Tensor(
# [[ 0  1]
#  [ 6  7]
#  [10 11]
#  [14 15]
#  [18 19]
#  [24 25]], shape=(6, 2), dtype=int32)

#If you want to save the information of the original line
print(tf.RaggedTensor.from_value_rowids(ret, col.value_rowids()))
# <tf.RaggedTensor [[[0, 1]], [], [[6, 7], [10, 11]], [[14, 15], [18, 19]], [[24, 25]]]>

Recommended Posts

[TensorFlow] I want to master the indexing for Ragged Tensor
[TensorFlow] I want to process windows with Ragged Tensor
I tried porting the code written for TensorFlow to Theano
I want to move selenium for the time being [for mac]
I want to create a Dockerfile for the time being.
[For those who want to use TPU] I tried using the Tensorflow Object Detection API 2
For the time being, I want to convert files with ffmpeg !!
I want to pin Spyder to the taskbar
I want to output to the console coolly
I tried tensorflow for the first time
I want to handle the rhyme part1
I want to handle the rhyme part3
I want to display the progress bar
I want to handle the rhyme part2
I want to handle the rhyme part5
I want to handle the rhyme part4
The fastest way for beginners to master Python
I want to handle the rhyme part7 (BOW)
[I want to classify images using Tensorflow] (2) Let's classify images
I want to customize the appearance of zabbix
I want to use the activation function Mish
I want to display the progress in Python!
I want to add silence to the beginning of a wav file for 1 second
I want to see the file name from DataLoader
I want to use self in Backpropagation (tf.custom_gradient) (tensorflow)
I want to grep the execution result of strace
I want to scroll the Django shift table, but ...
I tried the MNIST tutorial for beginners of tensorflow.
I want to inherit to the back with python dataclass
I want to fully understand the basics of Bokeh
I want to write in Python! (3) Utilize the mock
I want to handle the rhyme part6 (organize once)
I want to automate ssh using the expect command!
I want to publish the product at the lowest cost
I want to use the R dataset in python
I want to handle the rhyme part8 (finished once)
I want to increase the security of ssh connections
I just want to find the 95% confidence interval for the difference in population ratios in Python
I want to create a lunch database [EP1] Django study for the first time
I want to create a lunch database [EP1-4] Django study for the first time
I want to use the latest gcc without sudo privileges! !!
I want to initialize if the value is empty (python)
I want to save the photos sent by LINE to S3
I want to exchange gifts even for myself! [Christmas hackathon]
I tried to find the average of the sequence with TensorFlow
I want to automate ssh using the expect command! part2
maya Python I want to fix the baked animation again.
I want to see something beautiful, so I tried to visualize the function used for benchmarking the optimization function.
[NetworkX] I want to search for nodes with specific attributes
[For beginners] I want to explain the number of learning times in an easy-to-understand manner.
I want to use only the normalization process of SudachiPy
I want to get the operation information of yahoo route
I want to change the Japanese flag to the Palau flag with Numpy
I want to solve Sudoku (Sudoku)
For those who want to start machine learning with TensorFlow2
I want to calculate the allowable downtime from the operating rate
[Python] I want to use the -h option with argparse
I want to absorb the difference between the for statement on the Python + numpy matrix and the Julia for statement
I want to judge the authenticity of the elements of numpy array
I didn't know how to use the [python] for statement
I want to use the Ubuntu desktop environment on Android for the time being (Termux version)