本文以端到端的数据预处理的过程为例,结合开源的HugginFace数据集以及Ray引擎能力对数据进行对数据进行预处理。
关于使用 Ray Data 来进行离线的批量预处理操作的原因,以及与其他的可替代方案进行对比的内容概述,请参阅Ray Data 概述。
目前很多的数据预处理中会引入小模型进行对数据做打标分类工作,模型推理模块会用到GPU/CPU资源,但是数据处理的两端是load 和store,属于IO密集型操作,如果操作不能pipeline执行会反压造成中间的计算资源包括gpu/cpu的利用率降低,最终造成处理成本变高,因此这里提供基于ray data的解决方案。
从 Hugging Face 下载数据集,并上传到火山引擎对象存储服务 TOS 中,该样例的数据集为 bertram-gilfoyle/CC-MAIN-2021-17-raw。
TOS_FS = TosFileSystem(
key=ENV_AK,
secret=ENV_SK,
endpoint_url=TOS_ENV_ENDPOINT,
region='cn-beijing',
socket_timeout=60,
connection_timeout=60,
max_retry_num=30
)
如果没有集成pyproton,我们可以采用开源的s3fs,配置方式如下。
import s3fs
# 不同版本的设置方式不一样,因此将两种方式都放置进来
s3fs.S3FileSystem.read_timeout = 300
s3fs.S3FileSystem.connect_timeout = 30
s3fs.S3FileSystem.retries = 100
def get_fs(s3_fs=True):
return s3fs.S3FileSystem(anon=False,
key='xx',
secret='xxx==',
endpoint_url='https://tos-s3-cn-beijing.ivolces.com',
config_kwargs={'s3': {'addressing_style': 'virtual'},
'read_timeout': 300,
'connect_timeout': 30, })
# 获取所有parquet文件
file_paths = TOS_FS.ls(input_dir, detail=False)
# 只要parquet后缀的文件,这里可以替换成其他的条件,例如文件名称中包含某个关键字上
file_names = [
file_name for file_name in file_paths if 'parquet' in file_name.split('/')[-1]][:2]
input_dir = '{bucket-name}/ray/raw/'
output_dir = '{bucket-name}/ray/output/'
# 将所有的文件都读出
# 读取parquet文件,并且只获取某些列
# schema
# url text
# --------------------
# https://www.volcengine.com/docs 火山文档的官方链接
ds = ray.data.read_parquet(
file_names, filesystem=fs, concurrency=10, columns=['url', 'text'])
from typing import Dict
from datatrove.pipeline.filters import URLFilter
from datatrove.data import Document
from tldextract import TLDExtract
class CustomeURLFilter(URLFilter):
"""
自定义URL过滤器
过滤掉不安全的URL
过滤掉没有uri的URL
"""
name = "😈 CustomeURLFilter-filter"
def __init__(
self
):
super().__init__
def __call__(self, batch: Dict[str, str]):
data = {}
for (url, text) in batch.items():
# 标题跳过
if url == "url":
continue
# 关键字跳过
if "not safe" in url:
continue
# pseudo code
# metadata = {"url": url}
# document = Document(
# text="fake_text", id="fake_id", metadata=metadata)
# print(f"------------{document.metadata['url']}")
# filtered = super().filter(document)
# if filtered:
# continue
data[url] = text
return data
# 过滤掉不安全的URL
ds = ds.map_batches(CustomeURLFilter, batch_size=100,
num_cpus=1, concurrency=10)
# 写回归档到固定目录下
ds.write_parquet(output_dir, filesystem=fs, compression='snappy')
更多样例可参考 官网文档。