我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用torch.utils.data.sort()。
def collate_fn(data): """Creates mini-batch tensors from the list of tuples (image, caption). We should build custom collate_fn rather than using default collate_fn, because merging caption (including padding) is not supported in default. Args: data: list of tuple (image, caption). - image: torch tensor of shape (3, 256, 256). - caption: torch tensor of shape (?); variable length. Returns: images: torch tensor of shape (batch_size, 3, 256, 256). targets: torch tensor of shape (batch_size, padded_length). lengths: list; valid length for each padded caption. """ # Sort a data list by caption length (descending order). data.sort(key=lambda x: len(x[1]), reverse=True) images, captions = zip(*data) # Merge images (from tuple of 3D tensor to 4D tensor). images = torch.stack(images, 0) # Merge captions (from tuple of 1D tensor to 2D tensor). lengths = [len(cap) for cap in captions] targets = torch.zeros(len(captions), max(lengths)).long() for i, cap in enumerate(captions): end = lengths[i] targets[i, :end] = cap[:end] return images, targets, lengths
def collate_fn(data): """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding). Args: data: list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length. Returns: src_seqs: torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence. """ def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) return src_seqs, src_lengths, trg_seqs, trg_lengths
def collate_fn(data): """Creates mini-batch tensors from the list of tuples (image, caption). We should build custom collate_fn rather than using default collate_fn, because merging code (including padding) is not supported in default. Args: data: list of tuple (image, caption). - image: torch tensor of shape (3, 256, 256). - caption: torch tensor of shape (?); variable length. Returns: images: torch tensor of shape (batch_size, 3, 256, 256). targets: torch tensor of shape (batch_size, padded_length). lengths: list; valid length for each padded caption. """ # Sort a data list by code length (descending order). data.sort(key=lambda x: len(x[1]), reverse=True) images, captions = zip(*data) # Merge images (from tuple of 3D tensor to 4D tensor). images = torch.stack(images, 0) # Merge code (from tuple of 1D tensor to 2D tensor). lengths = [len(cap) for cap in captions] targets = torch.zeros(len(captions), max(lengths)).long() for i, cap in enumerate(captions): end = lengths[i] targets[i, :end] = cap[:end] return images, targets, lengths
def train_collate_fn(data): ''' ??????????????minibatch??? ''' # ??video?????????? data.sort(key=lambda x: x[-1], reverse=True) videos, captions, lengths, video_ids = zip(*data) # ??????????2D Tensor?????3D Tensor? videos = torch.stack(videos, 0) # ?caption???????1D Tensor???????2D Tensor? captions = torch.stack(captions, 0) return videos, captions, lengths, video_ids
def eval_collate_fn(data): ''' ??????????????minibatch??? ''' data.sort(key=lambda x: x[-1], reverse=True) videos, video_ids = zip(*data) # ??????????2D Tensor?????3D Tensor? videos = torch.stack(videos, 0) return videos, video_ids
def collate_fn(data): """Build mini-batch tensors from a list of (image, caption) tuples. Args: data: list of (image, caption) tuple. - image: torch tensor of shape (3, 256, 256). - caption: torch tensor of shape (?); variable length. Returns: images: torch tensor of shape (batch_size, 3, 256, 256). targets: torch tensor of shape (batch_size, padded_length). lengths: list; valid length for each padded caption. """ # Sort a data list by caption length data.sort(key=lambda x: len(x[1]), reverse=True) images, captions, ids, img_ids = zip(*data) # Merge images (convert tuple of 3D tensor to 4D tensor) images = torch.stack(images, 0) # Merget captions (convert tuple of 1D tensor to 2D tensor) lengths = [len(cap) for cap in captions] targets = torch.zeros(len(captions), max(lengths)).long() for i, cap in enumerate(captions): end = lengths[i] targets[i, :end] = cap[:end] return images, targets, lengths, ids