You need to enable JavaScript to run this app.
导航
Ray读取Webdataset写入lance
最近更新时间:2024.11.21 11:16:37首次发布时间:2024.11.21 11:16:37

背景

WebDataset 通过将数据集打包成 tar 文件,并使用简单的 URL 进行访问,使得数据集的管理和使用变得更加高效和灵活。
Lance 是一个高效的列式存储格式,基于 Apache Arrow,旨在提供快速的数据存储和检索。它特别适用于大规模数据分析和机器学习任务。Lance 通过利用列式存储的优势,能够在处理大规模数据时提供高效的读写性能和压缩效果。
而WebDataset因是固定的压缩格式,无法直接将webdataset中包含的json信息等打平,无法快速的对数据集进行数据集筛选、过滤、查询等。
因此本文提供了一种通过Ray读取WebDataset,写入lance的范式。并通过Ray对lance的数据集进行操作

数据架构

Image

代码样例

本样例以**华北2(北京)**地域为例,用户须根据实际情况进行替换。

import ray
import logging
from lance.ray.sink import LanceDatasink
from tosfs.core import TosFileSystem
import pyarrow as pa
import numpy as np
from ray.data.dataset import Schema
from ray.data._internal.pandas_block import PandasBlockSchema
from PIL import Image
import io
import json

# 从WEBDATASET 读取的dict 转成 lance 需要的schema(PyArrow)

class TypeConverter:

    def __init__(self,
                 dataset_schema: Schema,
                 required_schema: pa.lib.Schema = None,) -> None:
        self.field_names = dataset_schema.names
        if required_schema is not None:
            self.schema = required_schema
        elif isinstance(dataset_schema.base_schema, pa.lib.Schema):
            self.schema = required_schema
        elif isinstance(dataset_schema.base_schema, PandasBlockSchema):
            self.schema = self.speculation_type(dataset_schema)
        else:
            raise ValueError(
                f"Unsupported schema type: {type(dataset_schema.base_schema)}")

    def __call__(self, batches) -> pa.Table:
        columns = {}
        for field_name in self.field_names:
            column_data = batches[field_name]
            if (field_name == "json"):
                parsed_json = json.loads(column_data[0].decode("utf-8"))
                for key in parsed_json.keys():
                    if key in self.schema.names:
                        columns[key] = [json.loads(item.decode(
                            "utf-8"))[key] for item in column_data]
            else:
                columns[field_name] = column_data
        return pa.Table.from_pydict(columns, schema=self.schema)

# TOS AK
ENV_AK = '{AK}'
# TOS SK
ENV_SK = '{SK}'
# REGION
ENV_REGION = 'cn-beijing'

# Only for lance REGION
S3_ENV_REGION = 'beijing'

# TOS ENDPOINT
TOS_ENV_ENDPOINT = "https://tos-{ENV_REGION}.ivolces.com"
# 存储数据的桶名称
ENV_BUCKET = "hzw"
# S3 协议的 endpoint
S3_ENV_ENDPOINT = "https://{ENV_BUCKET}.tos-s3-{ENV_REGION}.ivolces.com"
# 存储数据的目录
ENV_BASEDIR = "{ENV_BUCKET}/datasource/"
# WEBDATASET 下载的文件类型
FILE_TYPES = ["json", "jpg"]

# 转到LANCE 需要的schema, 这里需要看下TypeConverter 的实现
REQUIRED_SCHEMA = [pa.field("SAMPLE_ID", pa.float64()),
                   pa.field("HEIGHT", pa.float32()),
                   pa.field("WIDTH", pa.float32()),
                   pa.field("jpg", pa.binary()),
                   pa.field("URL", pa.string()),
                   pa.field("TEXT", pa.string()),
                   pa.field("LICENSE", pa.string()),
                   pa.field("NSFW", pa.string()),
                   pa.field("similarity", pa.float64()),
                   ]

TOS_ROOT_DIR = "tos://{ENV_BASEDIR}"
S3_ROOT_DIR = "s3://{ENV_BASEDIR}"
# 数据所在目录
input_dir = "{TOS_ROOT_DIR}/webdataset/"
# 目标是直接走lance aws sdk 所以格式不能包含tos协议,按照s3协议
output_dir = "{S3_ROOT_DIR}/lance/"

# Ray环境初始化
ray.init(runtime_env={
    "env_vars": {"TOSFS_LOGGING_LEVEL": "INFO", "LANCE_LOG": "INFO"},
    "pip": ["pylance", "tosfs"]
})
# 涉及到背压,结合使用再进行调整
data_context = ray.data.context.DataContext()
data_context.op_resource_reservation_enabled = False
logging.basicConfig(level=logging.INFO)

# 自定义解析函数
# 数据来源: https://huggingface.co/datasets/laion/clevr-webdataset/tree/main/train
# 数据格式: https://huggingface.co/datasets/laion/clevr-webdataset
# 详细设置: https://tosfs.readthedocs.io/en/latest/api.html
print(TOS_ENV_ENDPOINT)
print(TOS_ROOT_DIR)

TOS_FS = TosFileSystem(
    key=ENV_AK,
    secret=ENV_SK,
    endpoint_url=TOS_ENV_ENDPOINT,
    region='cn-beijing',
    socket_timeout=60,
    connection_timeout=60,
    max_retry_num=30

)

# 写入lance
storage_options = {
    "access_key_id": ENV_AK,
    "secret_access_key": ENV_SK,
    "aws_region": S3_ENV_REGION,
    "aws_endpoint": S3_ENV_ENDPOINT,
    "virtual_hosted_style_request": "true"
}

# 读取webdataset
ds = ray.data.read_webdataset(
    paths=input_dir,
    filesystem=TOS_FS,
    override_num_blocks=100,
    suffixes=FILE_TYPES,
    concurrency=100,
    decoder=None,
)

# 按照类型的要求做转换
# 原始类型
# <class 'object'>  -->  pa.string()
# dict key --> pa.int()
# dict --> pa.int()
# <numpy.ndarray(shape=(320, 480, 3), dtype=uint8)> --> pa.binary()

required_schema = pa.schema(REQUIRED_SCHEMA)

# 类型转换
ds = ds.map_batches(TypeConverter,  fn_constructor_args=(
    ds.schema(), required_schema), batch_size=10, num_cpus=1, concurrency=30)

# 数据写入lance
sink = LanceDatasink(
    output_dir,
    schema=required_schema,
    max_rows_per_file=100,
    storage_options=storage_options,
    mode="overwrite")
ds.write_datasink(sink)