In this article, I will implement the method of cropping images using Salience Map using deep learning with Python / PyTorch while reading the paper.
When we talk about images in deep learning, we often try to classify handwritten numbers and detect people, but I hope you will see that you can also do this.
This article participates in DeNA 20 New Graduate Advent Calendar 2019 --Qiita. Thanks to the Advent calendar for giving me the opportunity to make it!
Since there are various genres of Advent calendars, if you just read the article, it is intended for everyone who has touched the program. In terms of moving it, it is assumed for those who have done deep learning tutorial-like things.
Since the code assumed for Jupyter Notebook is included for easy trial, it is possible to move it at hand. The display is collapsed, so click to open it if necessary.
I am using only the library that is already installed in Google Colaboratory. Due to the large data set, it can be a little difficult to try until it is trained.
Sometimes I want to crop (crop) an image in some way. For example, the icon image is roughly square, so I think everyone has thought about how to cut it when registering for various services. In addition, the header image is halfway horizontally long, and the shape of the image is often decided on the spot. On the other hand, if the user cuts it into a fixed shape, you can do your best to make it feel good, but there are many cases where it is necessary to automate it on the application side.
Suppose you want the posted image to always be displayed vertically (1: 3) on a page. It is vertically long because it is a condition that seems difficult to cut.
This photo I took with "It's a nice lobby with a Christmas tree", if you cut it yourself, of course I will do it like this to show the Christmas tree.
However, it is not possible for people to see and cut out all the images posted in large numbers, so it will be automated. Well, I decided to implement it in Python because it would be safe to cut the middle.
import numpy as np
import cv2
import matplotlib.pyplot as plt
def crop(image, aspect_rate=(1, 1)):
"""
Cut out the image from the center so that it has the specified aspect ratio.
Parameters:
-----------------
image : ndarray, (h, w, rgb), uint8
aspect_rate : tuple of int (x, y)
default : (1, 1)
Returns:
-----------------
cropped_image : ndarray, (h, w, rgb), uint8
"""
assert image.dtype==np.uint8
assert image.ndim==3
im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)
#Find the following four values
# box_x : int,Top left x coordinate to crop, box_y : int,Top left y coordinate to crop
# box_width : int,Cutout width, box_height : int,Cutout height
if im_size[0]>im_size[1]:
box_y = 0
box_height = im_size[1]
box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
if box_width>im_size[0]:
box_x = 0
box_width = im_size[0]
box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
box_y = int(round(center[1]-(box_height/2)))
else:
box_x = int(round(center[0]-(box_width/2)))
else:
box_x = 0
box_width = im_size[0]
box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
if box_height>im_size[1]:
box_y = 0
box_height = im_size[1]
box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
box_y = int(round(center[0]-(box_width/2)))
else:
box_y = int(round(center[1]-(box_height/2)))
cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
return cropped_image
#image: Read the image with OpenCV etc. and make it a NumPy array
image = cv2.imread("tree.jpg ")[:, :, ::-1]
cropped_image = crop(image, aspect_rate=(1, 3))
plt.imshow(cropped_image)
plt.show()
Assuming that the long side of the image is used entirely, it is calculated from the aspect ratio given the length of the short side at that time.
I will try it.
The Christmas element is gone and it's just a nice lobby photo. This is bad. AI? Let's do something with the power of.
This time, I will try to imitate the one using Salience Map that Twitter and Adobe have introduced in the last two years. On Twitter [^ 1], when you post an image, it is nicely displayed on the timeline. In addition, Adobe's InDesigin has a function called Content-Aware Fit that crops the image according to the specified range.
[^ 1]: Introducing a neural network that optimally and automatically crops images https://blog.twitter.com/ja_jp/topics/product/2018/0125ML-CR.html
Object detection can be used as a comparison method. However, the Salience Map-based method is versatile in that it does not always show the object with the trained label.
A cropping method using Saliency Map [^ 2] was proposed in 2013 by Ardizzone's paper "Saliency Based Image Cropping".
Where does the line of sight go when a person sees the image? ** Salience Map ** is a pixel-based version of. For example, in the lower left of the figure, this was obtained by measuring from many people, and the Salience Map is the one obtained by calculation. In this figure, the whiter the part, the higher the probability that the viewpoint is, and the black part, the lower the probability that the viewpoint is.
Figure: Example of Salience Map. Upper left: Image. Upper right: The measured viewpoint is indicated by a red X. Bottom left: Salience Map. Bottom right: Salience Map in color and overlaid on the image.This figure is a visualization of the training data of the SALICON dataset [^ 3]. The red X in the upper right is the viewpoint data obtained by having many people look at the image in the upper left and touch the part you are looking at with the mouse cursor.
If you apply a Gaussian filter based on that data, you can create a map that shows the probability (0 to 1) that there is a viewpoint in pixel units, as shown in the lower left. This is the training data of Salience Map that you want to calculate.
As shown in the lower right, if you color it and overlay it on the image, you can see that there is a high probability that the cat will be eye-catching. When the probability of the viewpoint is close to 1, it is red, and when the probability is close to 0, it is blue.
Let's implement Ardizzone's method. For the time being, Salience Map will use the training data of the SALICON dataset as it is. This is the image of the cat and the learning data (correct answer data) of Salience Map for it.
It is a method of cropping to include all pixels above a certain probability. It means that you should only place it in places where you are likely to see it.
Figure: Pipeline of Ardizzone's method (quoted from the paper [^ 2])
To summarize this figure in words, there are the following three steps.
--Salience Map is binarized with a certain threshold (set to 1 and 0) --Find a bounding box that encloses the range of 1 --Crop the original image with the bounding box
NumPy makes binarization easy. NumPy also broadcasts the calculation of comparison operators (>
and ==
), so if you execute ndarray> float
, you will get True or False of each element and binarization is completed. ..
threshhold = 0.3 #Set threshold, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Salience Map path
saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
plt.imshow(saliencymap)
plt.show()
threshhold *= 255 #Salience Map read from the image is 0-Since it is 255, convert the range
binarized_saliencymap = saliencymap>threshhold # ndarray, (h, w), bool
plt.imshow(binarized_saliencymap)
plt.show()
Figure: Results of binarization
The result is as shown in this figure. By default, matplotlib's plt.imshow ()
shows large values in yellow and small values in purple.
The threshold is a hyperparameter that can be set arbitrarily. This time, it is unified to 0.3 throughout the article.
Calculate a ** bounding box ** (a rectangle that just surrounds) that contains all the 1s (True) obtained by binarization.
This is implemented in OpenCV's cv2.boundingRect ()
and can be achieved by just calling it.
Structural Analysis and Shape Descriptors — OpenCV 2.4.13.7 documentation
[Area (contour) features — OpenCV-Python Tutorials 1 documentation](http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_contours/py_contour_features/ py_contour_features.html)
Use patches.Rectangle ()
to draw rectangles in matplotlib.
matplotlib.patches.Rectangle — Matplotlib 3.1.1 documentation
import matplotlib.patches as patches
#Convert to a format that OpenCV can handle
binarized_saliencymap = binarized_saliencymap.astype(np.uint8) # ndarray, (h, w), np.uint8 (0 or 1)
box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
# box_x : int,Top left x coordinate to crop, box_y : int,Top left y coordinate to crop
# box_width : int,Cutout width, box_height : int,Cutout height
#Bounding box drawing
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(binarized_saliencymap)
ax.add_patch(bounding_box)
plt.show()
Figure: Result of getting the bounding box
You can get the bounding box as shown in this figure. The rectangle information is held as the upper left coordinates and width / height values.
Crop the image based on the obtained bounding box. Slice the ndarray of the image using the value in the bounding box.
image_path = 'COCO_train2014_000000196971.jpg' #Image path
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width] # ndarray, (h, w, rgb), np.uint8 (0-255)
#Visualization
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(image)
ax.add_patch(bounding_box)
plt.show()
plt.imshow(cropped_image)
plt.show()
Figure: Results cut out by Ardizzone's method
As shown in this figure, an image was obtained in which the line of sight was likely to be directed.
To make it easier to see how it was processed, let's overlay the colored Salience Map and bounding box on the image. Implement a function to make Salience Map a color and a function to overlay Salience Map on an image.
def color_saliencymap(saliencymap):
"""
Color and visualize the Salience Map. Set 1 to red and 0 to blue.
Parameters
----------------
saliencymap : ndarray, np.uint8, (h, w) or (h, w, rgb)
Returns
----------------
saliencymap_colored : ndarray, np.uint8, (h, w, rgb)
"""
assert saliencymap.dtype==np.uint8
assert (saliencymap.ndim == 2) or (saliencymap.ndim == 3)
saliencymap_colored = cv2.applyColorMap(saliencymap, cv2.COLORMAP_JET)[:, :, ::-1]
return saliencymap_colored
def overlay_saliencymap_and_image(saliencymap_color, image):
"""
Overlay the image with Salience Map.
Parameters
----------------
saliencymap_color : ndarray, (h, w, rgb), np.uint8
image : ndarray, (h, w, rgb), np.uint8
Returns
----------------
overlaid_image : ndarray(h, w, rgb)
"""
assert saliencymap_color.ndim==3
assert saliencymap_color.dtype==np.uint8
assert image.ndim==3
assert image.dtype==np.uint8
im_size = (image.shape[1], image.shape[0])
saliencymap_color = cv2.resize(saliencymap_color, im_size, interpolation=cv2.INTER_CUBIC)
overlaid_image = cv2.addWeighted(src1=image, alpha=1, src2=saliencymap_color, beta=0.7, gamma=0)
return overlaid_image
saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
#Visualization
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
plt.show()
Figure: Image of colorized Salience Map and bounding box overlaid
As shown in this figure, you can see that the areas on the Salience Map that have a high probability of turning red are surrounded.
With Ardizzone's method, the size and aspect ratio depends on the Salience Map. But now that I want to crop to a certain aspect ratio, I need to think about that.
I couldn't find an existing method for this, so I decided to use the following algorithm to determine the range to cut out.
--After using all the range obtained by Ardizzone's method, extend the range in a certain direction so that it has the specified aspect ratio. ――If you extend the range and it jumps out of the image, use the entire image for that direction and narrow the opposite direction to adjust. ――The range to be narrowed is the range where the total value of Salience Map is maximized in the range obtained by Ardizzone's method. --The range to be extended is the range where the total value of Salience Map is maximized.
Find the range that maximizes the sum of the Salience Map values, using the range found so far as much as possible.
Create a "SaliencyBasedImageCropping class" for cropping, and summarize the code so far below.
import copy
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
class SaliencyBasedImageCropping:
"""
A class for cropping images using Salience Map. A method that uses the entire range that exceeds a certain threshold[1]To use.
* If no pixels exceed the threshold, the entire image is returned.
[1] Ardizzone, Edoardo, Alessandro Bruno, and Giuseppe Mazzola. "Saliency based image cropping." International Conference on Image Analysis and Processing. Springer, Berlin, Heidelberg, 2013.
Parameters
----------------
aspect_rate : tuple of int (x, y)
If you specify the aspect ratio here,[1]Find the range that maximizes the total value of the Salience Map while using the range obtained by the method of.
min_size : tuple of int (w, h)
[1]If each axis of the range obtained by the method of is smaller than this value, the range is evenly expanded starting from the center of the range.
Attributes
----------------
self.aspect_rate : tuple of int (x, y)
self.min_size : tuple of int (w, h)
im_size : tuple of int (w, h)
self.bounding_box_based_on_binary_saliency : list
[1]Range obtained by the method of
box_x : int
box_y : int
box_width : int
box_height : int
self.bounding_box : list
The final cropping range with the adjusted aspect ratio
box_x : int
box_y : int
box_width : int
box_height : int
"""
def __init__(self, aspect_rate=None, min_size=(200, 200)):
assert (aspect_rate is None)or((type(aspect_rate)==tuple)and(len(aspect_rate)==2))
assert (type(min_size)==tuple)and(len(min_size)==2)
self.aspect_rate = aspect_rate
self.min_size = min_size
self.im_size = None
self.bounding_box_based_on_binary_saliency = None
self.bounding_box = None
def _compute_bounding_box_based_on_binary_saliency(self, saliencymap, threshhold):
"""
Ardizzone's method[1]Find the cropping range based on the Salience Map with.
Parameters:
-----------------
saliencymap : ndarray, (h, w), np.uint8
0<=saliencymap<=255
threshhold : float
0<threshhold<255
Returns:
-----------------
bounding_box_based_on_binary_saliency : list
box_x : int
box_y : int
box_width : int
box_height : int
"""
assert (threshhold>0)and(threshhold<255)
assert saliencymap.dtype==np.uint8
assert saliencymap.ndim==2
binarized_saliencymap = saliencymap>threshhold
#If there are no pixels in the Salience Map that exceed the threshold, treat them as all.
if saliencymap.sum()==0:
saliencymap+=True
binarized_saliencymap = (binarized_saliencymap.astype(np.uint8))*255
# binarized_saliencymap : ndarray, (h, w), uint8, 0 or 255
#Small areas are erased by morphology processing (opening)
kernel_size = round(min(self.im_size)*0.02)
kernel = np.ones((kernel_size, kernel_size))
binarized_saliencymap = cv2.morphologyEx(binarized_saliencymap, cv2.MORPH_OPEN, kernel)
box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
bounding_box_based_on_binary_saliency = [box_x, box_y, box_width, box_height]
return bounding_box_based_on_binary_saliency
def _expand_small_bounding_box_to_minimum_size(self, bounding_box):
"""
If the range is smaller than the specified size, widen it. Spread the range evenly starting from the center of the range. If it goes out of the image, spread it to the opposite side.
Parameters:
-----------------
bounding_box : list
box_x : int
box_y : int
box_width : int
box_height : int
"""
bounding_box = copy.copy(bounding_box) #Deep copy because I want to keep the values of the original list
# axis=0 : x and witdth, axis=1 : y and hegiht
for axis in range(2):
if bounding_box[axis+2]<self.min_size[axis+0]:
bounding_box[axis+0] -= np.floor((self.min_size[axis+0]-bounding_box[axis+2])/2).astype(np.int)
bounding_box[axis+2] = self.min_size[axis+0]
if bounding_box[axis+0]<0:
bounding_box[axis+0] = 0
if (bounding_box[axis+0]+bounding_box[axis+2])>self.im_size[axis+0]:
bounding_box[axis+0] -= (bounding_box[axis+0]+bounding_box[axis+2]) - self.im_size[axis+0]
return bounding_box
def _expand_bounding_box_to_specified_aspect_ratio(self, bounding_box, saliencymap):
"""
Expand the range so that it has the specified aspect ratio.
Ardizzone's method[1]Find the range that maximizes the total value of the Salience Map while using the range obtained in step 2 as much as possible.
Parameters
----------------
bounding_box : list
box_x : int
box_y : int
box_width : int
box_height : int
saliencymap : ndarray, (h, w), np.uint8
0<=saliencymap<=255
"""
assert saliencymap.dtype==np.uint8
assert saliencymap.ndim==2
bounding_box = copy.copy(bounding_box)
# axis=0 : x and witdth, axis=1 : y and hegiht
if bounding_box[2]>bounding_box[3]:
long_length_axis = 0
short_length_axis = 1
else:
long_length_axis = 1
short_length_axis = 0
#In which direction to stretch
rate1 = self.aspect_rate[long_length_axis]/self.aspect_rate[short_length_axis]
rate2 = bounding_box[2+long_length_axis]/bounding_box[2+short_length_axis]
if rate1>rate2:
moved_axis = long_length_axis
fixed_axis = short_length_axis
else:
moved_axis = short_length_axis
fixed_axis = long_length_axis
fixed_length = bounding_box[2+fixed_axis]
moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
if moved_length > self.im_size[moved_axis]:
#When the size of the image is exceeded when stretched
moved_axis, fixed_axis = fixed_axis, moved_axis
fixed_length = self.im_size[fixed_axis]
moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
fixed_point = 0
start_point = bounding_box[moved_axis]
end_point = bounding_box[moved_axis]+bounding_box[2+moved_axis]
if fixed_axis==0:
saliencymap_extracted = saliencymap[start_point:end_point, :]
elif fixed_axis==1:
saliencymap_extracted = saliencymap[:, start_point:end_point:]
else:
#When stretched to fit within the size of the image
start_point = int(bounding_box[moved_axis]+bounding_box[2+moved_axis]-moved_length)
if start_point<0:
start_point = 0
end_point = int(bounding_box[moved_axis]+moved_length)
if end_point>self.im_size[moved_axis]:
end_point = self.im_size[moved_axis]
if fixed_axis==0:
fixed_point = bounding_box[fixed_axis]
saliencymap_extracted = saliencymap[start_point:end_point, fixed_point:fixed_point+fixed_length]
elif fixed_axis==1:
fixed_point = bounding_box[fixed_axis]
saliencymap_extracted = saliencymap[fixed_point:fixed_point+fixed_length, start_point:end_point]
saliencymap_summed_1d = saliencymap_extracted.sum(moved_axis)
self.saliencymap_summed_slided = np.convolve(saliencymap_summed_1d, np.ones(moved_length), 'valid')
moved_point = np.array(self.saliencymap_summed_slided).argmax() + start_point
if fixed_axis==0:
bounding_box = [fixed_point, moved_point, fixed_length, moved_length]
elif fixed_axis==1:
bounding_box = [moved_point, fixed_point, moved_length, fixed_length]
return bounding_box
def crop_center(self, image):
"""
Crop the center of the image with the specified aspect ratio without using Salience Map.
Parameters:
-----------------
image : ndarray, (h, w, rgb), uint8
Returns:
-----------------
cropped_image : ndarray, (h, w, rgb), uint8
"""
assert image.dtype==np.uint8
assert image.ndim==3
im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)
if im_size[0]>im_size[1]:
box_y = 0
box_height = im_size[1]
box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
if box_width>im_size[0]:
box_x = 0
box_width = im_size[0]
box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
box_y = int(round(center[1]-(box_height/2)))
else:
box_x = int(round(center[0]-(box_width/2)))
else:
box_x = 0
box_width = im_size[0]
box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
if box_height>im_size[1]:
box_y = 0
box_height = im_size[1]
box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
box_y = int(round(center[0]-(box_width/2)))
else:
box_y = int(round(center[1]-(box_height/2)))
cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
return cropped_image
def crop(self, image, saliencymap, threshhold=0.3):
"""
Cropping using Salience Map.
Parameters:
-----------------
image : ndarray, (h, w, rgb), np.uint8
saliencymap : ndarray, (h, w), np.uint8
Saliency map's ndarray need not be the same size as image's ndarray. Saliency map is resized within this method.
threshhold : float
0 < threshhold <1
Returns:
-----------------
cropped_image : ndarray, (h, w, rgb), uint8
"""
assert (threshhold>0)and(threshhold<1)
assert image.dtype==np.uint8
assert image.ndim==3
assert saliencymap.dtype==np.uint8
assert saliencymap.ndim==2
threshhold = threshhold*255 # scale to 0 - 255
self.im_size = (image.shape[1], image.shape[0]) # (width, height)
saliencymap = cv2.resize(saliencymap, self.im_size, interpolation=cv2.INTER_CUBIC)
# compute bounding box based on saliency map
bounding_box_based_on_binary_saliency = self._compute_bounding_box_based_on_binary_saliency(saliencymap, threshhold)
bounding_box = self._expand_small_bounding_box_to_minimum_size(bounding_box_based_on_binary_saliency)
if self.aspect_rate is not None:
bounding_box = self._expand_bounding_box_to_specified_aspect_ratio(bounding_box, saliencymap)
box_y = bounding_box[1]
box_x = bounding_box[0]
box_height = bounding_box[3]
box_width = bounding_box[2]
cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
self.bounding_box_based_on_binary_saliency = bounding_box_based_on_binary_saliency
self.bounding_box = bounding_box
return cropped_image
# -------------------
# SETTING
threshhold = 0.3 #Set threshold, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Salience Map path
image_path = 'COCO_train2014_000000196971.jpg' #Image path
# -------------------
saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
#Visualization of cropped images using Salience Map
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()
#Visualization of Salience Map and bounding box
#The one that matches the specified aspect ratio is red, and the one that matches the specified aspect ratio is green.
saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()
#Visualization of image with cropped center for comparison
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()
Np.convolve ()
is used to find the range where the Salience Map has the maximum value.
numpy.convolve — NumPy v1.17 Manual
This is a one-dimensional convolution function. By convolving with an array of all 1s of the length you want to sum, you can calculate the sum for each fixed range as shown below.
array_1d = np.array([1, 2, 3, 4])
print(np.convolve(array_1d, np.ones(2), 'valid')) # [3. 5. 7.]
If you use a simple for statement on Python, it will slow down, so we will combine NumPy functions as much as possible.
In addition, in the binarization process, we have added an implementation that erases a very small area by morphology conversion. Especially when the Salience Map is obtained by deep learning after this, such an area is likely to occur, so this implementation is added.
Morphology Transformation — OpenCV-Python Tutorials 1 documentation
The green bounding box is used before the aspect ratio is adjusted, and the red bounding box is used after adjusting the aspect ratio.
Figure: Results of cropping with a 1: 3 aspect ratio using Salience Map
As shown in this figure (a), a cat and a hand soap in a vertically long range? I succeeded in cutting out by inserting. Compared to the figure (b) where the center was just cut out, the part that humans want to see is put in a good feeling.
Figure: Results of cropping with a 1: 1 aspect ratio using Salience Map
This is the case for a square (1: 1). When the center is cut out (Fig. (B)), the cat is firmly contained, but when using Salience Map (Fig. (A)), a narrower area is cut out, so it is displayed in the same size. If the cat is getting bigger. Not only whether an object is shown, but also whether it is shown in a sufficient size is important in cropping.
You cannot crop the image you prepared by yourself. I want to crop the image of the Christmas tree I took, not the image of the SALICON dataset, so I will use deep learning to create an estimation model of the Salience Map.
If you look at the benchmark site "MIT Saliency Benchmark" [^ 4] for the Salience Map task, you will find various methods, but this time we will implement SalGAN [^ 5]. The score doesn't seem to be very high, but I chose this because the mechanism seemed simple.
The author implementation [^ 6] was also released, but since the framework is not very familiar with Lasagne (Theano), I will write it with PyTorch while referring to it.
"SalGAN: Visual Saliency Prediction with Generative Adversarial Networks" is a paper published in 2017. As the name implies, it is a technique to estimate the Salience Map using ** GAN (Generative Adversarial Networks) **.
I will omit the explanation of GAN because there are already many easy-to-understand articles. For example, GAN (1) Understanding the basic structure-Qiita is recommended. If you know a typical GAN method that has a lot of implementation and explanation, you can implement it by considering the difference.
Figure: Overall structure of SalGAN (quoted from paper [^ 5])
The Salience Map is a pixel-by-pixel binary classification problem because it has a probability of having a viewpoint for each pixel (0 to 1). It's close to one-class segmentation. Since I want to input an image and output an image (Salience Map), it becomes a ** Encoder-Decoder model ** using CNN as shown in this figure. Pix2Pix [^ 7] is famous when it comes to image-to-image using GAN, but it does not have a U-Net structure like that.
In the Encoder-Decoder model, you can also learn to reduce the output Salience Map and ** Binary Cross Entropy ** of the correct answer data. However, this SalGAN is trying to improve the accuracy by adding a network (** Discriminator **) that classifies the Salience Map as correct data or estimated data.
The loss function of the Encoder-Decoder part (** Generator **) is as follows. In addition to the usual Adversarial Loss, the estimated Salience Map and the Binary Cross Entropy section of the correct answer data are added. Adjust the ratio with the hyperparameter $ \ alpha $.
The loss function of Discriminator is as follows. It is a general form.
While quoting the paper, we will read the information necessary for implementation. I somehow understood by looking at the overall structure, but I will look for the part where the information I want to know a little more is written.
The encoder part of the network is identical in architecture to VGG-16 (Simonyan and Zisserman, 2015), omitting the final pooling and fully connected layers. The network is initialized with the weights of a VGG-16 model trained on the ImageNet data set for object classification (Deng et al., 2009). Only the last two groups of convolutional layers in VGG-16 are modified during the training for saliency prediction, while the earlier layers remain fixed from the original VGG-16 model.
--The CNN of the Encoder part of Generator uses VGG16 --Excluding the last pooling layer and fully connected layer --Initial value is the weight learned by ImageNet --Learn only the last two groups of convolution layers --The weights of the convolution layers of the previous 3 groups are fixed as they are learned by ImageNet.
The decoder architecture is structured in the same way as the encoder, but with the ordering of layers reversed, and with pooling layers being replaced by upsampling layers. Again, ReLU non-linearities are used in all convolution layers, and a final 1 × 1 convolution layer with sigmoid non-linearity is added to produce the saliency map. The weights for the decoder are randomly initialized. The final output of the network is a saliency map in the same size to input image.
--Decoder is the same as Encoder, but inserts an upsampling layer instead of the pooling layer --The last layer is a sigmoid function after a 1x1 convolution --Weights are initialized randomly --The output will be the same size as the input
The input to the discriminator network is an RGBS image of size 256×192×4 containing both the source image channels and (predicted or ground truth) saliency.
--Inject not only Salience Map but also the original image into Discriminator in 4 channels. --In the first place, input the image at 256 x 192
We train the networks on the 15,000 images from the SALICON training set using a batch size of 32.
--Use 15,000 images from the SALICON dataset --Batch size is 32
I don't want to do a reproduction experiment of the paper this time, so I am not particular about the details when implementing it. For example, instead of ReLU in the paper, Leaky ReLU, which is generally considered to be effective, is adopted.
I will write the code. Since the basis is GAN of the Encoder-Decoder model by CNN, we will refer to the implementation of the existing similar method. For example, eriklindernoren's GitHub has PyTorch with various GANs implemented. DCGAN implementation [^ 8] looks good.
Generator uses VGG16 which has been trained with ImageNet, which is provided in torchvision [^ 9]. In SalGAN, the weight is fixed on the front side and learning is on the back side, so describe it in separate layers like torchvision.models.vgg16 (pretrained = True) .features [:17]
. You can check what number is what layer with print (torchvision.models.vgg16 (pretrained = True) .features)
.
torchvision.models — PyTorch master documentation
from torch import nn
import torchvision
class Generator(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.encoder_first = torchvision.models.vgg16(pretrained=True).features[:17] #The part to be used with fixed weight
self.encoder_last = torchvision.models.vgg16(pretrained=True).features[17:-1] #Part to learn
self.decoder = nn.Sequential(
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.LeakyReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(512, 256, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.LeakyReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(256, 128, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.LeakyReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.LeakyReLU(),
nn.Conv2d(64, 1, 1, padding=0),
nn.Sigmoid())
def forward(self, x):
x = self.encoder_first(x)
x = self.encoder_last(x)
x = self.decoder(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(4, 3, 1, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(3, 32, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(32, 64, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(64, 64, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(2, stride=2))
self.classifier = nn.Sequential(
nn.Linear(64*32*24, 100, bias=True),
nn.Tanh(),
nn.Linear(100, 2, bias=True),
nn.Tanh(),
nn.Linear(2, 1, bias=True),
nn.Sigmoid())
def forward(self, x):
x = self.main(x)
x = x.view(x.shape[0], -1)
x = self.classifier(x)
return x
You need a dataset class to read the SALICON dataset. It is a little troublesome to write according to the prepared data set and task. When it comes to how to write, the PyTorch tutorial will be helpful.
Writing Custom Datasets, DataLoaders and Transforms — PyTorch Tutorials 1.3.1 documentation
Preprocessing using torchvision.transforms
is also described here. This time, we will only resize to 192 x 256 and Normalize.
torchvision.transforms — PyTorch master documentation
The SALICON dataset can be downloaded from LSUN’17 Saliency Prediction Challenge | SALICON.
import os
import torch.utils.data as data
import torchvision.transforms as transforms
class SALICONDataset(data.Dataset):
def __init__(self, root_dataset_dir, val_mode = False):
"""
Dataset class for reading SALICON datasets
Parameters:
-----------------
root_dataset_dir : str
Directory path above the SALICON dataset
val_mode : bool (default: False)
If False, Train data is read. If True, Validation data is read.
"""
self.root_dataset_dir = root_dataset_dir
self.imgsets_dir = os.path.join(self.root_dataset_dir, 'SALICON/image_sets')
self.img_dir = os.path.join(self.root_dataset_dir, 'SALICON/imgs')
self.distribution_target_dir = os.path.join(self.root_dataset_dir, 'SALICON/algmaps')
self.img_tail = '.jpg'
self.distribution_target_tail = '.png'
self.transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
self.distribution_transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor()])
if val_mode:
train_or_val = "val"
else:
train_or_val = "train"
imgsets_file = os.path.join(self.imgsets_dir, '{}.txt'.format(train_or_val))
files = []
for data_id in open(imgsets_file).readlines():
data_id = data_id.strip()
img_file = os.path.join(self.img_dir, '{0}{1}'.format(data_id, self.img_tail))
distribution_target_file = os.path.join(self.distribution_target_dir, '{0}{1}'.format(data_id, self.distribution_target_tail))
files.append({
'img': img_file,
'distribution_target': distribution_target_file,
'data_id': data_id
})
self.files = files
def __len__(self):
return len(self.files)
def __getitem__(self, index):
"""
Returns
-----------
data : list
[img, distribution_target, data_id]
"""
data_file = self.files[index]
data = []
img_file = data_file['img']
img = Image.open(img_file)
data.append(img)
distribution_target_file = data_file['distribution_target']
distribution_target = Image.open(distribution_target_file)
data.append(distribution_target)
# transform
data[0] = self.transform(data[0])
data[1] = self.distribution_transform(data[1])
data.append(data_file['data_id'])
return data
Write the code for the rest of the learning. The point is how to calculate the loss function and how to learn Generator and Discriminator.
It takes about several hours to learn the same 120 epochs as the dissertation using GPU.
from datetime import datetime
import torch
from torch.autograd import Variable
#-----------------
# SETTING
root_dataset_dir = "" #Directory path above the SALICON dataset
alpha = 0.005 #Hyperparameters of the Generator loss function. The recommended value for the paper is 0.005
epochs = 120
batch_size = 32 #32 in the dissertation
#-----------------
#Use start time for file name
start_time_stamp = '{0:%Y%m%d-%H%M%S}'.format(datetime.now())
save_dir = "./log/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Load data loader
train_dataset = SALICONDataset(
root_dataset_dir,
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers = 4, pin_memory=True, sampler=None)
val_dataset = SALICONDataset(
root_dataset_dir,
val_mode=True
)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle=False, num_workers = 4, pin_memory=True, sampler=None)
#Load model and loss function
loss_func = torch.nn.BCELoss().to(DEVICE)
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)
#Definition of optimization method (using the settings in the paper)
optimizer_G = torch.optim.Adagrad([
{'params': generator.encoder_last.parameters()},
{'params': generator.decoder.parameters()}
], lr=0.0001, weight_decay=3*0.0001)
optimizer_D = torch.optim.Adagrad(discriminator.parameters(), lr=0.0001, weight_decay=3*0.0001)
#Learning
for epoch in range(epochs):
n_updates = 0 #Iteration count
n_discriminator_updates = 0
n_generator_updates = 0
d_loss_sum = 0
g_loss_sum = 0
for i, data in enumerate(train_loader):
imgs = data[0] # ([batch_size, rgb, h, w])
salmaps = data[1] # ([batch_size, 1, h, w])
#Create label for Discriminator
valid = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(DEVICE)
fake = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(DEVICE)
imgs = Variable(imgs).to(DEVICE)
real_salmaps = Variable(salmaps).to(DEVICE)
#Alternately learn Generator and Discriminator for each iteration
if n_updates % 2 == 0:
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
gen_salmaps = generator(imgs)
#Combine the original image and the generated Salience Map for input to the Discriminator to create a 4-channel array
fake_d_input = torch.cat((imgs, gen_salmaps.detach()), 1) # ([batch_size, rgbs, h, w])
#Calculate the loss function of Generator
g_loss1 = loss_func(gen_salmaps, real_salmaps)
g_loss2 = loss_func(discriminator(fake_d_input), valid)
g_loss = alpha*g_loss1 + g_loss2
g_loss.backward()
optimizer_G.step()
g_loss_sum += g_loss.item()
n_generator_updates += 1
else:
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
#Combine the original image and the Salience Map of the correct data for input to the Discriminator to create a 4-channel array
real_d_input = torch.cat((imgs, real_salmaps), 1) # ([batch_size, rgbs, h, w])
#Calculate the loss function of Discriminator
real_loss = loss_func(discriminator(real_d_input), valid)
fake_loss = loss_func(discriminator(fake_d_input), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
d_loss_sum += d_loss.item()
n_discriminator_updates += 1
n_updates += 1
if n_updates%10==0:
if n_discriminator_updates>0:
print(
"[%d/%d (%d/%d)] [loss D: %f, G: %f]"
% (epoch, epochs-1, i, len(train_loader), d_loss_sum/n_discriminator_updates , g_loss_sum/n_generator_updates)
)
else:
print(
"[%d/%d (%d/%d)] [loss G: %f]"
% (epoch, epochs-1, i, len(train_loader), g_loss_sum/n_generator_updates)
)
#Saving weights
#Save every 5 epochs and the last epoch
if ((epoch+1)%5==0)or(epoch==epochs-1):
generator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_generator_epoch{}".format(start_time_stamp, epoch)))
discriminator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_discriminator_epoch{}".format(start_time_stamp, epoch)))
torch.save(generator.state_dict(), generator_save_path)
torch.save(discriminator.state_dict(), discriminator_save_path)
#Visualize part of Validation data for each epoch
with torch.no_grad():
print("validation")
for i, data in enumerate(val_loader):
image = Variable(data[0]).to(DEVICE)
gen_salmap = generator(imgs)
gen_salmap_np = np.array(gen_salmaps.data.cpu())[0, 0]
plt.imshow(np.array(image[0].cpu()).transpose(1, 2, 0))
plt.show()
plt.imshow(gen_salmap_np)
plt.show()
if i==1:
break
Input the image to the learned SalGAN and try to estimate the Salience Map. See how it is estimated with images that are not used for learning.
generator_path = "" #Path of Generator weight file (pkl) obtained by learning
image_path = "COCO_train2014_000000196971.jpg " #The path of the image you want to enter
generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))
image_pil = Image.open(image_path) #PIL format image input is assumed for transform
image = np.array(image_pil)
plt.imshow(image)
plt.show()
transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
pred_saliencymap = generator(img_torch)
pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Scaling so that the sum is 1
pred_saliencymap = ((pred/pred.max())*255).astype(np.uint8) #Np so that it can be treated as an image.Convert to uint8
plt.imshow(pred_saliencymap)
plt.show()
Figure: Example of Salience Map estimated by SalGAN
The Salience Map in this figure (b) has been estimated. Compared to the correct answer data (Fig. (C)), it has a big impression, but the probability of convincing points around the pitcher and batter can be estimated high.
If you want to learn well, you need to verify with various indicators for Salience Map as done in the paper. Since we have not verified it now, it is unclear how much results are obtained compared to the SalGAN introduced in the paper. This time, the main thing is cropping, so I'd like to move on because I've made something qualitatively like that.
Now you have what you want to make. By combining the Salience Map estimated by SalGAN with the cropping class, you can crop the image nicely.
# -------------------
# SETTING
threshhold = 0.3 #Set threshold, float (0<threshhold<1)
generator_path = "" #Path of Generator weight file (pkl) obtained by learning
image_path = "COCO_train2014_000000196971.jpg " #The path of the image you want to crop
# -------------------
generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))
image_pil = Image.open(image_path) #PIL format image input is assumed for transform
image = np.array(image_pil)
plt.imshow(image)
plt.show()
transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
pred_saliencymap = generator(image_torch)
pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Scaling so that the sum is 1
pred_saliencymap = ((pred_saliencymap/pred_saliencymap.max())*255).astype(np.uint8) #Np so that it can be treated as an image.Convert to uint8
plt.imshow(pred_saliencymap)
plt.show()
#Visualization of cropped images using Salience Map
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, pred_saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()
#Visualization of Salience Map and bounding box
#The one that matches the specified aspect ratio is red, and the one that matches the specified aspect ratio is green.
saliencymap_colored = color_saliencymap(pred_saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()
#Visualization of image with cropped center for comparison
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()
Figure: Comparison of cutouts when using correct data and when using SalGAN (baseball image)
When using the correct answer data in Fig. (A) and when using SalGAN in Fig. (B), almost the same result was obtained by cutting out the batter part. Which batter or pitcher will look at you? I don't think I've learned enough about that, but I'm glad I got through it in the same way.
Isn't this kind of story a cherry picking that only lists successful results? I have a question. If you want to measure this result quantitatively, you can see how these two types overlap with the SALICON dataset. It is an IoU-like calculation in the object detection task. However, I will omit it because it is a story that I just made it now.
I used the image of the cat that appeared in the first half of the article because it is cute, but it is actually Train data, so it is not suitable for verification. But let's take a look.
Figure: Comparison of cropping when using correct data and when using SalGAN (cat image)
Almost the same result was obtained here. It was good.
Finally, we will return to the image of the first Christmas tree. If you can make a convincing cutout with a photo you took that is not in the dataset, you have achieved your goal.
Figure: Comparison of cropping with SalGAN and simply centered (Christmas tree image)
I got the perfect result. The figure (a) with the Christmas tree looks better than the figure (b) without it. AI? We have completed a mechanism that can automatically do things that are close to what people do with the power of. This may reduce the disappointment of cropping blank areas, or reduce the work of manually cropping images.
The contents were the implementation of two papers, cropping using Salience Map [^ 2] and SalGAN [^ 5] for estimating Salience Map + alpha.
You can make something like this with only the publicly available information. Even if you have stopped moving something like a tutorial in deep learning or machine learning, I would like you to take a little challenge and try to make something like this!
Recommended Posts