You need to enable JavaScript to run this app.
导航
创建批量推理任务
最近更新时间:2024.12.16 21:13:44首次发布时间:2024.12.16 21:13:44

介绍

通过 batch Inference job 接口,创建批量推理任务。详细介绍请参见批量推理说明

前提条件

使用步骤

1. 准备需批量处理的文件

批量推理任务支持jsonl格式的文件,具体说明见数据文件格式说明

{"custom_id": "request-1", "body": {"messages": [{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000,"top_p":1}}
{"custom_id": "request-2", "body": {"messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000}}

为避免因为文件格式错误,导致批量推理任务失败,可以使用下面脚本进行校验。

import json

def check_jsonl_file(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        total = 0
        custom_id_set = set()
        for line in file:
            if line.strip() == "":
                continue
            try:
                line_dict = json.loads(line)
            except json.decoder.JSONDecodeError:
                raise Exception(f"批量推理输入文件格式错误,第{total+1}行非json数据")
            if not line_dict.get("custom_id"):
                raise Exception(f"批量推理输入文件格式错误,第{total+1}行custom_id不存在")
            if line_dict.get("custom_id") in custom_id_set:
                raise Exception(
                    f"批量推理输入文件格式错误,custom_id={line_dict.get('custom_id', '')}存在重复"
                )
            else:
                custom_id_set.add(line_dict.get("custom_id"))
            if not isinstance(line_dict.get("body", ""), dict):
                raise Exception(
                    f"批量推理输入文件格式错误,custom_id={line_dict.get('custom_id', '')}的body非json字符串"
                )
            total += 1
    return total


# 替换<YOUR_JSONL_FILE>为你的JSONL文件路径
file_path = "<YOUR_JSONL_FILE>"
total_lines = check_jsonl_file(file_path)
print(f"文件中有效JSON数据的行数为: {total_lines}")

2.上传需批量处理的文件

您需要将批量处理任务的输入文件传入到TOS中,您可以选择下面方式上传文件:

为了方便您实现,我们提供了一些上传文件的python脚本,请参见2.TOS上传数据

完成上传后,您需要记录几个信息.

  • 输入文件的TOS bucket名称:<INPUT_BUCKET_NAME>
  • 输入文件在TOS中的路径:<INPUT_FILE>
  • 输出文件的TOS bucket名称:<OUTPUT_BUCKET_NAME>
  • 推理结果存放的文件夹路径:<OUTPUT_DIR>

TOS 中文件夹以“/”结尾,如output/

3.创建批量推理任务

通过API

获取了上面的信息后,您可以使用接口CreateBatchInferenceJob - 创建批量推理任务来创建批量推理任务。
其中绿色字体,需要您进行替换,其中X-Content-Sha256Authorization这两部分需要您签名和计算完成,具体请参见API签名调用指南

POST /?Action=CreateBatchInferenceJob&Version=2024-01-01 HTTP/1.1
Host: open.volcengineapi.com
Content-Type: application/json; charset=UTF-8
X-Date: 20240514T132743Z
X-Content-Sha256: 287e874e******d653b44d21e
Authorization: HMAC-SHA256 Credential=Adfks******wekfwe/20240514/cn-beijing/ark/request, SignedHeaders=host;x-content-sha256;x-date, Signature=47a7d934ff7b37c03938******cd7b8278a40a1057690c401e92246a0e41085f
{
    "Name": "你的批量推理任务",
    "Description": "你的批量推理任务",
    "ModelReference": {
      "FoundationModel": {
        "Name": "doubao-pro-32k",
        "ModelVersion": "240615"
      }
    },
    "InputFileTosLocation": {
      "BucketName": "<INPUT_BUCKET_NAME>",
      "ObjectKey": "<INPUT_FILE>"
    },
    "OutputDirTosLocation": {
      "BucketName": "<OUTPUT_BUCKET_NAME>",
      "ObjectKey": "<OUTPUT_DIR>"
    },
    "ProjectName":"default",
    "CompletionWindow": "1d",
    "Tags": [
      {
        "Key": "test_key",
        "Value": "test_value"
      }
    ]
}

通过Python

下面是通过 Python 实现签名以及创建批量推理任务的工程文件。

batch_inference.zip
未知大小

在这之前,您还是要配置对 client.py 以及 main.py 文件进行修改(参考下面代码中标记绿色的部分以及说明):

  • 配置Access Key到环境变量中:获取您的Access Key,配置环境变量可以参考2.配置 API Key 到环境变量
  • 替换<YOUR_ACCOUNT_ID>为您的账号ID:获取账号ID,可以单击官网右上角头像获取。
  • 替换模型:更改self.model = "doubao-pro-32k"self.model_version = "240615"为您需要使用的模型名称以及模型版本,具体支持批量推理的模型见适用模型
import json
from typing import Dict, Any
import aiohttp
import backoff
from common import utils
import os

class BatchInferenceClient:
    def __init__(self):
        """
        初始化BatchInferenceClient类的实例。
        该方法设置了一些默认属性,如重试次数、超时时间、访问密钥(AK/SK)、账号ID、API版本、服务域名、区域和基础参数。
        访问密钥(AK/SK)从环境变量中获取,以提高安全性。
        基础参数包括API版本和账号ID,这些参数在每次请求中都会用到。
        """
        # 设置重试次数为3次
        self._retry = 3
        # 设置请求超时时间为60秒
        self._timeout = aiohttp.ClientTimeout(60)
        # Access Key访问火山云资源的秘钥,可从访问控制-API访问密钥获取获取
        # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK
        # 从环境变量中获取访问密钥(AK)
        self.ak = os.environ.get("YOUR_ACCESS_KEY")
        # 从环境变量中获取秘密密钥(SK)
        self.sk = os.environ.get("YOUR_SECRET_KEY")
        # 设置模型名称
        self.model = "doubao-pro-32k"
        # 设置模型版本
        self.model_version = "240615"
        # 需要替换为您的账号id,可从火山引擎官网点击账号头像,弹出框中找到,复制“账号ID”后的一串数字
        self.account_id = "<YOUR_ACCOUNT_ID>"
        # 设置API版本
        self.version = "2024-01-01"
        # 设置服务域名
        self.domain = "open.volcengineapi.com"
        # 设置区域
        self.region = "cn-beijing"
        # 设置服务名称
        self.service = "ark"
        # 设置基础参数,包括API版本和账号ID
        self.base_param = {"Version": self.version, "X-Account-Id": self.account_id}

    async def _call(self, url, headers, req: Dict[str, Any]):
        """
        异步调用指定URL的HTTP POST请求,并处理请求的重试逻辑。
        :param url: 请求的目标URL。
        :param headers: 请求的HTTP头部信息。
        :param req: 请求的JSON格式数据。
        :return: 响应的JSON数据。
        :raises Exception: 如果请求失败或解析响应失败,抛出异常。
        """
        @backoff.on_exception(
            backoff.expo, Exception, factor=0.1, max_value=5, max_tries=self._retry
        )
        async def _retry_call(body):
            """
            内部函数,用于发送HTTP POST请求,并处理请求的重试逻辑。
            :param body: 请求的JSON格式数据。
            :return: 响应的JSON数据。
            :raises Exception: 如果请求失败或解析响应失败,抛出异常。
            """
            async with aiohttp.request(
                method="POST",
                url=url,
                json=body,
                headers=headers,
                timeout=self._timeout,
            ) as response:
                try:
                    return await response.json()
                except Exception as e:
                    raise e

        try:
            return await _retry_call(req)
        except Exception as e:
            raise e


    async def create_batch_inference_job(
        self, input_bucket_name,output_bucket_name, input_object_key, output_object_key: str
    ):
        """
        异步创建批量推理任务。
        :param input_bucket_name: 任务JSONL文件存储的桶名称。
        :param input_bucket_name: 任务结果存储的桶名称。
        :param input_object_key: 输入文件的对象键。
        :param output_object_key: 输出文件的对象键。
        :return: 响应的JSON数据。
        :raises Exception: 如果请求失败或解析响应失败,抛出异常。
        """
        action = "CreateBatchInferenceJob"
        canonicalQueryString = "Action={}&Version={}&X-Account-Id={}".format(
            action, self.version, self.account_id
        )
        url = "https://" + self.domain + "/?" + canonicalQueryString
        extra_param = {
            "Action": action,
            "ProjectName": "default",
            "Name": "just_test",
            "ModelReference": {
                "FoundationModel": {"Name": self.model, "ModelVersion": self.model_version},
            },
            "InputFileTosLocation": {
                "BucketName": input_bucket_name,
                "ObjectKey": input_object_key,
            },
            "OutputDirTosLocation": {
                "BucketName": output_bucket_name,
                "ObjectKey": output_object_key,
            },
            "CompletionWindow": "3d",
        }
        param = self.base_param | extra_param


        payloadSign = utils.get_hmac_encode16(json.dumps(param))
        headers = utils.get_hashmac_headers(
            self.domain,
            self.region,
            self.service,
            canonicalQueryString,
            "POST",
            "/",
            "application/json; charset=utf-8",
            payloadSign,
            self.ak,
            self.sk,
        )
        return await self._call(url, headers, param)

    async def ListBatchInferenceJobs(self, phases=None):
        """
        异步列出批量推理任务。
        :param phases: 任务阶段列表,默认为空列表。
        :return: 响应的JSON数据。
        :raises Exception: 如果请求失败或解析响应失败,抛出异常。
        """
        # 如果phases为None,则初始化为空列表
        if phases is None:
            phases = []
        # 设置操作名称为ListBatchInferenceJobs
        action = "ListBatchInferenceJobs"
        # 构建规范查询字符串,包含操作名称、API版本和账号ID
        canonicalQueryString = "Action={}&Version={}&X-Account-Id={}".format(
            action, self.version, self.account_id
        )
        # 构建请求URL
        url = "https://" + self.domain + "/?" + canonicalQueryString
        # 构建额外参数,包括操作名称、项目名称和过滤器
        extra_param = {
            "Action": action,
            "ProjectName": "default",
            "Filter": {"Phases": phases},
        }
        # 合并基础参数和额外参数
        param = self.base_param | extra_param

        # 计算请求体的签名
        payloadSign = utils.get_hmac_encode16(json.dumps(param))
        # 获取请求头,包含签名信息
        headers = utils.get_hashmac_headers(
            self.domain,
            self.region,
            self.service,
            canonicalQueryString,
            "POST",
            "/",
            "application/json; charset=utf-8",
            payloadSign,
            self.ak,
            self.sk,
        )
        # 调用_call方法发送请求并返回响应
        return await self._call(url, headers, param)

修改工程文件中的main.py文件:根据示例代码中的绿色字体和注释,为您的配置。

import uvloop
from bytedance_ark_batch_inference.client import BatchInferenceClient
from bytedance_tos.tos_client import TosClient

async def main():
    tos_client = TosClient()
    batch_inference_client = BatchInferenceClient()
    # 上传您您本地或者远程存储的 JSONL 文件,至 TOS 桶中
    # 替换为您的 TOS 桶名称
    input_bucket_name = "<INPUT_BUCKET_NAME>"
    # 替换为您希望上传至TOS桶的路径和文件名
    object_key = "<INPUT_FILE>"
    # 替换<YOUR_JSONL_FILE>为您本地或者远程存储的JSONL文件路径
    tos_client.put_object_from_file(input_bucket_name, object_key, "<YOUR_JSONL_FILE>")

    # 创建批量推理任务,并将结果存储至您指定的 TOS 桶中
    # 替换为替换为您希望存储批量推理结果 TOS 桶名称
    output_bucket_name = "<OUTPUT_BUCKET_NAME>"
    # 替换为您希望存储批量推理结果的路径,例如:output/
    output_key = "<OUTPUT_DIR>"
    response = await batch_inference_client.create_batch_inference_job(input_bucket_name,output_bucket_name, object_key, output_key)
    response = await batch_inference_client.ListBatchInferenceJobs(['Running'])

    print('done')

if __name__ == '__main__':
    uvloop.run(main())

完成配置,在终端中运行python main.py,运行脚本。

相关文档