批量推理是一种高效的推理模式,可通过控制台操作或 OpenAPI 创建异步任务对请求进行处理。批量推理支持多种模型,且支持与对应模型在线推理一致的参数配置选项。
批量推理提供更高的速率限制和更大的每日吞吐,每个批量推理任务的运行时间最长可配置至28天,特别适用于需要处理大量数据且无需实时获取结果的场景,例如日志分析、离线数据预测和评测等。
本实践教程将通过实际示例,带您了解并逐步了解批量推理的使用方法。示例基于 Python,以 Doubao-pro-32k/240615 模型为例,展示如何进行批量推理,主要涉及到两个部分:
请参考 输入文件格式说明 准备包含您要进行推理的请求的输入文件,并确保账户已启用火山引擎对象存储 TOS 服务。
以下是一个输入文件示例,包含 2 个请求:
{"custom_id": "request-1", "body": {"messages": [{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000,"top_p":1}} {"custom_id": "request-2", "body": {"messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000}}
使用如下代码对输入数据进行前置检测
# input_file_name-文件的完整路径 with open(input_file_name, 'r', encoding='utf-8') as file: total = 0 custom_id_set = set() for line in file: if line.strip() == '': continue try: line_dict = json.loads(line) except json.decoder.JSONDecodeError: raise Exception(f"批量推理任务失败,第{total+1}行非json数据") if not line_dict.get('custom_id'): raise Exception(f"批量推理任务失败,第{total+1}行custom_id不存在") if line_dict.get('custom_id') in custom_id_set: raise Exception(f"批量推理任务失败,custom_id={line_dict.get('custom_id', '')}存在重复") else: custom_id_set.add(line_dict.get('custom_id')) if not isinstance(line_dict.get('body', ''), dict): raise Exception(f"批量推理任务失败,custom_id={line_dict.get('custom_id', '')}的body非json字符串") total += 1 return total
数据上传可选择通过 TOS 控制台或者 OpenAPI:
# 通过 pip 安装 pip install tos # 通过源码安装, 从 Github 下载相应版本的 SDK 包,解压后进入目录,执行如下命令。 python3 setup.py install
import os import tos # AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取 # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK ak = os.environ.get("YOUR_ACCESS_KEY") sk = os.environ.get("YOUR_SECRET_KEY") # your endpoint 和 your region 填写Bucket 所在区域对应的Endpoint。 # 以华北2(北京)为例,your endpoint 填写 tos-cn-beijing.volces.com,your region 填写 cn-beijing。 endpoint = "tos-cn-beijing.volces.com" region = "cn-beijing" client = tos.TosClientV2(ak, sk, endpoint, region)
bucket_name = "demo-bucket-test" # 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt object_key = "data.jsonl" # 本地文件路径 file_name = "/usr/local/data.jsonl" try: # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现 client = tos.TosClientV2(ak, sk, endpoint, region) # 将本地文件上传到目标桶中 # file_name为本地文件的完整路径。 client.put_object_from_file(bucket_name, object_key, file_name) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e))
说明
# 断点续传demo bucket_name = "demo-bucket-test" # 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt object_key = "data.jsonl" # 本地文件完整路径,例如usr/local/testfile.txt filename = "/usr/local/testfile.txt" try: # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现 client = tos.TosClientV2(ak, sk, endpoint, region) def percentage( consumed_bytes: int, total_bytes: int, rw_once_bytes: int, type: DataTransferType, ): if total_bytes: rate = int(100 * float(consumed_bytes) / float(total_bytes)) print( "rate:{}, consumed_bytes:{},total_bytes{}, rw_once_bytes:{}, type:{}".format( rate, consumed_bytes, total_bytes, rw_once_bytes, type ) ) # 可通过part_size可选参数指定分片大小 # 通过enable_checkpoint参数开启和关闭断点续传特性 # 通过task_num设置线程数 client.upload_file( bucket_name, object_key, filename, # 设置断点续传执行线程数,默认为1 task_num=3, # 设置断点续传分片大小,默认20mb part_size=1024 * 1024 * 5, # 设置断点续传进度条回调函数 data_transfer_listener=percentage, ) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print("fail with client error, message:{}, cause: {}".format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print("fail with server error, code: {}".format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print("error with request id: {}".format(e.request_id)) print("error with message: {}".format(e.message)) print("error with http code: {}".format(e.status_code)) print("error with ec: {}".format(e.ec)) print("error with request url: {}".format(e.request_url)) except Exception as e: print("fail with unknown error: {}".format(e))
您有3种方式创建批量推理任务,您可以根据自己的需求灵活选择。
参考批量推理教程,在“批量推理”页面,点击左上角 创建批量任务 按钮跳转至创建页。
账户维度有如下配额:
如希望调整配额,您可以前往配额中心进行申请。
参考创建批量推理任务API进行批量推理任务的创建,参考示例如下:
POST /?Action=CreateBatchInferenceJob&Version=2024-01-01 HTTP/1.1 Host: open.volcengineapi.com Content-Type: application/json; charset=UTF-8 X-Date: 20240514T132743Z X-Content-Sha256: 287e874e******d653b44d21e Authorization: HMAC-SHA256 Credential=Adfks******wekfwe/20240514/cn-beijing/ark/request, SignedHeaders=host;x-content-sha256;x-date, Signature=47a7d934ff7b37c03938******cd7b8278a40a1057690c401e92246a0e41085f { "Name": "批量推理任务", "Description": "这是一个批量推理任务", "ModelReference": { "FoundationModel": { "Name": "doubao-pro-32k", "ModelVersion": "240615" } }, "InputFileTosLocation": { "BucketName": "demo-bucket-test", "ObjectKey": "data.jsonl" }, "OutputDirTosLocation": { "ObjectKey": "output/", "BucketName": "demo-bucket-test" }, "ProjectName":"default", "CompletionWindow": "1d", "Tags": [ { "Key": "test_key", "Value": "test_value" } ] }
下面为您提供提供了一种封装 TOS SDK 和批量推理 API 的方式,以便您可以方便地执行自己的批量推理任务。
参考上传本地文件:普通上传构造TOS Client,如示例代码:
import tos import os class TosClient: # 文档链接:https://www.volcengine.com/docs/6349/92786 def __init__(self): # AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取 # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK self.ak = os.environ.get("YOUR_ACCESS_KEY") self.sk = os.environ.get("YOUR_SECRET_KEY") # your endpoint 和 your region 填写Bucket 所在区域对应的Endpoint。# 以华北2(北京)为例,your endpoint 填写 tos-cn-beijing.volces.com,your region 填写 cn-beijing。 self.endpoint = "tos-cn-beijing.volces.com" self.region = "cn-beijing" self.client = tos.TosClientV2(self.ak, self.sk, self.endpoint, self.region) def create_bucket(self, bucket_name): try: # 设置桶存储桶读写权限 self.client.create_bucket(bucket_name, acl=tos.ACLType.ACL_Public_Read_Write) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e)) def put_object_from_file(self, bucket_name, object_key, file_path): try: # 通过字符串方式添加 Object res = self.client.put_object_from_file(bucket_name, object_key, file_path) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e)) def get_object(self, bucket_name, object_name): try: # 从TOS bucket中下载对象到内存中 object_stream = self.client.get_object(bucket_name, object_name) # object_stream 为迭代器可迭代读取数据 # for content in object_stream: # print(content) # 您也可调用 read()方法一次在内存中获取完整的数据 print(object_stream.read()) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e)) def close_client(self): try: # 执行相关操作后,将不再使用的TosClient关闭 self.client.close() except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e))
构造一个BatchInferenceClient
,示例代码包含了CreateBatchInferenceJob
和ListBatchInferenceJobs
两个接口,您也可以根据创建批量推理任务API添加更多API。
import json from typing import Dict, Any import aiohttp import backoff from common import utils class BatchInferenceClient: def __init__(self): self._retry = 3 self._timeout = aiohttp.ClientTimeout(60) # AKSK是访问火山云资源的秘钥,可从访问控制-API访问密钥获取 # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK self.ak = os.environ.get("YOUR_ACCESS_KEY") self.sk = os.environ.get("YOUR_SECRET_KEY") # 需要替换为您的账号id,可从火山引擎官网点击账号头像,弹出框中找到,复制“账号ID”后的一串数字 self.account_id = 'Your_ACCOUNT_ID' self.version = '2024-01-01' self.domain = 'open.volcengineapi.com' self.region = 'cn-beijing' self.service = 'ark' self.base_param = {'Version': self.version, 'X-Account-Id': self.account_id} async def _call(self, url, headers, req: Dict[str, Any]): @backoff.on_exception(backoff.expo, Exception, factor=0.1, max_value=5, max_tries=self._retry) async def _retry_call(body): async with aiohttp.request(method='POST', url=url, json=body, headers=headers, timeout=self._timeout) as response: try: return await response.json() except Exception as e: raise e try: return await _retry_call(req) except Exception as e: raise e async def create_batch_inference_job(self, bucket_name, input_object_key, output_object_key: str): action = 'CreateBatchInferenceJob' canonicalQueryString = 'Action={}&Version={}&X-Account-Id={}'.format( action, self.version, self.account_id) url = 'https://' + self.domain + '/?' + canonicalQueryString extra_param = { 'Action': action, 'ProjectName': 'default', 'Name': 'just_test', 'ModelReference': { 'FoundationModel': { 'Name': 'doubao-pro-32k', 'ModelVersion': '240615' } }, 'InputFileTosLocation': { 'BucketName': bucket_name, 'ObjectKey': input_object_key }, 'OutputDirTosLocation': { 'BucketName': bucket_name, 'ObjectKey': output_object_key }, 'CompletionWindow': '3d' } param = self.base_param | extra_param payloadSign = utils.get_hmac_encode16(json.dumps(param)) headers = utils.get_hashmac_headers(self.domain, self.region, self.service, canonicalQueryString, 'POST', '/', 'application/json; charset=utf-8', payloadSign, self.ak, self.sk) return await self._call(url, headers, param) async def ListBatchInferenceJobs(self, phases=None): if phases is None: phases = [] action = 'ListBatchInferenceJobs' canonicalQueryString = 'Action={}&Version={}&X-Account-Id={}'.format( action, self.version, self.account_id) url = 'https://' + self.domain + '/?' + canonicalQueryString extra_param = { 'Action': action, 'ProjectName': 'default', 'Filter': { 'Phases': phases } } param = self.base_param | extra_param payloadSign = utils.get_hmac_encode16(json.dumps(param)) headers = utils.get_hashmac_headers(self.domain, self.region, self.service, canonicalQueryString, 'POST', '/', 'application/json; charset=utf-8', payloadSign, self.ak, self.sk) return await self._call(url, headers, param)
构造完成TOS Client
和BatchInference Client
后,您可以直接调用TOS Client
上传您的批量推理文件。然后调用BatchInference Client
创建一个批量推理任务,之后可以调用ListBatchInferenceJobs
查看运行中的批量推理任务。
import uvloop from bytedance_ark_batch_inference.client import BatchInferenceClient from bytedance_tos.tos_client import TosClient async def main(): tos_client = TosClient() batch_inference_client = BatchInferenceClient() # put object bucket_name = "demo-bucket-test" object_key = "data.jsonl" tos_client.put_object_from_file(bucket_name, object_key, "data.jsonl") # create batch job # output key should be existed output_key = "output/" response = await batch_inference_client.create_batch_inference_job(bucket_name, object_key, output_key) response = await batch_inference_client.ListBatchInferenceJobs(['Running']) print('done') if __name__ == '__main__': uvloop.run(main())
您可以参考火山方舟的 API 说明创建批量推理任务API,使用GetBatchInferenceJob
接口查看已创建的批量推理任务的状态,参考示例,获取批量推理任务的状态信息:
批量推理任务状态与对应描述如下:
Status Code | 状态 | 描述 |
---|---|---|
Queued | 排队中 | 任务由于账号下并发任务数达到上限等原因需排队等候 |
Running | 运行中 | 任务正在运行中 |
Completed | 完成 | 所有请求已经处理完毕,任务已完成 |
Terminating | 终止中 | 由于到期等系统原因或手动终止,任务当前处于终止中状态 |
Terminated | 已终止 | 任务已被取消 |
Failed | 失败 | 输入文件校验失败或其他原因导致任务失败 |
您也可以通过火山方舟控制台,在批量推理列表页或任务详情页查询创建的批量推理任务的状态信息。
在批量推理任务运行结束后,可点击 查看结果 按钮或在详情页的「文件信息」中点击跳转至 TOS 页面查看并下载输出文件。
结果文件的 TOS 存储路径如下:
bucket_name = "demo-bucket-test" # 对象名称,例如 example_dir 下的 example_object.txt 文件,则填写为 example_dir/example_object.txt object_key = "output/bi-20241111191820-tfjbg/output/results.jsonl" # 本地文件完整路径,例如usr/local/testfile.txt file_path = "/usr/local/testfile.txt" try: # 创建 TosClientV2 对象,对桶和对象的操作都通过 TosClientV2 实现 client = tos.TosClientV2(ak, sk, endpoint, region) def percentage(consumed_bytes, total_bytes, rw_once_bytes, type: DataTransferType): if total_bytes: rate = int(100 * float(consumed_bytes) / float(total_bytes)) print("rate:{}, consumed_bytes:{},total_bytes{}, rw_once_bytes:{}, type:{}".format(rate, consumed_bytes, total_bytes, rw_once_bytes, type)) client.download_file(bucket_name, object_key, file_path, # 通过可选参数part_size配置下载时分片大小,默认为20mb part_size=1024 * 1024 * 20, # 通过可选参数task_num配置下载分片的线程数,默认为1 task_num=3, # 通过可选参数data_transfer_listener配置进度条 data_transfer_listener=percentage) except tos.exceptions.TosClientError as e: # 操作失败,捕获客户端异常,一般情况为非法请求参数或网络异常 print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) except tos.exceptions.TosServerError as e: # 操作失败,捕获服务端异常,可从返回信息中获取详细错误信息 print('fail with server error, code: {}'.format(e.code)) # request id 可定位具体问题,强烈建议日志中保存 print('error with request id: {}'.format(e.request_id)) print('error with message: {}'.format(e.message)) print('error with http code: {}'.format(e.status_code)) print('error with ec: {}'.format(e.ec)) print('error with request url: {}'.format(e.request_url)) except Exception as e: print('fail with unknown error: {}'.format(e))