批量推理是一种高效的推理模式,可通过控制台操作或 OpenAPI 创建异步任务对请求进行处理。批量推理支持多种模型,且支持与对应模型在线推理一致的参数配置选项。
本实践教程将通过实际示例,带您了解并逐步了解批量推理的使用方法。示例基于 Python,以 Doubao-pro-32k/240615 模型为例,展示如何进行批量推理,主要涉及到两个部分:
请参考 输入文件格式说明 准备包含您要进行推理的请求的输入文件,并确保账户已启用火山引擎对象存储 TOS 服务。
以下是一个输入文件示例,包含 2 个请求:
{"custom_id": "request-1", "body": {"messages": [{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000,"top_p":1}} {"custom_id": "request-2", "body": {"messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000}}
# input_file_name-文件的完整路径 with open(input_file_name, 'r', encoding='utf-8') as file: total = 0 custom_id_set = set() for line in file: if line.strip() == '': continue try: line_dict = json.loads(line) except json.decoder.JSONDecodeError: raise Exception(f"批量推理任务失败,第{total+1}行非json数据") if not line_dict.get('custom_id'): raise Exception(f"批量推理任务失败,第{total+1}行custom_id不存在") if line_dict.get('custom_id') in custom_id_set: raise Exception(f"批量推理任务失败,custom_id={line_dict.get('custom_id', '')}存在重复") else: custom_id_set.add(line_dict.get('custom_id')) if not isinstance(line_dict.get('body', ''), dict): raise Exception(f"批量推理任务失败,custom_id={line_dict.get('custom_id', '')}的body非json字符串") total += 1 return total
数据上传可选择通过 TOS 控制台或者 OpenAPI:
# 通过 pip 安装 pip install tos # 通过源码安装, 从 Github 下载相应版本的 SDK 包,解压后进入目录,执行如下命令。 python3 setup.py install
import os import tos # AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取 # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK ak = os.environ.get("YOUR_ACCESS_KEY") sk = os.environ.get("YOUR_SECRET_KEY") # your endpoint 和 your region 填写Bucket 所在区域对应的Endpoint。 # 以华北2(北京)为例,your endpoint 填写 tos-cn-beijing.volces.com,your region 填写 cn-beijing。 endpoint = "tos-cn-beijing.volces.com" region = "cn-beijing" client = tos.TosClientV2(ak, sk, endpoint, region)
bucket_name = "demo-bucket-test" # 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt object_key = "data.jsonl" # 本地文件路径 file_name = "/usr/local/data.jsonl" try: # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现 client = tos.TosClientV2(ak, sk, endpoint, region) # 将本地文件上传到目标桶中 # file_name为本地文件的完整路径。 client.put_object_from_file(bucket_name, object_key, file_name) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e))
# 断点续传demo bucket_name = "demo-bucket-test" # 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt object_key = "data.jsonl" # 本地文件完整路径,例如usr/local/testfile.txt filename = "/usr/local/testfile.txt" try: # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现 client = tos.TosClientV2(ak, sk, endpoint, region) def percentage( consumed_bytes: int, total_bytes: int, rw_once_bytes: int, type: DataTransferType, ): if total_bytes: rate = int(100 * float(consumed_bytes) / float(total_bytes)) print( "rate:{}, consumed_bytes:{},total_bytes{}, rw_once_bytes:{}, type:{}".format( rate, consumed_bytes, total_bytes, rw_once_bytes, type ) ) # 可通过part_size可选参数指定分片大小 # 通过enable_checkpoint参数开启和关闭断点续传特性 # 通过task_num设置线程数 client.upload_file( bucket_name, object_key, filename, # 设置断点续传执行线程数,默认为1 task_num=3, # 设置断点续传分片大小,默认20mb part_size=1024 * 1024 * 5, # 设置断点续传进度条回调函数 data_transfer_listener=percentage, ) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print("fail with client error, message:{}, cause: {}".format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print("fail with server error, code: {}".format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print("error with request id: {}".format(e.request_id)) print("error with message: {}".format(e.message)) print("error with http code: {}".format(e.status_code)) print("error with ec: {}".format(e.ec)) print("error with request url: {}".format(e.request_url)) except Exception as e: print("fail with unknown error: {}".format(e))
参考批量推理教程,在“批量推理”页面,点击左上角 创建批量任务 按钮跳转至创建页。
POST /?Action=CreateBatchInferenceJob&Version=2024-01-01 HTTP/1.1 Host: open.volcengineapi.com Content-Type: application/json; charset=UTF-8 X-Date: 20240514T132743Z X-Content-Sha256: 287e874e******d653b44d21e Authorization: HMAC-SHA256 Credential=Adfks******wekfwe/20240514/cn-beijing/ark/request, SignedHeaders=host;x-content-sha256;x-date, Signature=47a7d934ff7b37c03938******cd7b8278a40a1057690c401e92246a0e41085f { "Name": "批量推理任务", "Description": "这是一个批量推理任务", "ModelReference": { "FoundationModel": { "Name": "doubao-pro-32k", "ModelVersion": "240615" } }, "InputFileTosLocation": { "BucketName": "demo-bucket-test", "ObjectKey": "data.jsonl" }, "OutputDirTosLocation": { "ObjectKey": "output/", "BucketName": "demo-bucket-test" }, "ProjectName":"default", "CompletionWindow": "1d", "Tags": [ { "Key": "test_key", "Value": "test_value" } ] }
下面为您提供提供了一种封装 TOS SDK 和批量推理 API 的方式,以便您可以方便地执行自己的批量推理任务。
参考上传本地文件:普通上传构造TOS Client,如示例代码:
import tos import os class TosClient: # 文档链接:https://www.volcengine.com/docs/6349/92786 def __init__(self): # AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取 # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK self.ak = os.environ.get("YOUR_ACCESS_KEY") self.sk = os.environ.get("YOUR_SECRET_KEY") # your endpoint 和 your region 填写Bucket 所在区域对应的Endpoint。# 以华北2(北京)为例,your endpoint 填写 tos-cn-beijing.volces.com,your region 填写 cn-beijing。 self.endpoint = "tos-cn-beijing.volces.com" self.region = "cn-beijing" self.client = tos.TosClientV2(self.ak, self.sk, self.endpoint, self.region) def create_bucket(self, bucket_name): try: # 设置桶存储桶读写权限 self.client.create_bucket(bucket_name, acl=tos.ACLType.ACL_Public_Read_Write) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e)) def put_object_from_file(self, bucket_name, object_key, file_path): try: # 通过字符串方式添加 Object res = self.client.put_object_from_file(bucket_name, object_key, file_path) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e)) def get_object(self, bucket_name, object_name): try: # 从TOS bucket中下载对象到内存中 object_stream = self.client.get_object(bucket_name, object_name) # object_stream 为迭代器可迭代读取数据 # for content in object_stream: # print(content) # 您也可调用 read()方法一次在内存中获取完整的数据 print(object_stream.read()) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e)) def close_client(self): try: # 执行相关操作后,将不再使用的TosClient关闭 self.client.close() except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e))
import json from typing import Dict, Any import aiohttp import backoff from common import utils import os class BatchInferenceClient: def __init__(self): """ 初始化BatchInferenceClient类的实例。 该方法设置了一些默认属性,如重试次数、超时时间、访问密钥(AK/SK)、账号ID、API版本、服务域名、区域和基础参数。 访问密钥(AK/SK)从环境变量中获取,以提高安全性。 基础参数包括API版本和账号ID,这些参数在每次请求中都会用到。 """ # 设置重试次数为3次 self._retry = 3 # 设置请求超时时间为60秒 self._timeout = aiohttp.ClientTimeout(60) # Access Key访问火山云资源的秘钥,可从访问控制-API访问密钥获取获取 # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK # 从环境变量中获取访问密钥(AK) self.ak = os.environ.get("YOUR_ACCESS_KEY") # 从环境变量中获取秘密密钥(SK) self.sk = os.environ.get("YOUR_SECRET_KEY") # 设置模型名称 self.model = "doubao-pro-32k" # 设置模型版本 self.model_version = "240615" # 需要替换为您的账号id,可从火山引擎官网点击账号头像,弹出框中找到,复制“账号ID”后的一串数字 self.account_id = "<YOUR_ACCOUNT_ID>" # 设置API版本 self.version = "2024-01-01" # 设置服务域名 self.domain = "open.volcengineapi.com" # 设置区域 self.region = "cn-beijing" # 设置服务名称 self.service = "ark" # 设置基础参数,包括API版本和账号ID self.base_param = {"Version": self.version, "X-Account-Id": self.account_id} async def _call(self, url, headers, req: Dict[str, Any]): """ 异步调用指定URL的HTTP POST请求,并处理请求的重试逻辑。 :param url: 请求的目标URL。 :param headers: 请求的HTTP头部信息。 :param req: 请求的JSON格式数据。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ @backoff.on_exception( backoff.expo, Exception, factor=0.1, max_value=5, max_tries=self._retry ) async def _retry_call(body): """ 内部函数,用于发送HTTP POST请求,并处理请求的重试逻辑。 :param body: 请求的JSON格式数据。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ async with aiohttp.request( method="POST", url=url, json=body, headers=headers, timeout=self._timeout, ) as response: try: return await response.json() except Exception as e: raise e try: return await _retry_call(req) except Exception as e: raise e async def create_batch_inference_job( self, bucket_name, input_object_key, output_object_key: str ): """ 异步创建批量推理任务。 :param bucket_name: 存储桶名称。 :param input_object_key: 输入文件的对象键。 :param output_object_key: 输出文件的对象键。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ action = "CreateBatchInferenceJob" canonicalQueryString = "Action={}&Version={}&X-Account-Id={}".format( action, self.version, self.account_id ) url = "https://" + self.domain + "/?" + canonicalQueryString extra_param = { "Action": action, "ProjectName": "default", "Name": "just_test", "ModelReference": { "FoundationModel": {"Name": self.model, "ModelVersion": self.model_version}, }, "InputFileTosLocation": { "BucketName": bucket_name, "ObjectKey": input_object_key, }, "OutputDirTosLocation": { "BucketName": bucket_name, "ObjectKey": output_object_key, }, "CompletionWindow": "3d", } param = self.base_param | extra_param payloadSign = utils.get_hmac_encode16(json.dumps(param)) headers = utils.get_hashmac_headers( self.domain, self.region, self.service, canonicalQueryString, "POST", "/", "application/json; charset=utf-8", payloadSign, self.ak, self.sk, ) return await self._call(url, headers, param) async def ListBatchInferenceJobs(self, phases=None): """ 异步列出批量推理任务。 :param phases: 任务阶段列表,默认为空列表。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ # 如果phases为None,则初始化为空列表 if phases is None: phases = [] # 设置操作名称为ListBatchInferenceJobs action = "ListBatchInferenceJobs" # 构建规范查询字符串,包含操作名称、API版本和账号ID canonicalQueryString = "Action={}&Version={}&X-Account-Id={}".format( action, self.version, self.account_id ) # 构建请求URL url = "https://" + self.domain + "/?" + canonicalQueryString # 构建额外参数,包括操作名称、项目名称和过滤器 extra_param = { "Action": action, "ProjectName": "default", "Filter": {"Phases": phases}, } # 合并基础参数和额外参数 param = self.base_param | extra_param # 计算请求体的签名 payloadSign = utils.get_hmac_encode16(json.dumps(param)) # 获取请求头,包含签名信息 headers = utils.get_hashmac_headers( self.domain, self.region, self.service, canonicalQueryString, "POST", "/", "application/json; charset=utf-8", payloadSign, self.ak, self.sk, ) # 调用_call方法发送请求并返回响应 return await self._call(url, headers, param) headers = utils.get_hashmac_headers( self.domain, self.region, self.service, canonicalQueryString, "POST", "/", "application/json; charset=utf-8", payloadSign, self.ak, self.sk, ) return await self._call(url, headers, param)
构造完成TOS Client
和BatchInference Client
后,您可以直接调用TOS Client
上传您的批量推理文件。然后调用BatchInference Client
import uvloop from bytedance_ark_batch_inference.client import BatchInferenceClient from bytedance_tos.tos_client import TosClient async def main(): tos_client = TosClient() batch_inference_client = BatchInferenceClient() # put object bucket_name = "demo-bucket-test" object_key = "data.jsonl" tos_client.put_object_from_file(bucket_name, object_key, "data.jsonl") # create batch job # output key should be existed output_key = "output/" response = await batch_inference_client.create_batch_inference_job(bucket_name, object_key, output_key) response = await batch_inference_client.ListBatchInferenceJobs(['Running']) print('done') if __name__ == '__main__': uvloop.run(main())
您可以参考火山方舟的 API 说明创建批量推理任务API,使用GetBatchInferenceJob
Status Code | 状态 | 描述 |
Queued | 排队中 | 任务由于账号下并发任务数达到上限等原因需排队等候 |
Running | 运行中 | 任务正在运行中 |
Completed | 完成 | 所有请求已经处理完毕,任务已完成 |
Terminating | 终止中 | 由于到期等系统原因或手动终止,任务当前处于终止中状态 |
Terminated | 已终止 | 任务已被取消 |
Failed | 失败 | 输入文件校验失败或其他原因导致任务失败 |
在批量推理任务运行结束后,可点击 查看结果 按钮或在详情页的「文件信息」中点击跳转至 TOS 页面查看并下载输出文件。
结果文件的 TOS 存储路径如下:
bucket_name = "demo-bucket-test" # 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt object_key = "output/bi-20241111191820-tfjbg/output/results.jsonl" # 本地文件完整路径,例如usr/local/testfile.txt file_path = "/usr/local/testfile.txt" try: # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现 client = tos.TosClientV2(ak, sk, endpoint, region) def percentage(consumed_bytes, total_bytes, rw_once_bytes, type: DataTransferType): if total_bytes: rate = int(100 * float(consumed_bytes) / float(total_bytes)) print("rate:{}, consumed_bytes:{},total_bytes{}, rw_once_bytes:{}, type:{}".format(rate, consumed_bytes, total_bytes, rw_once_bytes, type)) client.download_file(bucket_name, object_key, file_path, # 通过可选参数part_size配置下载时分片大小,默认为20mb part_size=1024 * 1024 * 20, # 通过可选参数task_num配置下载分片的线程数,默认为1 task_num=3, # 通过可选参数data_transfer_listener配置进度条 data_transfer_listener=percentage) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e))