You need to enable JavaScript to run this app.
导航
在 Pytorch 任务中访问和存储 TOS 数据
最近更新时间:2024.11.07 10:28:30首次发布时间:2024.11.07 10:28:30

本文介绍如何在 Pytorch 任务中访问和存储 TOS 数据。

前提条件

  1. 在使用 TOS 之前,请确保您已注册火山引擎账号并完成实名认证,具体步骤,请参见账号注册实名认证
  2. 请确保您已开通 TOS 服务并创建 TOS 存储桶,具体步骤,请参见快速入门
  3. 必须使用 Linux x86-64 操作系统。
  4. 已安装 Python 3.8 及以上版本。

    注意

    不建议使用 3.12+ 版本,因为 PyTorch 不支持。

  5. 已安装 2.0 及以上版本的 Pytorch。

使用 S3 Connector 对接

S3 Connector 是 AWS 提供的 pytorch 连接器,用于增强 torchdata 提供的 S3 相关访问能力,支持 map-style datasets、iterable-stype datasets 以及 checkpointing 接口,详细信息,请参见 S3-connector-for-pytorch

注意事项

当前仅支持 TOS FNS 桶。

操作步骤

  1. 安装 s3 connector。

    1. 执行以下命令安装 s3 connector 基础库。

      pip install s3torchconnector
      
    2. (可选)执行以下命令安装 s3 connector for lighting。

      pip install s3torchconnector[lightning]
      
  2. 配置环境。

    1. 执行以下命令编辑密钥配置文件。

      vim ~/.aws/credentials
      
    2. 执行以下命令配置访问密钥。

      [default]
      aws_access_key_id = <Your Volces AK>
      aws_secret_access_key = <Your Volces SK>
      
    3. 执行以下命令保存配置。

      :wq
      
  3. 运行以下代码,使用 iterable-style datasets。

    说明

    本文以使用北京地域的 bucketname 桶的 shard-data/ 目录为例,建议您根据业务需求修改代码。

    from s3torchconnector import S3IterableDataset, S3ClientConfig
    
    if __name__ == '__main__':
        dataset_uri = 's3://bucketname/shard-data/'
        region = 'cn-beijing'
        endpoint = 'http://tos-s3-cn-beijing.volces.com'
        s3_config = S3ClientConfig(force_path_style=False)
    
        print('测试 iterable-style datasets')
        iterable_dataset = S3IterableDataset.from_prefix(dataset_uri,
                                                         region=region,
                                                         endpoint=endpoint,
                                                         s3client_config=s3_config)
        # 遍历并读取Dataset中的对象。
        for item in iterable_dataset:
            print(item.bucket, item.key)
            data = item.read()
            print(len(data))
    
  4. 运行以下代码,使用 map-style datasets。

    说明

    本文以使用北京地域的 bucketname 桶的 shard-data/ 目录为例,建议您根据业务需求修改代码。

    from s3torchconnector import S3IterableDataset, S3ClientConfig
    
    if __name__ == '__main__':
        dataset_uri = 's3://bucketname/shard-data/'
        region = 'cn-beijing'
        endpoint = 'http://tos-s3-cn-beijing.volces.com'
        s3_config = S3ClientConfig(force_path_style=False)
        
        print('测试 map-style datasets')
        map_dataset = S3MapDataset.from_prefix(dataset_uri,
                                               region=region,
                                               endpoint=endpoint,
                                               s3client_config=s3_config)
        # 遍历并读取Dataset中的对象。
        for item in map_dataset:
            print(item.bucket, item.key)
            data = item.read()
            print(len(data))
            
        # 根据索引读取Dataset中的对象。
        item = map_dataset[0]
        bucket = item.bucket
        key = item.key
        print(bucket, key)
        data = item.read()
        print(len(data))
    
  5. 运行以下代码,保存或加载 checkpoint。

    说明

    本文以使用北京地域的 bucketname 桶的 checkpoint/ 目录为例,建议您根据业务需求修改代码。

    import torch
    import torchvision
    from s3torchconnector import S3Checkpoint, S3ClientConfig
    
    if __name__ == '__main__':
        checkpoint_uri = 's3://bucketname/checkpoint/'
        region = 'cn-beijing'
        endpoint = 'http://tos-s3-cn-beijing.volces.com'
        s3_config = S3ClientConfig(force_path_style=False)
    
        print('测试 checkpointing')
        checkpoint = S3Checkpoint(region=region,
                                  endpoint=endpoint,
                                  s3client_config=s3_config)
    
        model = torchvision.models.resnet18()
    
        with checkpoint.writer(checkpoint_uri + 'epoch0.ckpt') as writer:
            torch.save(model.state_dict(), writer)
    
        with checkpoint.reader(checkpoint_uri + 'epoch0.ckpt') as reader:
            state_dict = torch.load(reader)
    
        model.load_state_dict(state_dict)
    

使用 Fuse 对接

