DataEngine
1. DataEngine 简介
DataEngine 是 LLaMA-Factory v1 数据处理的核心类,继承自 PyTorch 的 Dataset,负责各种插件的接入,其他功能(如数据格式转换、数据加载等)均通过插件的形式实现并接入 DataEngine。
DataEngine接受一个唯一入参:DataArguments 实例,所有的元数据集信息均通过该参数配置传入。
2. DataEngine 与 DataArguments 接口定义
@dataclass
class DataArguments:
""" `DataEngine`初始化入参
args:
dataset (str): 数据集路径,远程数据集 repo id / dataset_info.yaml 路径,或本地数据集路径/dataset_info.yaml路径
cutoff_len (int): 数据集截止长度,即数据集最大样本采样数量
"""
...
class DataEngine(Dataset):
"""数据引擎(DataEngine)
`DataEngine` 负责数据集的加载与统一管理,支持:
- 从本地路径或 Hugging Face Hub 加载数据
- 通过插件机制加载自定义数据
- 构建统一的数据索引
- 支持流式(streaming)与非流式数据访问
attr:
args (DataArguments): 数据参数配置
datasets (dict[str, HFDataset]): 数据集名称到数据对象的映射
dataset_infos (dict[str, DatasetInfo]): 数据集名称到元信息的映射
data_index (list[tuple[str, int]]): 数据索引列表,每项为 (dataset_name, sample_index)
streaming (bool): 是否为流式数据集
"""
def __init__(self, data_args: DataArguments) -> None:
"""初始化 `DataEngine`
初始化时自动执行以下步骤:
1. 调用 `get_dataset_info`, 从 `data_args` 读取并解析数据集元信息
2. 调用 `load_dataset`,根据配置加载数据集
3. 调用 `build_data_index`,构建统一的索引列表
args:
data_args (DataArguments): 数据参数配置对象
"""
...
def get_dataset_info(self) -> None:
"""从配置文件或远程仓库加载数据集元信息
根据 `self.args.dataset` 确定数据源,数据源支持如下选项:
- 本地 YAML 配置文件路径
- Hugging Face Hub 上的 YAML 配置文件路径
- 本地数据集文件路径
- Hugging Face Hub 数据集 repo id
"""
...
def load_dataset(self) -> None:
"""根据数据集元信息加载所有数据集
每个数据集条目可以包含以下字段:
- `hf_hub_url`: 使用 `datasets.load_dataset` 加载
- 本地数据文件:通过 `DataLoaderPlugin` 插件加载
- `streaming`: 是否启用流式模式
更新:
self.datasets (dict): 数据集名称到已加载数据对象的映射
self.streaming (bool): 如果任一数据集为流式模式,则设置为 True
"""
...
def build_data_index(self) -> None:
"""构建统一的数据索引
为所有数据集创建全局索引列表 `(dataset_name, sample_index)`
当启用流式模式时,生成固定长度(例如 1000)的占位索引;
否则,为每条样本建立索引。
插件 `DataIndexPlugin` 可根据数据集大小或权重调整索引分布
"""
...
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
"""将原始样本转换为统一格式
根据 `dataset_info` 中的 `converter` 字段,调用对应的转换插件,
将原始样本标准化为统一的数据结构。
args:
raw_sample (dict[str, Any]): 原始数据样本
dataset_name (str): 样本所属的数据集名称
return:
Sample: 转换后的标准化格式样本
"""
...
def __len__(self) -> int:
"""返回数据集的总样本数
return:
int: 数据集长度
如果为流式数据集,返回 `-1`
"""
...
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
"""根据索引或选择器获取样本
args:
index (Union[int, Any]): 数据索引,int 或 list[int]
return:
Union[Sample, list[Sample]]: 单个样本或样本列表
"""
...
def __iter__(self) -> Iterable:
"""返回数据集迭代器
用于非流式数据集的顺序或随机访问
流式模式下需要实现异步加载逻辑
return:
Iterable: 数据集迭代器。
"""
...
async def __aiter__(self) -> AsyncIterable:
"""返回异步数据集迭代器
用于流式数据集或异步数据加载场景
允许在异步环境中以流的方式读取样本
return:
AsyncIterable: 异步迭代器,按顺序产出样本
"""
...
DataArguments 参数说明:
dataset: 数据集路径,支持本地或远程,当传入本地数据集文件路径时,需要满足该数据集为标准格式;否则需要传入 dataset_info.yaml 来配置数据集的 converter 等元信息,以告知 DataEngine 应当如何处理该数据。
cutoff_len: 数据集的截止长度,即该数据集的最大样本数量。
3. DataEngine 核心方法
3.1 get_dataset_info:加载数据元信息
根据 dataset 参数加载数据集配置,获取数据位置、数据格式、插件配置等所有数据元信息,在实例化 DataEngine 时会自动调用此方法。
3.2 加载数据集:load_dataset
遍历所有数据源,根据不同的数据源加载数据,在实例化 DataEngine 时会自动调用此方法。
for key, value in self.dataset_infos.items():
split = value.get("split", "train")
streaming = value.get("streaming", False)
if "hf_hub_url" in value:
# 从 HF Hub 加载
dataset = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else:
# 使用 DataLoaderPlugin 加载本地文件
dataset = DataLoaderPlugin(args=self.args).auto_load_data(value)
self.datasets[key] = dataset
3.3 build_data_index:构建数据索引
为每个数据集创建索引列表 [(dataset_name, sample_index), ...], DataIndexPlugin插件在此处被调用,可控制各数据集的采样频率、采样方式等,在实例化DataEngine时会自动调用此方法。
for dataset_name, dataset in self.datasets.items():
# 创建基础索引
data_index = [(dataset_name, idx) for idx in range(len(dataset))]
# 根据 size 和 weight 调整索引
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight:
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
3.4 _convert_data_sample:数据格式标准化
将原始数据转换为标准格式,DataConverterPlugin插件在此处被调用,具体调用的插件由 get_dataset_info 方法获取的 converter 信息指定,若 converter 为空则假定数据集为标准格式,此方法由DataEngine的 __getitem__ 方法调用。
def _convert_data_sample(self, raw_sample: dict, dataset_name: str) -> Sample:
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
# 使用指定的转换器
from ..plugins.data_plugins.converter import get_converter
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
else:
# 已经是标准格式
return {"_dataset_name": dataset_name, **raw_sample}
4. 初始化
DataEngine 初始化过程只需传入一个构建好的 DataArguments 即可,后续可通过该 DataEngine 访问数据集中的数据。
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
# 1. 创建数据参数
data_args = DataArguments(
dataset="~/data/v1_sft_demo.jsonl",
cutoff_len=2048
)
# 2. 初始化 Data Engine
data_engine = DataEngine(data_args=data_args)
# 3. 访问数据
sample = data_engine[0] # 获取第一个样本
5. 数据访问方式
实例化后的DataEngine支持整数索引、列表索引、以及切片等访问方式,其数据读取用法可等价于 Python 列表。
sample = data_engine[0] # 获取第一个样本
sample = data_engine[0:10] # 获取前 10 个样本
sample = data_engine[[0, 5, 10]] # 获取指定索引的样本