WebDataset 通过将数据集打包成 tar 文件,并使用简单的 URL 进行访问,使得数据集的管理和使用变得更加高效和灵活。
Lance 是一个高效的列式存储格式,基于 Apache Arrow,旨在提供快速的数据存储和检索。它特别适用于大规模数据分析和机器学习任务。Lance 通过利用列式存储的优势,能够在处理大规模数据时提供高效的读写性能和压缩效果。
而WebDataset因是固定的压缩格式,无法直接将webdataset中包含的json信息等打平,无法快速的对数据集进行数据集筛选、过滤、查询等。
因此本文提供了一种通过Ray读取WebDataset,写入lance的范式。并通过Ray对lance的数据集进行操作
本样例以**华北2(北京)**地域为例,用户须根据实际情况进行替换。
import ray import logging from lance.ray.sink import LanceDatasink from tosfs.core import TosFileSystem import pyarrow as pa import numpy as np from ray.data.dataset import Schema from ray.data._internal.pandas_block import PandasBlockSchema from PIL import Image import io import json # 从WEBDATASET 读取的dict 转成 lance 需要的schema(PyArrow) class TypeConverter: def __init__(self, dataset_schema: Schema, required_schema: pa.lib.Schema = None,) -> None: self.field_names = dataset_schema.names if required_schema is not None: self.schema = required_schema elif isinstance(dataset_schema.base_schema, pa.lib.Schema): self.schema = required_schema elif isinstance(dataset_schema.base_schema, PandasBlockSchema): self.schema = self.speculation_type(dataset_schema) else: raise ValueError( f"Unsupported schema type: {type(dataset_schema.base_schema)}") def __call__(self, batches) -> pa.Table: columns = {} for field_name in self.field_names: column_data = batches[field_name] if (field_name == "json"): parsed_json = json.loads(column_data[0].decode("utf-8")) for key in parsed_json.keys(): if key in self.schema.names: columns[key] = [json.loads(item.decode( "utf-8"))[key] for item in column_data] else: columns[field_name] = column_data return pa.Table.from_pydict(columns, schema=self.schema) # TOS AK ENV_AK = '{AK}' # TOS SK ENV_SK = '{SK}' # REGION ENV_REGION = 'cn-beijing' # Only for lance REGION S3_ENV_REGION = 'beijing' # TOS ENDPOINT TOS_ENV_ENDPOINT = "https://tos-{ENV_REGION}.ivolces.com" # 存储数据的桶名称 ENV_BUCKET = "hzw" # S3 协议的 endpoint S3_ENV_ENDPOINT = "https://{ENV_BUCKET}.tos-s3-{ENV_REGION}.ivolces.com" # 存储数据的目录 ENV_BASEDIR = "{ENV_BUCKET}/datasource/" # WEBDATASET 下载的文件类型 FILE_TYPES = ["json", "jpg"] # 转到LANCE 需要的schema, 这里需要看下TypeConverter 的实现 REQUIRED_SCHEMA = [pa.field("SAMPLE_ID", pa.float64()), pa.field("HEIGHT", pa.float32()), pa.field("WIDTH", pa.float32()), pa.field("jpg", pa.binary()), pa.field("URL", pa.string()), pa.field("TEXT", pa.string()), pa.field("LICENSE", pa.string()), pa.field("NSFW", pa.string()), pa.field("similarity", pa.float64()), ] TOS_ROOT_DIR = "tos://{ENV_BASEDIR}" S3_ROOT_DIR = "s3://{ENV_BASEDIR}" # 数据所在目录 input_dir = "{TOS_ROOT_DIR}/webdataset/" # 目标是直接走lance aws sdk 所以格式不能包含tos协议,按照s3协议 output_dir = "{S3_ROOT_DIR}/lance/" # Ray环境初始化 ray.init(runtime_env={ "env_vars": {"TOSFS_LOGGING_LEVEL": "INFO", "LANCE_LOG": "INFO"}, "pip": ["pylance", "tosfs"] }) # 涉及到背压,结合使用再进行调整 data_context = ray.data.context.DataContext() data_context.op_resource_reservation_enabled = False logging.basicConfig(level=logging.INFO) # 自定义解析函数 # 数据来源: https://huggingface.co/datasets/laion/clevr-webdataset/tree/main/train # 数据格式: https://huggingface.co/datasets/laion/clevr-webdataset # 详细设置: https://tosfs.readthedocs.io/en/latest/api.html print(TOS_ENV_ENDPOINT) print(TOS_ROOT_DIR) TOS_FS = TosFileSystem( key=ENV_AK, secret=ENV_SK, endpoint_url=TOS_ENV_ENDPOINT, region='cn-beijing', socket_timeout=60, connection_timeout=60, max_retry_num=30 ) # 写入lance storage_options = { "access_key_id": ENV_AK, "secret_access_key": ENV_SK, "aws_region": S3_ENV_REGION, "aws_endpoint": S3_ENV_ENDPOINT, "virtual_hosted_style_request": "true" } # 读取webdataset ds = ray.data.read_webdataset( paths=input_dir, filesystem=TOS_FS, override_num_blocks=100, suffixes=FILE_TYPES, concurrency=100, decoder=None, ) # 按照类型的要求做转换 # 原始类型 # <class 'object'> --> pa.string() # dict key --> pa.int() # dict --> pa.int() # <numpy.ndarray(shape=(320, 480, 3), dtype=uint8)> --> pa.binary() required_schema = pa.schema(REQUIRED_SCHEMA) # 类型转换 ds = ds.map_batches(TypeConverter, fn_constructor_args=( ds.schema(), required_schema), batch_size=10, num_cpus=1, concurrency=30) # 数据写入lance sink = LanceDatasink( output_dir, schema=required_schema, max_rows_per_file=100, storage_options=storage_options, mode="overwrite") ds.write_datasink(sink)