You need to enable JavaScript to run this app.
导航
在Pytorch训练框架中使用TOSFS
最近更新时间:2024.12.02 19:28:55首次发布时间:2024.12.02 19:28:55

Pytorch是主流的训练框架之一,在深度学习领域得到了广泛的应用。许多研究人员和开发者选择Pytorch来构建和训练他们的深度学习模型。Pytorch的优点包括:

  • 简洁易用:Pytorch的API设计简洁直观,易于学习和使用。
  • 动态图机制:Pytorch采用动态图机制,允许在运行时动态构建计算图,方便进行调试和模型的灵活调整。
  • 丰富的功能:Pytorch提供了丰富的功能,包括张量操作、自动求导、神经网络模块等,方便进行深度学习模型的构建和训练。
  • 强大的社区支持:Pytorch拥有庞大的社区,有丰富的文档、教程和开源项目可供参考和学习。

Pytorch支持通过fsspec接口来加载数据集,而TOSFS就是对fsspec接口的标准实现。本文介绍将Pytorch和TOSFS结合来实现加载/迭代数据集保存/加载Checkpoint以及训练这三个场景的demo。
通过如下命令安装TOSFS的最新版本以及Pytorch相关的依赖包:

说明

建议使用3.9及以上版本的Python环境。

pip install tosfs
pip install torch
pip install torchdata
pip install torchvision
pip install lightning

加载与迭代数据集

import os

import fsspec
from torchdata.datapipes.iter import FSSpecFileLister, FSSpecFileOpener
from tos import EnvCredentialsProvider

fsspec.register_implementation("tos", "tosfs.TosFileSystem", )

if __name__ == '__main__':

    kwargs = {
        'endpoint_url': os.environ.get("TOS_ENDPOINT"),
        'credentials_provider' : EnvCredentialsProvider(),
        'region': 'cn-beijing'
    }

    # iterable-style datasets
    file_lister = FSSpecFileLister(root='tos://{your-bucket}/{your-dataset}/', **kwargs)
    iterable_dataset = FSSpecFileOpener(file_lister, mode="rb", **kwargs)

    for _, item in iterable_dataset:
        data = item.read()

保存与加载Checkpoint

import os

import torch
import torchvision
from tos import EnvCredentialsProvider

from tosfs.core import TosFileSystem

if __name__ == '__main__':
    tosfs = TosFileSystem(
        endpoint_url=os.environ.get("TOS_ENDPOINT"),
        region=os.environ.get("TOS_REGION"),
        credentials_provider=EnvCredentialsProvider(),
    )

    model = torchvision.models.resnet18()

    # save
    with tosfs.open('{bucket}/checkpoint/epoch1.ckpt', 'wb') as writer:
        torch.save(model.state_dict(), writer)

    # load
    with tosfs.open('{bucket}/checkpoint/epoch1.ckpt', 'rb') as reader:
        stat_dict = torch.load(reader)

    model.load_state_dict(stat_dict)

训练

import os
import time

import fsspec
import lightning as L
import torch
import torchvision
from PIL import Image
from torchdata.datapipes.iter import FSSpecFileLister, FSSpecFileOpener
from tos import EnvCredentialsProvider

fsspec.register_implementation("tos", "tosfs.TosFileSystem")

class VisionModel(L.LightningModule):
    def __init__(
            self,
            dataset: torch.utils.data.Dataset,
            model_name: str,
            batch_size: int,
            num_workers: int,
    ):
        super().__init__()

        ctor = getattr(torchvision.models, model_name)
        self.model = ctor(weights=None)
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.epoch_start_time = None
        self.epoch_images = 0

        self.loss_fn = torch.nn.CrossEntropyLoss()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        if self.epoch_start_time is None:
            self.epoch_start_time = time.perf_counter()
        return torch.utils.data.DataLoader(
            self.dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
        )

    def forward(self, imgs):
        return self.model(imgs)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        self.epoch_images += len(imgs)
        preds = self.forward(imgs)
        loss = self.loss_fn(preds, labels)
        self.log('train_loss', loss)
        return loss

    def on_train_epoch_end(self):
        t = time.perf_counter() - self.epoch_start_time
        self.log('throughput', self.epoch_images / t)
        print(f'{self.epoch_images} images in {t:.2f}s = {self.epoch_images / t:.2f} images/sec')
        self.epoch_start_time = time.perf_counter()
        self.epoch_images = 0


def load_image(sample):
    to_tensor = torchvision.transforms.ToTensor()
    return (to_tensor(Image.open(sample['.jpg'])), int(sample['.cls'].read()))


if __name__ == '__main__':

    kwargs = {
        'endpoint_url': os.environ.get("TOS_ENDPOINT"),
        'credentials_provider' : EnvCredentialsProvider(),
        'region': 'cn-beijing'
    }

    dataset_uri = 'tos://{your-bucket}/{your-dataset}/'

    file_lister = FSSpecFileLister(root=dataset_uri, **kwargs)
    file_lister = file_lister.sharding_filter()
    iterable_dataset = FSSpecFileOpener(file_lister, mode='rb', **kwargs)
    iterable_dataset = iterable_dataset.load_from_tar().webdataset().map(load_image)

    L.seed_everything(21, True)
    trainer = L.Trainer(
        max_epochs=3, precision='16-mixed',
        enable_checkpointing=False,
    )

    model = VisionModel(
        iterable_dataset, model_name='resnet50', batch_size=64, num_workers=1)

    start = time.perf_counter()
    trainer.fit(model)
    end = time.perf_counter()
    print(end - start)