通过 batch Inference job 接口,创建批量推理任务。详细介绍请参见批量推理说明。
批量推理任务支持jsonl格式的文件,具体说明见数据文件格式说明。
{"custom_id": "request-1", "body": {"messages": [{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000,"top_p":1}} {"custom_id": "request-2", "body": {"messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "天空为什么这么蓝?"}],"max_tokens": 1000}}
为避免因为文件格式错误,导致批量推理任务失败,可以使用下面脚本进行校验。
import json def check_jsonl_file(file_path): with open(file_path, "r", encoding="utf-8") as file: total = 0 custom_id_set = set() for line in file: if line.strip() == "": continue try: line_dict = json.loads(line) except json.decoder.JSONDecodeError: raise Exception(f"批量推理输入文件格式错误,第{total+1}行非json数据") if not line_dict.get("custom_id"): raise Exception(f"批量推理输入文件格式错误,第{total+1}行custom_id不存在") if line_dict.get("custom_id") in custom_id_set: raise Exception( f"批量推理输入文件格式错误,custom_id={line_dict.get('custom_id', '')}存在重复" ) else: custom_id_set.add(line_dict.get("custom_id")) if not isinstance(line_dict.get("body", ""), dict): raise Exception( f"批量推理输入文件格式错误,custom_id={line_dict.get('custom_id', '')}的body非json字符串" ) total += 1 return total # 替换<YOUR_JSONL_FILE>为你的JSONL文件路径 file_path = "<YOUR_JSONL_FILE>" total_lines = check_jsonl_file(file_path) print(f"文件中有效JSON数据的行数为: {total_lines}")
您需要将批量处理任务的输入文件传入到TOS中,您可以选择下面方式上传文件:
为了方便您实现,我们提供了一些上传文件的python脚本,请参见2.TOS上传数据。
完成上传后,您需要记录几个信息.
<INPUT_BUCKET_NAME>
<INPUT_FILE>
<OUTPUT_BUCKET_NAME>
<OUTPUT_DIR>
TOS 中文件夹以“/”结尾,如
output/
。
获取了上面的信息后,您可以使用接口CreateBatchInferenceJob - 创建批量推理任务来创建批量推理任务。
其中绿色字体,需要您进行替换,其中X-Content-Sha256
、Authorization
这两部分需要您签名和计算完成,具体请参见API签名调用指南。
POST /?Action=CreateBatchInferenceJob&Version=2024-01-01 HTTP/1.1 Host: open.volcengineapi.com Content-Type: application/json; charset=UTF-8 X-Date: 20240514T132743Z X-Content-Sha256: 287e874e******d653b44d21e Authorization: HMAC-SHA256 Credential=Adfks******wekfwe/20240514/cn-beijing/ark/request, SignedHeaders=host;x-content-sha256;x-date, Signature=47a7d934ff7b37c03938******cd7b8278a40a1057690c401e92246a0e41085f { "Name": "你的批量推理任务", "Description": "你的批量推理任务", "ModelReference": { "FoundationModel": { "Name": "doubao-pro-32k", "ModelVersion": "240615" } }, "InputFileTosLocation": { "BucketName": "<INPUT_BUCKET_NAME>", "ObjectKey": "<INPUT_FILE>" }, "OutputDirTosLocation": { "BucketName": "<OUTPUT_BUCKET_NAME>", "ObjectKey": "<OUTPUT_DIR>" }, "ProjectName":"default", "CompletionWindow": "1d", "Tags": [ { "Key": "test_key", "Value": "test_value" } ] }
下面是通过 Python 实现签名以及创建批量推理任务的工程文件。
<YOUR_ACCOUNT_ID>
为您的账号ID:获取账号ID,可以单击官网右上角头像获取。self.model = "doubao-pro-32k"
,self.model_version = "240615"
为您需要使用的模型名称以及模型版本,具体支持批量推理的模型见适用模型。import json from typing import Dict, Any import aiohttp import backoff from common import utils import os class BatchInferenceClient: def __init__(self): """ 初始化BatchInferenceClient类的实例。 该方法设置了一些默认属性,如重试次数、超时时间、访问密钥(AK/SK)、账号ID、API版本、服务域名、区域和基础参数。 访问密钥(AK/SK)从环境变量中获取,以提高安全性。 基础参数包括API版本和账号ID,这些参数在每次请求中都会用到。 """ # 设置重试次数为3次 self._retry = 3 # 设置请求超时时间为60秒 self._timeout = aiohttp.ClientTimeout(60) # Access Key访问火山云资源的秘钥,可从访问控制-API访问密钥获取获取 # 为了保障安全,强烈建议您不要在代码中明文维护您的AKSK # 从环境变量中获取访问密钥(AK) self.ak = os.environ.get("YOUR_ACCESS_KEY") # 从环境变量中获取秘密密钥(SK) self.sk = os.environ.get("YOUR_SECRET_KEY") # 设置模型名称 self.model = "doubao-pro-32k" # 设置模型版本 self.model_version = "240615" # 需要替换为您的账号id,可从火山引擎官网点击账号头像,弹出框中找到,复制“账号ID”后的一串数字 self.account_id = "<YOUR_ACCOUNT_ID>" # 设置API版本 self.version = "2024-01-01" # 设置服务域名 self.domain = "open.volcengineapi.com" # 设置区域 self.region = "cn-beijing" # 设置服务名称 self.service = "ark" # 设置基础参数,包括API版本和账号ID self.base_param = {"Version": self.version, "X-Account-Id": self.account_id} async def _call(self, url, headers, req: Dict[str, Any]): """ 异步调用指定URL的HTTP POST请求,并处理请求的重试逻辑。 :param url: 请求的目标URL。 :param headers: 请求的HTTP头部信息。 :param req: 请求的JSON格式数据。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ @backoff.on_exception( backoff.expo, Exception, factor=0.1, max_value=5, max_tries=self._retry ) async def _retry_call(body): """ 内部函数,用于发送HTTP POST请求,并处理请求的重试逻辑。 :param body: 请求的JSON格式数据。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ async with aiohttp.request( method="POST", url=url, json=body, headers=headers, timeout=self._timeout, ) as response: try: return await response.json() except Exception as e: raise e try: return await _retry_call(req) except Exception as e: raise e async def create_batch_inference_job( self, input_bucket_name,output_bucket_name, input_object_key, output_object_key: str ): """ 异步创建批量推理任务。 :param input_bucket_name: 任务JSONL文件存储的桶名称。 :param input_bucket_name: 任务结果存储的桶名称。 :param input_object_key: 输入文件的对象键。 :param output_object_key: 输出文件的对象键。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ action = "CreateBatchInferenceJob" canonicalQueryString = "Action={}&Version={}&X-Account-Id={}".format( action, self.version, self.account_id ) url = "https://" + self.domain + "/?" + canonicalQueryString extra_param = { "Action": action, "ProjectName": "default", "Name": "just_test", "ModelReference": { "FoundationModel": {"Name": self.model, "ModelVersion": self.model_version}, }, "InputFileTosLocation": { "BucketName": input_bucket_name, "ObjectKey": input_object_key, }, "OutputDirTosLocation": { "BucketName": output_bucket_name, "ObjectKey": output_object_key, }, "CompletionWindow": "3d", } param = self.base_param | extra_param payloadSign = utils.get_hmac_encode16(json.dumps(param)) headers = utils.get_hashmac_headers( self.domain, self.region, self.service, canonicalQueryString, "POST", "/", "application/json; charset=utf-8", payloadSign, self.ak, self.sk, ) return await self._call(url, headers, param) async def ListBatchInferenceJobs(self, phases=None): """ 异步列出批量推理任务。 :param phases: 任务阶段列表,默认为空列表。 :return: 响应的JSON数据。 :raises Exception: 如果请求失败或解析响应失败,抛出异常。 """ # 如果phases为None,则初始化为空列表 if phases is None: phases = [] # 设置操作名称为ListBatchInferenceJobs action = "ListBatchInferenceJobs" # 构建规范查询字符串,包含操作名称、API版本和账号ID canonicalQueryString = "Action={}&Version={}&X-Account-Id={}".format( action, self.version, self.account_id ) # 构建请求URL url = "https://" + self.domain + "/?" + canonicalQueryString # 构建额外参数,包括操作名称、项目名称和过滤器 extra_param = { "Action": action, "ProjectName": "default", "Filter": {"Phases": phases}, } # 合并基础参数和额外参数 param = self.base_param | extra_param # 计算请求体的签名 payloadSign = utils.get_hmac_encode16(json.dumps(param)) # 获取请求头,包含签名信息 headers = utils.get_hashmac_headers( self.domain, self.region, self.service, canonicalQueryString, "POST", "/", "application/json; charset=utf-8", payloadSign, self.ak, self.sk, ) # 调用_call方法发送请求并返回响应 return await self._call(url, headers, param)
修改工程文件中的main.py文件:根据示例代码中的绿色字体和注释,为您的配置。
import uvloop from bytedance_ark_batch_inference.client import BatchInferenceClient from bytedance_tos.tos_client import TosClient async def main(): tos_client = TosClient() batch_inference_client = BatchInferenceClient() # 上传您您本地或者远程存储的 JSONL 文件,至 TOS 桶中 # 替换为您的 TOS 桶名称 input_bucket_name = "<INPUT_BUCKET_NAME>" # 替换为您希望上传至TOS桶的路径和文件名 object_key = "<INPUT_FILE>" # 替换<YOUR_JSONL_FILE>为您本地或者远程存储的JSONL文件路径 tos_client.put_object_from_file(input_bucket_name, object_key, "<YOUR_JSONL_FILE>") # 创建批量推理任务,并将结果存储至您指定的 TOS 桶中 # 替换为替换为您希望存储批量推理结果 TOS 桶名称 output_bucket_name = "<OUTPUT_BUCKET_NAME>" # 替换为您希望存储批量推理结果的路径,例如:output/ output_key = "<OUTPUT_DIR>" response = await batch_inference_client.create_batch_inference_job(input_bucket_name,output_bucket_name, object_key, output_key) response = await batch_inference_client.ListBatchInferenceJobs(['Running']) print('done') if __name__ == '__main__': uvloop.run(main())
完成配置,在终端中运行python main.py
,运行脚本。