You need to enable JavaScript to run this app.
导航
批量推理任务最佳实践
最近更新时间:2024.12.16 23:09:27首次发布时间:2024.11.21 15:36:08

批量推理是一种高效的推理模式,可通过控制台操作或 OpenAPI 创建异步任务对请求进行处理。批量推理支持多种模型,且支持与对应模型在线推理一致的参数配置选项。
批量推理提供更高的速率限制和更大的每日吞吐,每个批量推理任务的运行时间最长可配置至28天,特别适用于需要处理大量数据且无需实时获取结果的场景,例如日志分析、离线数据预测和评测等。
本实践教程将通过实际示例,带您了解并逐步了解批量推理的使用方法。示例基于 Python,以 Doubao-pro-32k/240615 模型为例,展示如何进行批量推理,主要涉及到两个部分:

  • 调整数据格式以适配批量推理任务后,使用 TOS (对象存储服务)进行数据上传与存储
  • 通过控制台或批量推理 OpenAPI 发起批量推理任务

1. 准备批量推理文件

请参考 输入文件格式说明 准备包含您要进行推理的请求的输入文件,并确保账户已启用火山引擎对象存储 TOS 服务
以下是一个输入文件示例,包含 2 个请求:

data.jsonl
未知大小

{"custom_id": "request-1", "body": {"messages": [{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000,"top_p":1}}
{"custom_id": "request-2", "body": {"messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000}}

使用如下代码对输入数据进行前置检测

# input_file_name-文件的完整路径
with open(input_file_name, 'r', encoding='utf-8') as file:
    total = 0
    custom_id_set = set()
    for line in file:
        if line.strip() == '':
            continue
        try:
            line_dict = json.loads(line)
        except json.decoder.JSONDecodeError:
            raise Exception(f"批量推理任务失败,第{total+1}行非json数据")
        if not line_dict.get('custom_id'):
            raise Exception(f"批量推理任务失败,第{total+1}行custom_id不存在")
        if line_dict.get('custom_id') in custom_id_set:
            raise Exception(f"批量推理任务失败,custom_id={line_dict.get('custom_id', '')}存在重复")
        else:
            custom_id_set.add(line_dict.get('custom_id'))
        if not isinstance(line_dict.get('body', ''), dict):
            raise Exception(f"批量推理任务失败,custom_id={line_dict.get('custom_id', '')}的body非json字符串")
        total += 1
    return total

2.TOS上传数据

数据上传

数据上传可选择通过 TOS 控制台或者 OpenAPI:

  • 上传文件前,请确保您已经创建存储桶。具体步骤,请参见创建存储桶
  • 文件名的命名规范,请参见文件说明
  • 简单上传方式最大能够上传 5GB 文件。
  • 使用控制台分片上传最大能够上传 50GB 文件。
    • 如果您的文件大于 50GB,您可以使用 API 上传,该方式最大支持上传 48.8TB 文件(每个分片最大 5GB,分片数量最多 10000)。具体步骤请参见UploadPart

通过控制台

控制台快速入门--对象存储-火山引擎

通过 SDK

快速入门(Python SDK)--对象存储-火山引擎

安装及初始化

# 通过 pip 安装
pip install tos

# 通过源码安装, 从 Github 下载相应版本的 SDK 包,解压后进入目录,执行如下命令。
python3 setup.py install
import os
import tos

# AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取
# 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK
ak = os.environ.get("YOUR_ACCESS_KEY")
sk = os.environ.get("YOUR_SECRET_KEY")
# your endpoint 和 your region 填写Bucket 所在区域对应的Endpoint。
# 以华北2(北京)为例,your endpoint 填写 tos-cn-beijing.volces.com,your region 填写 cn-beijing。
endpoint = "tos-cn-beijing.volces.com"
region = "cn-beijing"

client = tos.TosClientV2(ak, sk, endpoint, region)

上传本地文件:普通上传

普通上传(Python SDK)--对象存储-火山引擎

bucket_name = "demo-bucket-test"
# 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt
object_key = "data.jsonl"
# 本地文件路径
file_name = "/usr/local/data.jsonl"
try:
    # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现
    client = tos.TosClientV2(ak, sk, endpoint, region)
    # 将本地文件上传到目标桶中
    # file_name为本地文件的完整路径。
    client.put_object_from_file(bucket_name, object_key, file_name)
except tos.exceptions.TosClientError as e:
    # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
    print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
except tos.exceptions.TosServerError as e:
    # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
    print('fail with server error, code: {}'.format(e.code))
    # request id 可定位具体问题,强烈建议日志中保存
    print('error with request id: {}'.format(e.request_id))
    print('error with message: {}'.format(e.message))
    print('error with http code: {}'.format(e.status_code))
    print('error with ec: {}'.format(e.ec))
    print('error with request url: {}'.format(e.request_url))
except Exception as e:
    print('fail with unknown error: {}'.format(e))

分片上传

分片上传(Python SDK)--对象存储-火山引擎
断点续传(Python SDK)--对象存储-火山引擎

说明

  • 文件大小超过5G需要走分片上传,网络稳定的情况下可以直接走分片上传,网络不稳定建议走断点续传。
  • 每个分片大小为 4MB - 5GB,网络环境越稳定分片大小可选越大。建议分片大小为50 MB - 1GB。
# 断点续传demo
bucket_name = "demo-bucket-test"
# 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt
object_key = "data.jsonl"
# 本地文件完整路径,例如usr/local/testfile.txt
filename = "/usr/local/testfile.txt"
try:
    # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现
    client = tos.TosClientV2(ak, sk, endpoint, region)


    def percentage(
        consumed_bytes: int,
        total_bytes: int,
        rw_once_bytes: int,
        type: DataTransferType,
    ):
        if total_bytes:
            rate = int(100 * float(consumed_bytes) / float(total_bytes))
            print(
                "rate:{}, consumed_bytes:{},total_bytes{}, rw_once_bytes:{}, type:{}".format(
                    rate, consumed_bytes, total_bytes, rw_once_bytes, type
                )
            )


    # 可通过part_size可选参数指定分片大小
    # 通过enable_checkpoint参数开启和关闭断点续传特性
    # 通过task_num设置线程数
    client.upload_file(
        bucket_name,
        object_key,
        filename,
        # 设置断点续传执行线程数,默认为1
        task_num=3,
        # 设置断点续传分片大小,默认20mb
        part_size=1024 * 1024 * 5,
        # 设置断点续传进度条回调函数
        data_transfer_listener=percentage,
    )
except tos.exceptions.TosClientError as e:
    # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
    print("fail with client error, message:{}, cause: {}".format(e.message, e.cause))
except tos.exceptions.TosServerError as e:
    # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
    print("fail with server error, code: {}".format(e.code))
    # request id 可定位具体问题,强烈建议日志中保存
    print("error with request id: {}".format(e.request_id))
    print("error with message: {}".format(e.message))
    print("error with http code: {}".format(e.status_code))
    print("error with ec: {}".format(e.ec))
    print("error with request url: {}".format(e.request_url))
except Exception as e:
    print("fail with unknown error: {}".format(e))

3.创建及执行批量推理任务

您有3种方式创建批量推理任务,您可以根据自己的需求灵活选择。

  • 通过火山方舟控制台可视化创建。
  • 调用 API 创建。
  • 通过SDK 创建。

通过控制台

参考批量推理教程,在“批量推理”页面,点击左上角 创建批量任务 按钮跳转至创建页。

  • 账户维度有如下配额:

    • 一个项目下 7 天内提交的批量推理任务数量最多为 500,超出时将暂时无法提交新的任务。
    • 一个项目下同时处于「运行中」状态的批量推理任务数量最多为 3,超出时其余任务将处于「排队中」状态等待运行。
      • 请注意,账号下实际同时运行的任务数量会受到平台总体资源的限制和任务调度策略的影响。

    如希望调整配额,您可以前往配额中心进行申请。

通过 API

参考创建批量推理任务API进行批量推理任务的创建,参考示例如下:

POST /?Action=CreateBatchInferenceJob&Version=2024-01-01 HTTP/1.1
Host: open.volcengineapi.com
Content-Type: application/json; charset=UTF-8
X-Date: 20240514T132743Z
X-Content-Sha256: 287e874e******d653b44d21e
Authorization: HMAC-SHA256 Credential=Adfks******wekfwe/20240514/cn-beijing/ark/request, SignedHeaders=host;x-content-sha256;x-date, Signature=47a7d934ff7b37c03938******cd7b8278a40a1057690c401e92246a0e41085f
{
    "Name": "批量推理任务",
    "Description": "这是一个批量推理任务",
    "ModelReference": {
      "FoundationModel": {
        "Name": "doubao-pro-32k",
        "ModelVersion": "240615"
      }
    },
    "InputFileTosLocation": {
      "BucketName": "demo-bucket-test",
      "ObjectKey": "data.jsonl"
    },
    "OutputDirTosLocation": {
      "ObjectKey": "output/",
      "BucketName": "demo-bucket-test"
    },
    "ProjectName":"default",
    "CompletionWindow": "1d",
    "Tags": [
      {
        "Key": "test_key",
        "Value": "test_value"
      }
    ]
}

通过 SDK

下面为您提供提供了一种封装 TOS SDK 和批量推理 API 的方式,以便您可以方便地执行自己的批量推理任务。

batch_inference.zip
未知大小

构造 TOS Client

参考上传本地文件:普通上传构造TOS Client,如示例代码:

import tos
import os

class TosClient:
    # 文档链接:https://www.volcengine.com/docs/6349/92786
    def __init__(self):
        # AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取
        # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK
        self.ak = os.environ.get("YOUR_ACCESS_KEY")
        self.sk = os.environ.get("YOUR_SECRET_KEY")
        
        # your endpoint 和 your region 填写Bucket 所在区域对应的Endpoint。# 以华北2(北京)为例,your endpoint 填写 tos-cn-beijing.volces.com,your region 填写 cn-beijing。
        self.endpoint = "tos-cn-beijing.volces.com"
        self.region = "cn-beijing"
        self.client = tos.TosClientV2(self.ak, self.sk, self.endpoint, self.region)

    def create_bucket(self, bucket_name):
        try:
            # 设置桶存储桶读写权限
            self.client.create_bucket(bucket_name, acl=tos.ACLType.ACL_Public_Read_Write)
        except tos.exceptions.TosClientError as e:
            # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
            print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
        except tos.exceptions.TosServerError as e:
            # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
            print('fail with server error, code: {}'.format(e.code))
            # request id 可定位具体问题,强烈建议日志中保存
            print('error with request id: {}'.format(e.request_id))
            print('error with message: {}'.format(e.message))
            print('error with http code: {}'.format(e.status_code))
            print('error with ec: {}'.format(e.ec))
            print('error with request url: {}'.format(e.request_url))
        except Exception as e:
            print('fail with unknown error: {}'.format(e))

    def put_object_from_file(self, bucket_name, object_key, file_path):
        try:
            # 通过字符串方式添加 Object
            res = self.client.put_object_from_file(bucket_name, object_key, file_path)
        except tos.exceptions.TosClientError as e:
            # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
            print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
        except tos.exceptions.TosServerError as e:
            # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
            print('fail with server error, code: {}'.format(e.code))
            # request id 可定位具体问题,强烈建议日志中保存
            print('error with request id: {}'.format(e.request_id))
            print('error with message: {}'.format(e.message))
            print('error with http code: {}'.format(e.status_code))
            print('error with ec: {}'.format(e.ec))
            print('error with request url: {}'.format(e.request_url))
        except Exception as e:
            print('fail with unknown error: {}'.format(e))

    def get_object(self, bucket_name, object_name):
        try:
            # 从TOS bucket中下载对象到内存中
            object_stream = self.client.get_object(bucket_name, object_name)
            # object_stream 为迭代器可迭代读取数据
            # for content in object_stream:
            #     print(content)
            # 您也可调用 read()方法一次在内存中获取完整的数据
            print(object_stream.read())
        except tos.exceptions.TosClientError as e:
            # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
            print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
        except tos.exceptions.TosServerError as e:
            # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
            print('fail with server error, code: {}'.format(e.code))
            # request id 可定位具体问题,强烈建议日志中保存
            print('error with request id: {}'.format(e.request_id))
            print('error with message: {}'.format(e.message))
            print('error with http code: {}'.format(e.status_code))
            print('error with ec: {}'.format(e.ec))
            print('error with request url: {}'.format(e.request_url))
        except Exception as e:
            print('fail with unknown error: {}'.format(e))

    def close_client(self):
        try:
            # 执行相关操作后,将不再使用的TosClient关闭
            self.client.close()
        except tos.exceptions.TosClientError as e:
            # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
            print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
        except tos.exceptions.TosServerError as e:
            # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
            print('fail with server error, code: {}'.format(e.code))
            # request id 可定位具体问题,强烈建议日志中保存
            print('error with request id: {}'.format(e.request_id))
            print('error with message: {}'.format(e.message))
            print('error with http code: {}'.format(e.status_code))
            print('error with ec: {}'.format(e.ec))
            print('error with request url: {}'.format(e.request_url))
        except Exception as e:
            print('fail with unknown error: {}'.format(e))

构造BatchInference Client

构造一个BatchInferenceClient,示例代码包含了CreateBatchInferenceJobListBatchInferenceJobs两个接口,您也可以根据创建批量推理任务API添加更多API。

import json
from typing import Dict, Any

import aiohttp
import backoff

from common import utils

class BatchInferenceClient:
    def __init__(self):
        self._retry = 3
        self._timeout = aiohttp.ClientTimeout(60)
        # AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取
        # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK
        self.ak = os.environ.get("YOUR_ACCESS_KEY")
        self.sk = os.environ.get("YOUR_SECRET_KEY")
        # 需要替换为您的账号id,可从火山引擎官网点击账号头像,弹出框中找到,复制“账号ID”后的一串数字
        self.account_id = 'Your_ACCOUNT_ID'
        self.version = '2024-01-01'
        self.domain = 'open.volcengineapi.com'
        self.region = 'cn-beijing'
        self.service = 'ark'
        self.base_param = {'Version': self.version, 'X-Account-Id': self.account_id}

    async def _call(self, url, headers, req: Dict[str, Any]):
        @backoff.on_exception(backoff.expo, Exception, factor=0.1, max_value=5, max_tries=self._retry)
        async def _retry_call(body):
            async with aiohttp.request(method='POST', url=url, json=body, headers=headers,
                                       timeout=self._timeout) as response:
                try:
                    return await response.json()
                except Exception as e:
                    raise e

        try:
            return await _retry_call(req)
        except Exception as e:
            raise e

    async def create_batch_inference_job(self, bucket_name, input_object_key, output_object_key: str):
        action = 'CreateBatchInferenceJob'
        canonicalQueryString = 'Action={}&Version={}&X-Account-Id={}'.format(
            action, self.version, self.account_id)
        url = 'https://' + self.domain + '/?' + canonicalQueryString
        extra_param = {
            'Action': action,
            'ProjectName': 'default',
            'Name': 'just_test',
            'ModelReference': {
                'FoundationModel': {
                    'Name': 'doubao-pro-32k',
                    'ModelVersion': '240615'
                }
            },
            'InputFileTosLocation': {
                'BucketName': bucket_name,
                'ObjectKey': input_object_key
            },
            'OutputDirTosLocation': {
                'BucketName': bucket_name,
                'ObjectKey': output_object_key
            },
            'CompletionWindow': '3d'
        }
        param = self.base_param | extra_param

        payloadSign = utils.get_hmac_encode16(json.dumps(param))
        headers = utils.get_hashmac_headers(self.domain, self.region, self.service, canonicalQueryString,
                                            'POST', '/', 'application/json; charset=utf-8', payloadSign, self.ak,
                                            self.sk)
        return await self._call(url, headers, param)

    async def ListBatchInferenceJobs(self, phases=None):
        if phases is None:
            phases = []
        action = 'ListBatchInferenceJobs'
        canonicalQueryString = 'Action={}&Version={}&X-Account-Id={}'.format(
            action, self.version, self.account_id)
        url = 'https://' + self.domain + '/?' + canonicalQueryString
        extra_param = {
            'Action': action,
            'ProjectName': 'default',
            'Filter': {
                'Phases': phases
            }
        }
        param = self.base_param | extra_param

        payloadSign = utils.get_hmac_encode16(json.dumps(param))
        headers = utils.get_hashmac_headers(self.domain, self.region, self.service, canonicalQueryString,
                                            'POST', '/', 'application/json; charset=utf-8', payloadSign, self.ak,
                                            self.sk)
        return await self._call(url, headers, param)

执行批量推理任务

构造完成TOS ClientBatchInference Client后,您可以直接调用TOS Client上传您的批量推理文件。然后调用BatchInference Client创建一个批量推理任务,之后可以调用ListBatchInferenceJobs查看运行中的批量推理任务。

import uvloop
from bytedance_ark_batch_inference.client import BatchInferenceClient
from bytedance_tos.tos_client import TosClient

async def main():
    tos_client = TosClient()
    batch_inference_client = BatchInferenceClient()

    # put object
    bucket_name = "demo-bucket-test"
    object_key = "data.jsonl"
    tos_client.put_object_from_file(bucket_name, object_key, "data.jsonl")

    # create batch job
    # output key should be existed
    output_key = "output/"
    response = await batch_inference_client.create_batch_inference_job(bucket_name, object_key, output_key)
    response = await batch_inference_client.ListBatchInferenceJobs(['Running'])

    print('done')

if __name__ == '__main__':
    uvloop.run(main())

4.查看批量推理任务的状态

调用 API

您可以参考火山方舟的 API 说明创建批量推理任务API,使用GetBatchInferenceJob接口查看已创建的批量推理任务的状态,参考示例,获取批量推理任务的状态信息:
批量推理任务状态与对应描述如下:

Status Code

状态

描述

Queued

排队中

任务由于账号下并发任务数达到上限等原因需排队等候

Running

运行中

任务正在运行中

Completed

完成

所有请求已经处理完毕,任务已完成

Terminating

终止中

由于到期等系统原因或手动终止,任务当前处于终止中状态

Terminated

已终止

任务已被取消

Failed

失败

输入文件校验失败或其他原因导致任务失败

通过控制台

您也可以通过火山方舟控制台,在批量推理列表页或任务详情页查询创建的批量推理任务的状态信息。

5.获取批量推理结果

控制台

在批量推理任务运行结束后,可点击 查看结果 按钮或在详情页的「文件信息」中点击跳转至 TOS 页面查看并下载输出文件。
结果文件的 TOS 存储路径如下:

  • 结果文件:tos://demo-bucket-test/output/bi-2024XXXXXX/output/results.jsonl
  • 错误信息文件:tos://demo-bucket-test/output/bi-2024XXXXXX/error/errors.jsonl

TOS下载结果

普通下载(Python SDK)--对象存储-火山引擎
断点续传下载(Python SDK)--对象存储-火山引擎

bucket_name = "demo-bucket-test"
# 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt
object_key = "output/bi-20241111191820-tfjbg/output/results.jsonl"
# 本地文件完整路径,例如usr/local/testfile.txt
file_path = "/usr/local/testfile.txt"
try:
    # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现
    client = tos.TosClientV2(ak, sk, endpoint, region)
    def percentage(consumed_bytes, total_bytes, rw_once_bytes, type: DataTransferType):
        if total_bytes:
            rate = int(100 * float(consumed_bytes) / float(total_bytes))
            print("rate:{}, consumed_bytes:{},total_bytes{}, rw_once_bytes:{}, type:{}".format(rate, consumed_bytes,
                                                                                               total_bytes,
                                                                                               rw_once_bytes, type))
    client.download_file(bucket_name, object_key, file_path,
                         # 通过可选参数part_size配置下载时分片大小,默认为20mb
                         part_size=1024 * 1024 * 20,
                         # 通过可选参数task_num配置下载分片的线程数,默认为1
                         task_num=3,
                         # 通过可选参数data_transfer_listener配置进度条
                         data_transfer_listener=percentage)
except tos.exceptions.TosClientError as e:
    # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常
    print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
except tos.exceptions.TosServerError as e:
    # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息
    print('fail with server error, code: {}'.format(e.code))
    # request id 可定位具体问题,强烈建议日志中保存
    print('error with request id: {}'.format(e.request_id))
    print('error with message: {}'.format(e.message))
    print('error with http code: {}'.format(e.status_code))
    print('error with ec: {}'.format(e.ec))
    print('error with request url: {}'.format(e.request_url))
except Exception as e:
    print('fail with unknown error: {}'.format(e))

相关文档

批量推理:如何使用控制台管理批量推理任务。
批量推理:批量推理的常见问题。