1. 数据流

graph TD
o_data[[原始数据]]
map_dataset[Map-style Dataset
 \_\_len\_\_
 \_\_getitem\_\_
\_\_getitems\_\_
]
iter_dataset[Iterable-style Dataset
\_\_iter\_\_
]

subgraph dataloader_parameters
bs[[batch_size]]
drop_last[[drop_last]]
shuffle[[shuffle]]
sampler1[[sampler]]
batch_sampler1[[batch_sampler]]
end

o_data --> map_dataset
o_data --> iter_dataset
sampler1 ==>|直接指定| sampler
batch_sampler1 ==>|直接指定| batch_sampler
bs --> batch_sampler
drop_last --> batch_sampler
shuffle -->| 构造随机采样器|sampler
map_dataset --> sampler
map_dataset --> batch_sampler

iter_dataset --> collate_fn

subgraph DataLoader
sampler
batch_sampler
collate_fn
sampler --> batch_sampler
batch_sampler --> collate_fn
end

collate_fn --> batched_data[[batched data]]

2. 说明 for Map-style Dataset

2.1. collate_fn

collate_fn的输入是一个样本列表,作用类似于下面这样。

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

如果不设置batch_size and batch_sampler, 即禁用自动batching
就等价于

for index in sampler:
    yield collate_fn(dataset[index])

如果不提供collate_fn, pytorch会应用一个默认的collate_fn, 默认的collate_fn有如下性质

  • 如果没有禁用自动batching
    • It always prepends a new dimension as the batch dimension.
    • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
    • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.
  • 如果禁用自动batching
    • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.

2.2. 直接指定

直接指定某个参数后,就不需要(禁止)再提供额外的参数来构造该对象了。例如
直接指定 sampler后就不能再输入shuffle
直接指定 batch_sampler后就不能再输入batch_size, drop_last, shuffle, sampler

3. 测试程序

from torch.utils.data import Dataset, DataLoader
import torch
 
from torch.utils.data import Sampler
from typing import Iterator, List
 
 
 
class EvenThenOddSampler(Sampler[int]):
    def __init__(self, data_source) -> None:
        self.data_source = data_source
 
    def __len__(self) -> int:
        return len(self.data_source)
 
    def __iter__(self) -> Iterator[int]:
        print("__iter__ is called")
        n = len(self.data_source)
 
        for i in range(0, n, 2):
            yield i
 
        for i in range(1, n, 2):
            yield i
 
 
class EvenThenOddBatchSampler(Sampler[List[int]]):
    def __init__(self, data_source, batch_size: int, drop_last: bool = False):
        self.data_source = data_source
        self.batch_size = batch_size
        self.drop_last = drop_last
 
    def __len__(self) -> int:
        n = len(self.data_source)
        if self.drop_last:
            return n // self.batch_size
        else:
            return (n + self.batch_size - 1) // self.batch_size
 
    def __iter__(self) -> Iterator[List[int]]:
        print("__iter__ is called")
        n = len(self.data_source)
 
        indices = list(range(0, n, 2)) + list(range(1, n, 2))
 
        batch = []
        for idx in indices:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
 
        if len(batch) > 0 and not self.drop_last:
            yield batch
 
 
class MyDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.data = torch.arange(0, 64)
        self.labels = torch.arange(0, 64) * 10
 
    def __getitem__(self, idx):
        print(f"__getitem__ is called with idx {idx}")
        return self.data[idx], self.labels[idx]
 
    def __len__(self):
        print(f"__len__ is called")
        return len(self.data)
 
    # def __getitems__(self, idxs):
    #     print(f"__getitems__ is called with idxs {idxs}")
    #     datalist = []
    #     for idx in idxs:
    #         datalist.append((self.data[idx], self.labels[idx]))
    #     return datalist
 
 
train_dataset = MyDataset()
 
sampler = EvenThenOddSampler(train_dataset)
batch_sampler = EvenThenOddBatchSampler(train_dataset, batch_size=4, drop_last=False)
 
print("dataloader is created")
 
dataloader = DataLoader(
    train_dataset,
    # batch_size=4,
    # shuffle=True,
    # sampler=sampler,
    batch_sampler=batch_sampler,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
)
 
for batch in dataloader:
    print(batch)
 

4. 参考资料

https://docs.pytorch.org/docs/2.6/data