Pytorch collate_fn est un argument de Dataloader.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
Cette fois, je voudrais confirmer son comportement et son utilisation.
Lorsque le \ _ \ _ getitem \ _ \ _ défini dans le jeu de données se présente sous la forme d'un lot, chaque élément (image, cible, etc.) est d'abord consolidé dans une liste. Collate_fn le manipule comme décrit dans Pytroch Official, le rendant finalement torche.Tensor C'est une fonction.
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
Par défaut, torch.stack () est utilisé pour créer Tensor, mais vous pouvez créer un lot avancé en utilisant votre propre collate_fn.
Le comportement par défaut est presque le même que ci-dessous. (Bien que le nombre de retours dépende de \ _ \ _ getitem \ _ \ _) Il prend le lot comme argument, l'empile et le renvoie.
def collate_fn(batch):
images, targets= list(zip(*batch))
images = torch.stack(images)
targets = torch.stack(targets)
return images, targets
Vous pouvez modifier le contenu de votre propre collate_fn.
Cette fois, nous allons créer un lot de détection d'objets. La détection d'objet entre essentiellement le rectangle de l'objet et son étiquette, mais comme il peut y avoir plusieurs rectangles dans une image, il est nécessaire de connecter quelle image est quel rectangle lors du traitement par lots, et l'index Doit être joint.
[[label, xc, yx, w, h],
[ ],
[ ],...]
#Changer cela vers le bas
[[0, label xc, yx, w, h],
[0, ],
[1, ],...]
La mise en œuvre elle-même n'est pas si difficile.
def batch_idx_fn(batch):
images, bboxes = list(zip(*batch))
targets = []
for idx, bbox in enumerate(bboxes):
target = np.zeros((len(bbox), 6))
target[:, 1:] = bbox
target[:, 0] = idx
targets.append(target)
images = torch.stack(images)
targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
return images, targets
Lorsque vous l'utilisez réellement, ce sera comme suit.
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
collate_fn=batch_idx_fn
)
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]
Autre que lors de l'indexation dans cet article Lorsque la cible change pour chaque lot, Lorsque la cible n'est pas des données numériques qui ne peuvent pas être empilées Je pense qu'il peut être utilisé lorsque vous souhaitez utiliser le même jeu de données avec des modifications légèrement différentes.
Recommended Posts