You need to enable JavaScript to run this app.
导航
Ray Data在LLM 数据预处理的最佳实践
最近更新时间:2024.12.12 17:01:09首次发布时间:2024.11.07 19:10:01

场景介绍

本文以端到端的数据预处理的过程为例,结合开源的HugginFace数据集以及Ray引擎能力对数据进行对数据进行预处理。
关于使用 Ray Data 来进行离线的批量预处理操作的原因,以及与其他的可替代方案进行对比的内容概述,请参阅Ray Data 概述

Ray 预处理流程

Image
目前很多的数据预处理中会引入小模型进行对数据做打标分类工作,模型推理模块会用到GPU/CPU资源,但是数据处理的两端是load 和store,属于IO密集型操作,如果操作不能pipeline执行会反压造成中间的计算资源包括gpu/cpu的利用率降低,最终造成处理成本变高,因此这里提供基于ray data的解决方案。

数据准备

Hugging Face 下载数据集,并上传到火山引擎对象存储服务 TOS 中,该样例的数据集为 bertram-gilfoyle/CC-MAIN-2021-17-raw

数据加载

  1. 配置TOS读取的配置项,如果你使用的是火山引擎的Ray集群,我们tosfs优化了TOS的读写,可以采用自研TosFileSystem,tosfs具体的安装过程请详见 官网文档
TOS_FS = TosFileSystem(
    key=ENV_AK,
    secret=ENV_SK,
    endpoint_url=TOS_ENV_ENDPOINT,
    region='cn-beijing',
    socket_timeout=60,
    connection_timeout=60,
    max_retry_num=30
)

如果没有集成pyproton,我们可以采用开源的s3fs,配置方式如下。

import s3fs
 
 # 不同版本的设置方式不一样,因此将两种方式都放置进来
s3fs.S3FileSystem.read_timeout = 300
s3fs.S3FileSystem.connect_timeout = 30
s3fs.S3FileSystem.retries = 100

def get_fs(s3_fs=True):
    return s3fs.S3FileSystem(anon=False,
                             key='xx',
                             secret='xxx==',
                             endpoint_url='https://tos-s3-cn-beijing.ivolces.com',
                             config_kwargs={'s3': {'addressing_style': 'virtual'},
                                            'read_timeout': 300,
                                            'connect_timeout': 30, })
  1. 读取TOS文件夹下的文件。
# 获取所有parquet文件
    file_paths = TOS_FS.ls(input_dir, detail=False)

    # 只要parquet后缀的文件,这里可以替换成其他的条件,例如文件名称中包含某个关键字上
    file_names = [
        file_name for file_name in file_paths if 'parquet' in file_name.split('/')[-1]][:2]
  1. 设置输入输出目录。
input_dir = '{bucket-name}/ray/raw/'
output_dir = '{bucket-name}/ray/output/'
  1. 读取parquet文件。
# 将所有的文件都读出
    # 读取parquet文件,并且只获取某些列
    # schema
    # url              text
    # --------------------
    # https://www.volcengine.com/docs          火山文档的官方链接
    ds = ray.data.read_parquet(
        file_names, filesystem=fs, concurrency=10, columns=['url', 'text'])

数据处理

  1. 对数据进行敏感url 过滤处理,详细内容如下。
from typing import Dict
from datatrove.pipeline.filters import URLFilter
from datatrove.data import Document
from tldextract import TLDExtract

class CustomeURLFilter(URLFilter):
    """
    自定义URL过滤器
    过滤掉不安全的URL
    过滤掉没有uri的URL
    """

    name = "😈 CustomeURLFilter-filter"

    def __init__(
        self
    ):
        super().__init__

    def __call__(self, batch: Dict[str, str]):
        data = {}
        for (url, text) in batch.items():
            # 标题跳过
            if url == "url":
                continue

            # 关键字跳过
            if "not safe" in url:
                continue

            # pseudo code
            # metadata = {"url": url}
            # document = Document(
            #     text="fake_text", id="fake_id", metadata=metadata)
            # print(f"------------{document.metadata['url']}")
            # filtered = super().filter(document)
            # if filtered:
            #     continue

            data[url] = text

        return data
  1. 在ray data 的主流程中批量调用。
# 过滤掉不安全的URL
    ds = ds.map_batches(CustomeURLFilter, batch_size=100,
                        num_cpus=1, concurrency=10)

数据归档

# 写回归档到固定目录下
    ds.write_parquet(output_dir, filesystem=fs, compression='snappy')

源码样例

更多样例可参考 官网文档