As the title suggests, I implemented Single Shot Multibox Detector (SSD) with PyTorch
([https://github.com] /jjjkkkjjj/pytorch_SSD](https://github.com/jjjkkkjjj/pytorch_SSD)) However, the calculation is calculated compared to ssd.pytorch etc. It's slow (I'll investigate the cause later). </ strike> [^ 1] However, I worked hard on the abstraction, so I think it is highly customizable. (I don't know if you can document how to use it properly ...). If you look up the SSD implementation, you will find a lot, but this time,
[^ 1]: I just didn't set the initialization argument num_workers
of DataLoader
. .. .. Currently, it is as fast as ssd.pytorch.
――I wanted to customize it freely --I wanted to understand SSD by implementing it --I was free with corona
I implemented it for that reason. There are many easy-to-understand explanations (reference) in many articles, but I would like to summarize them in my own way to organize my mind. So, this time I would like to summarize around the dataset.
-I implemented SSD with PyTorch now (Dataset) -I tried to implement SSD with PyTorch (model edition)
PyTorch
First of all, I will briefly touch on the deep learning framework PyTorch
used this time. Actually, when I first tried to implement SSD, I used Tensorflow
. However,
--Difficult to debug
--There are too many similar functions (compat.v1
, compat.v2
, etc.?)
Therefore, the implementation did not proceed easily. In particular, "hard to debug" was fatal to me, and I couldn't understand the translation well. Well, I think it would be useful to learn how to use it just because I didn't understand Tensorflow
. .. ..
I thought that it would not be completed for the rest of my life, so I changed it to Pytorch
, which has the feature of being able to perform operations similar to Numpy
. When I changed to PyTorch
, the" difficulty of debugging "that I felt in Tensorflow was considerably improved, and the implementation went smoothly. Well, the operation of Numpy
is Matft (https://github.com/jjjkkkjjj/Matft), [I made an N-dimensional matrix calculation library Matft with Swift] Get used to implementing (https://qiita.com/jjjkkkjjj/items/1f2b5c3835b1600d3129)) and MILES (https://github.com/jjjkkkjjj/MIL) Because it was, PyTorch
was perfect for me.
First of all, I would like to briefly touch on what SSD is and then summarize it in detail. SSD is an object detection algorithm that can predict the position and label of an object end-to-end. It is a suitable figure, but if you give an input image like this, the SSD will output the position and label of the object at once.
What this model does and does is as follows. I will explain step by step.
In the original SSD paper, the datasets are PASCAL VOC2007 and [PASCAL VOC2012](http://host.robots.ox. .ac.uk/pascal/VOC/voc2012/) and COCO2014 are used. COCO has not been implemented yet, so I will explain the VOC dataset. </ strike> First, let's talk about VOC datasets.
The structure of the directory is basically unified, and it looks like the following.
voc directory
$ tree -I '*.png|*.xml|*.jpg|*.txt'
└── VOCdevkit
└── VOC20**
├── Annotations
├── ImageSets
│ ├── Action
│ ├── Layout
│ ├── Main
│ └── Segmentation
├── JPEGImages
├── SegmentationClass
└── SegmentationObject
What is required for object detection is ʻAnnotations,
JPEGImages, and ʻImageSets / Main
directly underVOC20 **
.
Each is as follows.
.xml
file of annotation data. It is one with the .jpeg
file of JPEGImages
..jpeg
file of the image. It is one with the .xml
file of ʻAnnotations`..txt
file that represents information about the dataset set. The file names of ʻAnnotation and
JPEGImages`, which are the elements of the set, are described..xml
file)The .xml
file under the ʻAnnotaions` directory is as follows.
Annotations/~~.xml
<annotation>
<folder>VOC2007</folder>
<filename>000005.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>325991873</flickrid>
</source>
<owner>
<flickrid>archintent louisville</flickrid>
<name>?</name>
</owner>
<size>
<width>500</width>
<height>375</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>chair</name>
<pose>Rear</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>263</xmin>
<ymin>211</ymin>
<xmax>324</xmax>
<ymax>339</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>165</xmin>
<ymin>264</ymin>
<xmax>253</xmax>
<ymax>372</ymax>
</bndbox>
</object>
...
</annotation>
The important points are as follows.
<filename>
--The .jpeg
file that this annotation data corresponds to<object>
<name>
--Label name<truncated>
--Whether the object is entirely visible (0
) or partially visible (1
).<difficult>
--Difficult (1
) or not (0
)<bndbox>
--Bounding box (position of object). corners notationTo implement a dataset, you need to extend the Dataset
class. Then, it is necessary to implement __len__
which returns the number of datasets and __getitem__
which returns the input data and the correct answer label for ʻindex` within the range of the number of datasets.
What we are doing with the following implementation is
--Save the paths of the .xml
files directly under ʻAnnotations in
self._annopathsas a list --Get the image and bounding box for
self._annopaths [index] from ʻindex
given in __getitem__
--Images are read by OpenCV and returned by ** RGB order *** (input data)
--The bounding box is normalized by the width and height of the image.
--Bounding box and label are returned as concatenate
(correct label)
is.
PyTorch
is used. The Pre-Trained model of PyTorch
trains the image normalized by RGB order,mean = (0.485, 0.456, 0.406)
,std = (0.229, 0.224, 0.225)
as input. (Reference)ObjectDetectionDatasetBase
class ObjectDetectionDatasetBase(_DatasetBase):
def __init__(self, ignore=None, transform=None, target_transform=None, augmentation=None):
abridgement
def __getitem__(self, index):
"""
:param index: int
:return:
img : rgb image(Tensor or ndarray)
targets : Tensor or ndarray of bboxes and labels [box, label]
= [xmin, ymin, xmamx, ymax, label index(or relu_one-hotted label)]
or
= [cx, cy, w, h, label index(or relu_one-hotted label)]
"""
img = self._get_image(index)
bboxes, linds, flags = self._get_bbox_lind(index)
img, bboxes, linds, flags = self.apply_transform(img, bboxes, linds, flags)
# concatenate bboxes and linds
if isinstance(bboxes, torch.Tensor) and isinstance(linds, torch.Tensor):
if linds.ndim == 1:
linds = linds.unsqueeze(1)
targets = torch.cat((bboxes, linds), dim=1)
else:
if linds.ndim == 1:
linds = linds[:, np.newaxis]
targets = np.concatenate((bboxes, linds), axis=1)
return img, targets
def apply_transform(self, img, bboxes, linds, flags):
"""
IMPORTATANT: apply transform function in order with ignore, augmentation, transform and target_transform
:param img:
:param bboxes:
:param linds:
:param flags:
:return:
Transformed img, bboxes, linds, flags
"""
# To Percent mode
height, width, channel = img.shape
# bbox = [xmin, ymin, xmax, ymax]
# [bbox[0] / width, bbox[1] / height, bbox[2] / width, bbox[3] / height]
bboxes[:, 0::2] /= float(width)
bboxes[:, 1::2] /= float(height)
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
return img, bboxes, linds, flags
VOCDatasetBase
VOC_class_labels = ['aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
VOC_class_nums = len(VOC_class_labels)
class VOCSingleDatasetBase(ObjectDetectionDatasetBase):
def __init__(self, voc_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
"""
:param voc_dir: str, voc directory path above 'Annotations', 'ImageSets' and 'JPEGImages'
e.g.) voc_dir = '~~~~/trainval/VOCdevkit/voc2007'
:param focus: str, image set name. Assign txt file name under 'ImageSets' directory
:param ignore: target_transforms.Ignore
:param transform: instance of transforms
:param target_transform: instance of target_transforms
:param augmentation: instance of augmentations
:param class_labels: None or list or tuple, if it's None use VOC_class_labels
"""
super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)
self._voc_dir = voc_dir
self._focus = focus
self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
if self._class_labels is None:
self._class_labels = VOC_class_labels
layouttxt_path = os.path.join(self._voc_dir, 'ImageSets', 'Main', self._focus + '.txt')
if os.path.exists(layouttxt_path):
with open(layouttxt_path, 'r') as f:
filenames = f.read().splitlines()
filenames = [filename.split()[0] for filename in filenames]
self._annopaths = [os.path.join(self._voc_dir, 'Annotations', '{}.xml'.format(filename)) for filename in filenames]
else:
raise FileNotFoundError('layout: {} was invalid arguments'.format(focus))
@property
def class_nums(self):
return len(self._class_labels)
@property
def class_labels(self):
return self._class_labels
def _jpgpath(self, filename):
"""
:param filename: path containing .jpg
:return: path of jpg
"""
return os.path.join(self._voc_dir, 'JPEGImages', filename)
def __len__(self):
return len(self._annopaths)
"""
Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
VOC bounding box (xmin, ymin, xmax, ymax)
"""
def _get_image(self, index):
"""
:param index: int
:return:
rgb image(ndarray)
"""
root = ET.parse(self._annopaths[index]).getroot()
img = cv2.imread(self._jpgpath(_get_xml_et_value(root, 'filename')))
# pytorch's image order is rgb
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img.astype(np.float32)
def _get_bbox_lind(self, index):
"""
:param index: int
:return:
list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
"""
linds = []
bboxes = []
flags = []
root = ET.parse(self._annopaths[index]).getroot()
for obj in root.iter('object'):
linds.append(self._class_labels.index(_get_xml_et_value(obj, 'name')))
bndbox = obj.find('bndbox')
# bbox = [xmin, ymin, xmax, ymax]
bboxes.append([_get_xml_et_value(bndbox, 'xmin', int), _get_xml_et_value(bndbox, 'ymin', int), _get_xml_et_value(bndbox, 'xmax', int), _get_xml_et_value(bndbox, 'ymax', int)])
flags.append({'difficult': _get_xml_et_value(obj, 'difficult', int) == 1})#,
#'partial': _get_xml_et_value(obj, 'truncated', int) == 1})
return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags
The structure of the directory is the same as VOC in that it is divided into annotations and images (ʻannotations and ʻimages / {train or val} 20 **
), but the handling of annotations is slightly different.
├── annotations
│ ├── captions_train2014.json
│ ├── captions_val2014.json
│ ├── instances_train2014.json
│ ├── instances_val2014.json
│ ├── person_keypoints_train2014.json
│ └── person_keypoints_val2014.json
└── images
├── train2014
└── val2014
As you can see, unlike VOCs, all annotations are written in one file.
And what you need for object detection is the ʻinstances_ {train or val} 20 **. Json file. The format is described in detail in [Official](http://cocodataset.org/#format-data). And since [python api](https://github.com/cocodataset/cocoapi) is prepared in COCO, the Annotation file for object detection is ʻinstances_ {train or val} 20 **. Json
. If you know, honestly, you don't have to understand the contents very much.
Just in case, when I check the format, it looks like this.
{
"info": info,
"images": [image],
"annotations": [annotation],
"licenses": [license],
}
info{
"year": int,
"version": str,
"description": str,
"contributor": str,
"url": str,
"date_created": datetime,
}
image{
"id": int,
"width": int,
"height": int,
"file_name": str,
"license": int,
"flickr_url": str,
"coco_url": str,
"date_captured": datetime,
}
license{
"id": int,
"name": str,
"url": str,
}
The ʻannotationand
catecories` of object detection are as follows.
annotation{
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float, "bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories[{
"id": int,
"name": str,
"supercategory": str,
}]
Just implement it like VOC. The necessary information is acquired via the COCO
object of the API.
COCO_class_labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
COCO_class_nums = len(COCO_class_labels)
COCO2014_ROOT = os.path.join(DATA_ROOT, 'coco', 'coco2014')
class COCOSingleDatasetBase(ObjectDetectionDatasetBase):
def __init__(self, coco_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
"""
:param coco_dir: str, coco directory path above 'annotations' and 'images'
e.g.) coco_dir = '~~~~/coco2007/trainval'
:param focus: str or str, directory name under images
e.g.) focus = 'train2014'
:param ignore: target_transforms.Ignore
:param transform: instance of transforms
:param target_transform: instance of target_transforms
:param augmentation: instance of augmentations
:param class_labels: None or list or tuple, if it's None use VOC_class_labels
"""
super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)
self._coco_dir = coco_dir
self._focus = focus
self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
if self._class_labels is None:
self._class_labels = COCO_class_labels
self._annopath = os.path.join(self._coco_dir, 'annotations', 'instances_' + self._focus + '.json')
if os.path.exists(self._annopath):
self._coco = COCO(self._annopath)
else:
raise FileNotFoundError('json: {} was not found'.format('instances_' + self._focus + '.json'))
# remove no annotation image
self._imageids = list(self._coco.imgToAnns.keys())
@property
def class_nums(self):
return len(self._class_labels)
@property
def class_labels(self):
return self._class_labels
def _jpgpath(self, filename):
"""
:param filename: path containing .jpg
:return: path of jpg
"""
return os.path.join(self._coco_dir, 'images', self._focus, filename)
def __len__(self):
return len(self._imageids)
"""
Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
VOC bounding box (xmin, ymin, xmax, ymax)
"""
def _get_image(self, index):
"""
:param index: int
:return:
rgb image(ndarray)
"""
"""
self._coco.loadImgs(self._imageids[index]): list of dict, contains;
license: int
file_name: str
coco_url: str
height: int
width: int
date_captured: str
flickr_url: str
id: int
"""
filename = self._coco.loadImgs(self._imageids[index])[0]['file_name']
img = cv2.imread(self._jpgpath(filename))
# pytorch's image order is rgb
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img.astype(np.float32)
def _get_bbox_lind(self, index):
"""
:param index: int
:return:
list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
"""
linds = []
bboxes = []
flags = []
# anno_ids is list
anno_ids = self._coco.getAnnIds(self._imageids[index])
# annos is list of dict
annos = self._coco.loadAnns(anno_ids)
for anno in annos:
"""
anno's keys are;
segmentation: list of float
area: float
iscrowd: int, 0 or 1
image_id: int
bbox: list of float, whose length is 4
category_id: int
id: int
"""
"""
self._coco.loadCats(anno['category_id']) is list of dict, contains;
supercategory: str
id: int
name: str
"""
cat = self._coco.loadCats(anno['category_id'])[0]
linds.append(self.class_labels.index(cat['name']))
# bbox = [xmin, ymin, w, h]
xmin, ymin, w, h = anno['bbox']
# convert to corners
xmax, ymax = xmin + w, ymin + h
bboxes.append([xmin, ymin, xmax, ymax])
"""
flag = {}
keys = ['iscrowd']
for key in keys:
if key in anno.keys():
flag[key] = anno[key] == 1
else:
flag[key] = False
flags.append(flag)
"""
flags.append({'difficult': anno['iscrowd'] == 1})
return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags
Augmentation Augmentation is not always necessary, but even in the Original Paper
Data augmentation is crucial
It seems to be important because it is mentioned as. In the original paper, the specific method is omitted, but there are two main types of Augmentation methods.
In the following, I will write about how this original image is augmented.
Geometric Distortions In Geometric Distortions, there are the following three methods.
Random Expand
--As the name suggests, the size is expanded randomly.
--Fill the margins for size expansion with the mean rgb_mean = (103.939, 116.779, 123.68)
used in normalization.
Random Sample
--Sample at random.
--At that time, the threshold value of the degree of overlap (IoU value) between the sampled image and the bounding box is randomly determined.
--One of (0.1,0.3,0.5,0.7,0.9, None)
--None
has no threshold. However, ʻIoU = 0` without overlap is excluded.
--Repeat until the sample image exceeds the threshold.
I will omit the implementation for a moment. Specifically, it is here. See here for other image examples.
Photometric Distortions With Photometric Distortions, there are the following five methods.
I will omit the implementation for a moment. Specifically, it is here. See here for other image examples.
Transform
This is the preprocessing of the input image.
--Resize (300x300, 512x512, etc.)
--Convert RGB input image ndarray
to torch.Tensor
-Convert from $ [0,255] $ to $ [0,1] $
rgb_means (0.485, 0.456, 0.406), rgb_stds = (0.229, 0.224, 0.225)
are used for the mean and variance of the image, respectively. (I forgot, the mean and variance of the VGG dataset?)The processing in ↑ is implemented as follows. PyTorch
has a processing function transforms for the preprocessed input image, but the PIL
image Because it is a function for, I made it for Opencv. If you create your own transforms
, you need to correspond to the processing in the class method of the above dataset. This time, the _apply_transform
method passes ʻimg,
bboxes,
linds,
flags, that is, flag information such as images, bounding boxes, labels, and difficult as arguments. (* By the way, I omitted it, but ʻaugmentation
has the same implementation method.)
_apply_transform method
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
Therefore, you should implement the __call__ (self, img, bboxes, linds, flags)
method.
The ToTensor
class that converts to torch.Tensor
has the order of the input image of Conv2d
of PyTorch
from the order (h, w, c)
of OpenCV (b, c, h, w) Converted to
.
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, bboxes, labels, flags):
for t in self.transforms:
img, bboxes, labels, flags = t(img, bboxes, labels, flags)
return img, bboxes, labels, flags
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ToTensor(object):
"""
Note that convert ndarray to tensor and [0-255] to [0-1]
"""
def __call__(self, img, *args):
# convert ndarray into Tensor
# transpose img's tensor (h, w, c) to pytorch's format (c, h, w). (num, c, h, w)
img = np.transpose(img, (2, 0, 1))
return (torch.from_numpy(img).float() / 255., *args)
class Resize(object):
def __init__(self, size):
"""
:param size: 2d-array-like, (height, width)
"""
self._size = size
def __call__(self, img, *args):
return (cv2.resize(img, self._size), *args)
class Normalize(object):
#def __init__(self, rgb_means=(103.939, 116.779, 123.68), rgb_stds=(1.0, 1.0, 1.0)):
def __init__(self, rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225)):
self.means = np.array(rgb_means, dtype=np.float32).reshape((-1, 1, 1))
if np.any(np.abs(self.means) > 1):
logging.warning("In general, mean value should be less than 1 because img's range is [0-1]")
self.stds = np.array(rgb_stds, dtype=np.float32).reshape((-1, 1, 1))
def __call__(self, img, *args):
if isinstance(img, torch.Tensor):
return ((img.float() - torch.from_numpy(self.means)) / torch.from_numpy(self.stds), *args)
else:
return ((img.astype(np.float32) - self.means) / self.stds, *args)
Example of use
from data import transforms
transform = transforms.Compose(
[transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225))]
)
Target transform
Bounding box and label conversion.
--Convert bounding box from corners representation to centroids representation
--Convert label to Onehot vector
--Convert from ndarray
to torch.Tensor
--concatenate
the bounding box and label (shape = (box num, 4 = (cx, cy, w, h) + class_nums + 1 = (background))
)
There are three.
--centroids representation
--Use the center coordinates $ (c_x, c_y) $ and the width / height $ (w, h)
--corners expression
--Use the upper left coordinate $ (x_ {min}, y_ {min}) $ and the lower right coordinate $ (x_ {max}, y_ {max})
--minmax expression
--Use the upper left coordinate $ (x_ {min}, y_ {min}) $ and the lower right coordinate $ (x_ {max}, y_ {max})
--Center coordinate $ (c_x, c_y) $ and width / height $ (w, h) $, upper left coordinate $ (x_ {min}, y_ {min}) $ and lower right coordinate $ (x_ {max} , y_ {max}) $ relationship
\begin{align}
(c_x,c_y) &= (\frac{x_{min}+x_{max}}{2},\frac{y_{min}+y_{max}}{2}) \\
(w,h) &= (x_{max}-x_{min},y_{max}-y_{min})
\end{align}
The process of ↑ is implemented as follows. Correct label processing for object detection target_transforms
does not exist in PyTorch
, so you need to create target_transforms yourself. Again, the _apply_transform
method passes the flag information such as bboxes
, linds
, flags
, that is, the bounding box, label, and difficult as arguments as shown below.
_apply_transform method
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
Therefore, you can implement the __call__ (self, bboxes, linds, flags)
method.
class ToTensor(object):
def __call__(self, bboxes, labels, flags):
return torch.from_numpy(bboxes), torch.from_numpy(labels), flags
class ToCentroids(object):
def __call__(self, bboxes, labels, flags):
# bbox = [xmin, ymin, xmax, ymax]
bboxes = np.concatenate(((bboxes[:, 2:] + bboxes[:, :2]) / 2,
(bboxes[:, 2:] - bboxes[:, :2])), axis=1)
return bboxes, labels, flags
class ToCorners(object):
def __call__(self, bboxes, labels, flags):
# bbox = [cx, cy, w, h]
bboxes = np.concatenate((bboxes[:, :2] - bboxes[:, 2:]/2,
bboxes[:, :2] + bboxes[:, 2:]/2), axis=1)
return bboxes, labels, flags
class OneHot(object):
def __init__(self, class_nums, add_background=True):
self._class_nums = class_nums
self._add_background = add_background
if add_background:
self._class_nums += 1
def __call__(self, bboxes, labels, flags):
if labels.ndim != 1:
raise ValueError('labels might have been already relu_one-hotted or be invalid shape')
labels = _one_hot_encode(labels.astype(np.int), self._class_nums)
labels = np.array(labels, dtype=np.float32)
return bboxes, labels, flags
Example of use
target_transform = target_transforms.Compose(
[target_transforms.ToCentroids(),
target_transforms.OneHot(class_nums=datasets.VOC_class_nums, add_background=True),
target_transforms.ToTensor()]
)
The data set processing looks like this. As usual, I was halfway through, but I don't think there are many articles on data set processing, so I hope you find it helpful.
Recommended Posts