You need to enable JavaScript to run this app.
导航
Ray快速入门
最近更新时间:2024.09.02 16:02:04首次发布时间:2024.09.02 16:02:04

本章节从Core模块和Data模块介绍下如何快速使用Ray。也可以参考官网Getting Started

准备工作

部署Ray的环境。建议在EMR on VKE产品中部署Ray服务,参考EMR官网进行环境部署。也可以按照官网执行pip install ray方式进行手动部署。

Ray Core模块

Ray Core 提供的 API 将 Python 任务横向扩展到集群上,最关键的 API 是两个计算接口和一个数据接口。

  • Task:一种无状态的并行计算单元,面向函数(Function)的接口,用于定义一个函数,该函数可以在集群中分布式地执行。在函数调用时,使用Ray的@ray.remote装饰器将普通的Python函数转换为Task:

    import ray
    
    @ray.remote
    def add(x, y):
        return x + y
    
    # Call the function remotely.
    result_id = add.remote(1, 2)
    # Fetch the result.
    result = ray.get(result_id)  # result = 3
    
  • Actor:一种有状态的并行计算单元,面向类(Class)的接口,用于定义一个类,该类可以在集群中分布式地执行。使用Ray的@ray.remote装饰器将普通的Python类转换为Actor。适用于需要持久化状态、长时间运行。
    示例

    import ray
    
    @ray.remote
    class Counter(object):
        def __init__(self):
            self.value = 0
    
        def increment(self):
            self.value += 1
            return self.value
    
    # Create an actor from this class.
    counter = Counter.remote()
    
    # Call the actor.
    
  • Object:分布式的对象,对象不可变(Immutable),用于在 Task 和 Actor 之间传递数据。采用ray.get()ray.put()进行数据的加载和获取。

Ray Data模块

Ray Data是基于Ray Core的数据处理框架,专为 ML 设计的数据处理库,主要解决模型训练或推理相关的数据准备与处理问题,被称为数据的最后一公里问题。详细信息参考官网离线批量推理用于 ML 训练的数据预处理和摄取

Ray Data 对数据提供了一个抽象的类:Dataset,一种分布式数据集合。在 Dataset 上提供了常见的大数据处理的原语,覆盖了数据处理的大部分阶段,例如:

  • Loading data:数据加载,如读取 Parquet 文件等。
  • Transforming data:数据转换操作,如 map_batches()
  • Consuming data:数据消费,如使用 take_batch() 和 iter_batches() 等方法访问数据。
  • Saving data:数据保存,如调用 write_parquet() 将数据以Parquet格式保存到对象存储TOS中。

示例

from typing import Dict
import ray
from ray.data import Dataset

# 初始化Ray
ray.init()

def process_batch(batch: Dict[str, int]):
    # 对每个批次中的数据进行平方计算
    return {"result": [x ** 2 for x in batch["id"]]}


# 使用range方法生成数据
ds = ray.data.range(20)

# 使用map_batches方法并行处理数据
transformed_ds = ds.map_batches(process_batch, batch_size=4)

# 在这里可以按批次进行打印
for batch in transformed_ds.iter_batches(batch_size=5):
    print(batch)

# 将结果保存到本地
transformed_ds.write_parquet("local:///tmp/res/")