本文以端到端的数据预处理的过程为例,结合开源的HugginFace数据集以及Ray引擎能力对数据进行对数据进行预处理。
关于使用 Ray Data 来进行离线的批量预处理操作的原因,以及与其他的可替代方案进行对比的内容概述,请参阅Ray Data 概述。
目前很多的数据预处理中会引入小模型进行对数据做打标分类工作,模型推理模块会用到GPU/CPU资源,但是数据处理的两端是load 和store,属于IO密集型操作,如果操作不能pipeline执行会反压造成中间的计算资源包括gpu/cpu的利用率降低,最终造成处理成本变高,因此这里提供基于ray data的解决方案。
从 Hugging Face 下载数据集,并上传到火山引擎对象存储服务 TOS 中,该样例的数据集为 bertram-gilfoyle/CC-MAIN-2021-17-raw。
TOS_FS = TosFileSystem( key=ENV_AK, secret=ENV_SK, endpoint_url=TOS_ENV_ENDPOINT, region='cn-beijing', socket_timeout=60, connection_timeout=60, max_retry_num=30 )
如果没有集成pyproton,我们可以采用开源的s3fs,配置方式如下。
import s3fs # 不同版本的设置方式不一样,因此将两种方式都放置进来 s3fs.S3FileSystem.read_timeout = 300 s3fs.S3FileSystem.connect_timeout = 30 s3fs.S3FileSystem.retries = 100 def get_fs(s3_fs=True): return s3fs.S3FileSystem(anon=False, key='xx', secret='xxx==', endpoint_url='https://tos-s3-cn-beijing.ivolces.com', config_kwargs={'s3': {'addressing_style': 'virtual'}, 'read_timeout': 300, 'connect_timeout': 30, })
# 获取所有parquet文件 file_paths = TOS_FS.ls(input_dir, detail=False) # 只要parquet后缀的文件,这里可以替换成其他的条件,例如文件名称中包含某个关键字上 file_names = [ file_name for file_name in file_paths if 'parquet' in file_name.split('/')[-1]][:2]
input_dir = '{bucket-name}/ray/raw/' output_dir = '{bucket-name}/ray/output/'
# 将所有的文件都读出 # 读取parquet文件,并且只获取某些列 # schema # url text # -------------------- # https://www.volcengine.com/docs 火山文档的官方链接 ds = ray.data.read_parquet( file_names, filesystem=fs, concurrency=10, columns=['url', 'text'])
from typing import Dict from datatrove.pipeline.filters import URLFilter from datatrove.data import Document from tldextract import TLDExtract class CustomeURLFilter(URLFilter): """ 自定义URL过滤器 过滤掉不安全的URL 过滤掉没有uri的URL """ name = "😈 CustomeURLFilter-filter" def __init__( self ): super().__init__ def __call__(self, batch: Dict[str, str]): data = {} for (url, text) in batch.items(): # 标题跳过 if url == "url": continue # 关键字跳过 if "not safe" in url: continue # pseudo code # metadata = {"url": url} # document = Document( # text="fake_text", id="fake_id", metadata=metadata) # print(f"------------{document.metadata['url']}") # filtered = super().filter(document) # if filtered: # continue data[url] = text return data
# 过滤掉不安全的URL ds = ds.map_batches(CustomeURLFilter, batch_size=100, num_cpus=1, concurrency=10)
# 写回归档到固定目录下 ds.write_parquet(output_dir, filesystem=fs, compression='snappy')
更多样例可参考 官网文档。