ConcatDataset 和 StackDataset¶
在 PyTorch 中,ConcatDataset
和 StackDataset
是两种不同的数据集组合方式。本文介绍了它们的作用及其适用场景。
Python
# 使用 ConcatDataset 连接数据集
concat_dataset = ConcatDataset([dataset1, dataset2])
# 遍历 ConcatDataset
for sample in concat_dataset:
print(sample)
Python
# 使用 StackDataset 组合数据集
stack_dataset = StackDataset(dataset1, dataset2)
# 遍历 StackDataset
for sample in stack_dataset:
print(sample)
作用¶
ConcatDataset¶
- 将多个数据集按顺序拼接,形成一个更大的数据集。
- 遍历时顺序访问每个子数据集的所有样本。
- 例如,
ConcatDataset([dataset1, dataset2])
会先遍历dataset1
,再遍历dataset2
。 - 适用于需要合并多个数据集为单一数据集的情况。
StackDataset¶
- 将多个数据集的样本按索引一一对应组合。
- 遍历时同时从每个子数据集中取出相同索引的样本,并组合成元组或列表。
- 例如,
StackDataset(dataset1, dataset2)
会返回(sample1, sample2)
,其中sample1
来自dataset1
,sample2
来自dataset2
。 - 适用于需要同时处理多个数据集样本的情况,如混合频率的数据集。
代码示例¶
Python
import torch
from torch.utils.data import ConcatDataset, Dataset, StackDataset
# 创建两个简单的数据集
class SimpleDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 数据集 1
dataset1 = SimpleDataset([1, 2, 3])
# 数据集 2
dataset2 = SimpleDataset([4, 5, 6])
Python
# 使用 ConcatDataset 连接数据集
concat_dataset = ConcatDataset([dataset1, dataset2])
# 遍历 ConcatDataset
for sample in concat_dataset:
print(sample)
Python
# 使用 StackDataset 组合数据集
stack_dataset = StackDataset(dataset1, dataset2)
# 遍历 StackDataset
for sample in stack_dataset:
print(sample)