本文介绍如何在 Pytorch 任务中访问和存储 TOS 数据。
注意
不建议使用 3.12+ 版本,因为 PyTorch 不支持。
S3 Connector 是 AWS 提供的 pytorch 连接器,用于增强 torchdata 提供的 S3 相关访问能力,支持 map-style datasets、iterable-stype datasets 以及 checkpointing 接口,详细信息,请参见 S3-connector-for-pytorch。
当前仅支持 TOS FNS 桶。
安装 s3 connector。
执行以下命令安装 s3 connector 基础库。
pip install s3torchconnector
(可选)执行以下命令安装 s3 connector for lighting。
pip install s3torchconnector[lightning]
配置环境。
执行以下命令编辑密钥配置文件。
vim ~/.aws/credentials
执行以下命令配置访问密钥。
[default] aws_access_key_id = <Your Volces AK> aws_secret_access_key = <Your Volces SK>
执行以下命令保存配置。
:wq
运行以下代码,使用 iterable-style datasets。
说明
本文以使用北京地域的 bucketname
桶的 shard-data/
目录为例,建议您根据业务需求修改代码。
from s3torchconnector import S3IterableDataset, S3ClientConfig if __name__ == '__main__': dataset_uri = 's3://bucketname/shard-data/' region = 'cn-beijing' endpoint = 'http://tos-s3-cn-beijing.volces.com' s3_config = S3ClientConfig(force_path_style=False) print('测试 iterable-style datasets') iterable_dataset = S3IterableDataset.from_prefix(dataset_uri, region=region, endpoint=endpoint, s3client_config=s3_config) # 遍历并读取Dataset中的对象。 for item in iterable_dataset: print(item.bucket, item.key) data = item.read() print(len(data))
运行以下代码,使用 map-style datasets。
说明
本文以使用北京地域的 bucketname
桶的 shard-data/
目录为例,建议您根据业务需求修改代码。
from s3torchconnector import S3IterableDataset, S3ClientConfig if __name__ == '__main__': dataset_uri = 's3://bucketname/shard-data/' region = 'cn-beijing' endpoint = 'http://tos-s3-cn-beijing.volces.com' s3_config = S3ClientConfig(force_path_style=False) print('测试 map-style datasets') map_dataset = S3MapDataset.from_prefix(dataset_uri, region=region, endpoint=endpoint, s3client_config=s3_config) # 遍历并读取Dataset中的对象。 for item in map_dataset: print(item.bucket, item.key) data = item.read() print(len(data)) # 根据索引读取Dataset中的对象。 item = map_dataset[0] bucket = item.bucket key = item.key print(bucket, key) data = item.read() print(len(data))
运行以下代码,保存或加载 checkpoint。
说明
本文以使用北京地域的 bucketname
桶的 checkpoint/
目录为例,建议您根据业务需求修改代码。
import torch import torchvision from s3torchconnector import S3Checkpoint, S3ClientConfig if __name__ == '__main__': checkpoint_uri = 's3://bucketname/checkpoint/' region = 'cn-beijing' endpoint = 'http://tos-s3-cn-beijing.volces.com' s3_config = S3ClientConfig(force_path_style=False) print('测试 checkpointing') checkpoint = S3Checkpoint(region=region, endpoint=endpoint, s3client_config=s3_config) model = torchvision.models.resnet18() with checkpoint.writer(checkpoint_uri + 'epoch0.ckpt') as writer: torch.save(model.state_dict(), writer) with checkpoint.reader(checkpoint_uri + 'epoch0.ckpt') as reader: state_dict = torch.load(reader) model.load_state_dict(state_dict)
将火山 TOS 挂载为本地路径,再使用 torchdata 直接读写本地路径实现对接,本文中以使用 tosutil 工具作为 Fuse 挂载工具,您可以根据实际情况,选择其他工具挂载存储桶。
当前仅支持 TOS FNS 桶。
安装 tosutil 工具,具体步骤,请参见下载与安装。
配置 toustil 工具并检查版本和连通性,具体步骤,请参见快速入门。
使用 mount 命令挂载存储桶并验证挂载结果,具体步骤,请参见挂载桶为本地文件系统目录(mount)。
本文以将 bucketname
桶挂载到 ./root
目录为例。
./tosutil mount tos://bucketname ./root -aow -ar -ao
挂载完成后,执行以下命令验证挂载结果。
ls ./root
运行以下代码,构建 iterable-style datasets。
说明
本文以使用 shard-data/
目录为例,建议您根据业务需求修改代码。
from torchdata.datapipes.iter import FileLister, FileOpener if __name__ == '__main__': dataset_uri = './root/shard-data/' print('测试 iterable-style datasets') file_lister = FileLister(root=dataset_uri) iterable_dataset = FileOpener(file_lister, mode='rb') # 遍历并读取Dataset中的对象。 for _, item in iterable_dataset: print(item.bucket, item.key) data = item.read() print(len(data))
运行以下代码,保存或加载 checkpoint。
说明
本文以使用 checkpoint/
目录为例,建议您根据业务需求修改代码。
import torch import torchvision if __name__ == '__main__': print('测试 checkpointing') model = torchvision.models.resnet18() # 直接将 tos 作为本地文件处理 with open('./root/checkpoint/epoch0.ckpt', 'wb') as writer: torch.save(model.state_dict(), writer) with open('./root/checkpoint/epoch0.ckpt', 'rb') as reader: stat_dict = torch.load(reader) model.load_state_dict(stat_dict) print(model)
通过实现了 fsspec 规范的 tosfs,将火山 TOS 封装为 python 的标准读写接口,再使用 torchdata 直接读写本地路径实现对接,更多信息,请参见 tosfs。
支持 TOS FNS 桶和 HNS 桶。
执行以下命令安装 tosfs。
# 安装 tosfs for fsspec pip install tosfs
执行以下命令初始化客户端。
from tosfs.core import TosFileSystem # 配置访问密钥等信息 if __name__ == '__main__': tosfs = TosFileSystem( key='<Your Volces AK>', secret='<Your Volces SK>', endpoint_url='http://tos-cn-beijing.volces.com', region='cn-beijing', )
运行以下代码,使用 iterable-style datasets。
说明
本文以使用 北京地域的 bucketname
桶为例,建议您根据业务需求修改代码。
from torchdata.datapipes.iter import FSSpecFileLister, FSSpecFileOpener if __name__ == '__main__': kwargs = { 'key' : '<Your Volces AK>', 'secret' : '<Your Volces SK>', 'endpoint_url' : 'https://tos-cn-beijing.volces.com', 'region' :'cn-beijing' } bucket = 'bucketname' print('测试 iterable-style datasets') file_lister = FSSpecFileLister(root='tos://'+ bucket+'/shard-data/', **kwargs) iterable_dataset = FSSpecFileOpener(file_lister, mode='rb', **kwargs) # 遍历并读取Dataset中的对象。 for _, item in iterable_dataset: print(item.bucket, item.key) data = item.read() print(len(data))
运行以下代码,保存或加载 checkpoint。
说明
本文以使用 北京地域的 bucketname
桶为例,建议您根据业务需求修改代码。
import torch import torchvision from tosfs.core import TosFileSystem if __name__ == '__main__': tosfs = TosFileSystem( key='<Your Volces AK>', secret='<Your Volces SK>', endpoint_url='http://tos-cn-beijing.volces.com', region='cn-beijing', ) bucket = 'bucketname' print('测试 checkpointing') model = torchvision.models.resnet18() # 将 tos 封装为 python 的标准读写接口 with tosfs.open(bucket+ '/checkpoint/epoch1.ckpt', 'wb') as writer: torch.save(model.state_dict(), writer) with tosfs.open(bucket+'/checkpoint/epoch1.ckpt', 'rb') as reader: stat_dict = torch.load(reader) model.load_state_dict(stat_dict) print(model)