将火山 TOS 挂载为本地路径,再使用 torchdata 直接读写本地路径实现对接,本文中以使用 tosutil 工具作为 Fuse 挂载工具,您可以根据实际情况,选择其他工具挂载存储桶。

注意事项

当前仅支持 TOS FNS 桶。

操作步骤

  1. 安装 tosutil 工具,具体步骤,请参见下载与安装

  2. 配置 toustil 工具并检查版本和连通性,具体步骤,请参见快速入门

  3. 使用 mount 命令挂载存储桶并验证挂载结果,具体步骤,请参见挂载桶为本地文件系统目录(mount)
    本文以将 bucketname 桶挂载到 ./root 目录为例。

    ./tosutil mount tos://bucketname ./root -aow -ar -ao
    
  4. 挂载完成后,执行以下命令验证挂载结果。

    ls ./root
    
  5. 运行以下代码,构建 iterable-style datasets。

    说明

    本文以使用 shard-data/ 目录为例,建议您根据业务需求修改代码。

    from torchdata.datapipes.iter import FileLister, FileOpener
    
    if __name__ == '__main__':
        dataset_uri = './root/shard-data/'
    
        print('测试 iterable-style datasets')
        file_lister = FileLister(root=dataset_uri)
        iterable_dataset = FileOpener(file_lister, mode='rb')
        # 遍历并读取Dataset中的对象。
        for _, item in iterable_dataset:
            print(item.bucket, item.key)
            data = item.read()
            print(len(data))
    
  6. 运行以下代码,保存或加载 checkpoint。

    说明

    本文以使用 checkpoint/ 目录为例,建议您根据业务需求修改代码。

    import torch
    import torchvision
    
    if __name__ == '__main__':
        print('测试 checkpointing')
        model = torchvision.models.resnet18()
        
        # 直接将 tos 作为本地文件处理
        with open('./root/checkpoint/epoch0.ckpt', 'wb') as writer:
            torch.save(model.state_dict(), writer)
    
    
        with open('./root/checkpoint/epoch0.ckpt', 'rb') as reader:
            stat_dict = torch.load(reader)
    
        model.load_state_dict(stat_dict)
        print(model)
    

使用 fsspec 对接

通过实现了 fsspec 规范的 tosfs,将火山 TOS 封装为 python 的标准读写接口,再使用 torchdata 直接读写本地路径实现对接,更多信息,请参见 tosfs

注意事项

支持 TOS FNS 桶和 HNS 桶。

操作步骤

  1. 执行以下命令安装 tosfs。

    # 安装 tosfs for fsspec
    pip install tosfs
    
  2. 执行以下命令初始化客户端。

    from tosfs.core import TosFileSystem
    
    # 配置访问密钥等信息
    if __name__ == '__main__':
        tosfs = TosFileSystem(
            key='<Your Volces AK>',
            secret='<Your Volces SK>',
            endpoint_url='http://tos-cn-beijing.volces.com',
            region='cn-beijing',
        )
    
  3. 运行以下代码,使用 iterable-style datasets。

    说明

    本文以使用 北京地域的 bucketname 桶为例,建议您根据业务需求修改代码。

    from torchdata.datapipes.iter import FSSpecFileLister, FSSpecFileOpener
    
    if __name__ == '__main__':
        kwargs = {
            'key' : '<Your Volces AK>',
            'secret' : '<Your Volces SK>',
            'endpoint_url' : 'https://tos-cn-beijing.volces.com',
            'region' :'cn-beijing'
        }
        
        bucket = 'bucketname'
        print('测试 iterable-style datasets')
        file_lister = FSSpecFileLister(root='tos://'+ bucket+'/shard-data/', **kwargs)
        iterable_dataset = FSSpecFileOpener(file_lister, mode='rb', **kwargs)
        # 遍历并读取Dataset中的对象。
        for _, item in iterable_dataset:
            print(item.bucket, item.key)
            data = item.read()
            print(len(data))
    
  4. 运行以下代码,保存或加载 checkpoint。

    说明

    本文以使用 北京地域的 bucketname 桶为例,建议您根据业务需求修改代码。

    import torch
    import torchvision
    from tosfs.core import TosFileSystem
    
    if __name__ == '__main__':
        tosfs = TosFileSystem(
            key='<Your Volces AK>',
            secret='<Your Volces SK>',
            endpoint_url='http://tos-cn-beijing.volces.com',
            region='cn-beijing',
        )
    
        bucket = 'bucketname'
        print('测试 checkpointing')
        model = torchvision.models.resnet18()
        
        # 将 tos 封装为 python 的标准读写接口
        with tosfs.open(bucket+ '/checkpoint/epoch1.ckpt', 'wb') as writer:
            torch.save(model.state_dict(), writer)
        
        with tosfs.open(bucket+'/checkpoint/epoch1.ckpt', 'rb') as reader:
            stat_dict = torch.load(reader)
        
        model.load_state_dict(stat_dict)
        print(model)