• fluid.dataset
    • DatasetFactory
    • InMemoryDataset
    • QueueDataset

    fluid.dataset

    SourceEnglish

    DatasetFactory

    SourceEnglish

    • class paddle.fluid.dataset.DatasetFactory
    • DatasetFactory是一个按数据集名称创建数据集的 “工厂”,可以创建“QueueDataset”,“InMemoryDataset”或“FileInstantDataset”,默认为“QueueDataset”。

    代码示例

    1. import paddle.fluid as fluid
    2. dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset")
    • createdataset(_datafeed_class='QueueDataset')
    • 创建“QueueDataset”,“InMemoryDataset” 或 “FileInstantDataset”,默认为“QueueDataset”。

    • 参数:

      • datafeed_class (str) – datafeed类名,为QueueDataset或InMemoryDataset。默认为QueueDataset。代码示例:
    1. import paddle.fluid as fluid
    2. dataset = fluid.DatasetFactory().create_dataset()

    InMemoryDataset

    SourceEnglish

    • class paddle.fluid.dataset.InMemoryDataset
    • InMemoryDataset会向内存中加载数据并在训练前缓冲数据。此类由DatasetFactory创建。

    代码示例:

    1. dataset = paddle.fluid.DatasetFactory().create_dataset(“InMemoryDataset”)
    • load_into_memory()
    • 向内存中加载数据。

    代码示例:

    1. import paddle.fluid as fluid
    2. dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
    3. filelist = ["a.txt", "b.txt"]
    4. dataset.set_filelist(filelist)
    5. dataset.load_into_memory()
    • local_shuffle()
    • 局域shuffle。

    代码示例:

    1. import paddle.fluid as fluid
    2. dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
    3. filelist = ["a.txt", "b.txt"]
    4. dataset.set_filelist(filelist)
    5. dataset.load_into_memory()
    6. dataset.local_shuffle()
    • globalshuffle(_fleet=None)
    • 全局shuffle。

    只能用在分布式模式(单机多进程或多机多进程)中。您如果在分布式模式中运行,应当传递fleet而非None。

    代码示例:

    1. import paddle.fluid as fluid
    2. from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
    3. dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
    4. filelist = ["a.txt", "b.txt"]
    5. dataset.set_filelist(filelist)
    6. dataset.load_into_memory()
    7. dataset.global_shuffle(fleet)
    • 参数:
      • fleet (Fleet) – fleet单例。默认为None。
    • release_memory()
    • 当数据不再使用时,释放InMemoryDataset内存数据。

    代码示例:

    1. import paddle.fluid as fluid
    2. from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
    3. dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
    4. filelist = ["a.txt", "b.txt"]
    5. dataset.set_filelist(filelist)
    6. dataset.load_into_memory()
    7. dataset.global_shuffle(fleet)
    8. exe = fluid.Executor(fluid.CPUPlace())
    9. exe.run(fluid.default_startup_program())
    10. exe.train_from_dataset(fluid.default_main_program(), dataset)dataset.release_memory()
    11. dataset.release_memory()
    • getmemory_data_size(_fleet=None)
    • 用户可以调用此函数以了解加载进内存后所有workers中的ins数量。

    注解

    该函数可能会导致性能不佳,因为它具有barrier。

    • 参数:
      • fleet (Fleet) – fleet对象。返回:内存数据的大小。

    代码示例:

    1. import paddle.fluid as fluid
    2. from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
    3. dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
    4. filelist = ["a.txt", "b.txt"]
    5. dataset.set_filelist(filelist)
    6. dataset.load_into_memory()
    7. print dataset.get_memory_data_size(fleet)
    • getshuffle_data_size(_fleet=None)
    • 获取shuffle数据大小,用户可以调用此函数以了解局域/全局shuffle后所有workers中的ins数量。

    注解

    该函数可能会导致局域shuffle性能不佳,因为它具有barrier。但其不影响局域shuffle。

    • 参数:
      • fleet (Fleet) – fleet对象。返回:shuffle数据的大小。

    代码示例:

    1. import paddle.fluid as fluid
    2. from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
    3. dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
    4. filelist = ["a.txt", "b.txt"]
    5. dataset.set_filelist(filelist)
    6. dataset.load_into_memory()
    7. dataset.global_shuffle(fleet)
    8. print dataset.get_shuffle_data_size(fleet)

    QueueDataset

    SourceEnglish

    • class paddle.fluid.dataset.QueueDataset
    • 流式处理数据。

    代码示例:

    1. import paddle.fluid as fluid
    2. dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
    • local_shuffle()
    • 局域shuffle数据

    QueueDataset中不支持局域shuffle,可能抛出NotImplementedError

    代码示例:

    1. import paddle.fluid as fluid
    2. dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
    3. dataset.local_shuffle()
    • globalshuffle(_fleet=None)
    • 全局shuffle数据

    QueueDataset中不支持全局shuffle,可能抛出NotImplementedError

    代码示例:

    1. import paddle.fluid as fluid
    2. from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
    3. dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
    4. dataset.global_shuffle(fleet)