You need to enable JavaScript to run this app.
导航
Embeddings
最近更新时间:2024.12.30 21:55:28首次发布时间:2024.08.12 14:42:27

前提条件


示例代码

说明

示例代码中 <YOUR_ENDPOINT_ID> 需要替换为您在平台上创建的推理接入点 ID。

向量化

package main

import (
    "context"
    "encoding/json"
    "fmt"
    "os"

    "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
    "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)

func main() {
    client := arkruntime.NewClientWithApiKey(
       os.Getenv("ARK_API_KEY"),
    )
    ctx := context.Background()

    fmt.Println("----- embeddings request -----")
    req := model.EmbeddingRequestStrings{
       Input: []string{
          "花椰菜又称菜花、花菜,是一种常见的蔬菜。",
       },
       Model: "<YOUR_ENDPOINT_ID>",
    }

    resp, err := client.CreateEmbeddings(ctx, req)
    if err != nil {
       fmt.Printf("embeddings error: %v\n", err)
       return
    }

    s, _ := json.Marshal(resp)
    fmt.Println(string(s))
}

doubao-embedding 模型调用示例代码

package main

import (
    "context"
    "fmt"
    "math"
    "os"

    "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
    "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)

func main() {
    client := arkruntime.NewClientWithApiKey(
       os.Getenv("ARK_API_KEY"),
       arkruntime.WithBaseUrl("${BASE_URL}"),
    )
    ctx := context.Background()

    fmt.Println("----- doubao embeddings request -----")
    queryInstruction := "为这个句子生成表示以用于检索相关文章:"
    query := "天是什么颜色?"
    document := "天空呈现颜色主要与“瑞利散射”现象有关,具体形成过程如下:太阳光是由红、橙、黄、绿、蓝、靛、紫等多种颜色的光混合而成的。大气中存在着无数的气体分子和其他微粒。当太阳光进入地球大气层时,波长较长的红光、橙光、黄光能穿透大气层,直接射到地面,而波长较短的蓝、紫、靛等色光,很容易被悬浮在空气中的微粒阻挡,从而使光线散射向四方。其中蓝光波长较短,散射作用更强,因此我们眼睛看到的天空主要呈现蓝色。在一些特殊情况下,如傍晚或早晨,阳光斜射角度大,通过大气层的路径较长,蓝光等短波长光被散射得更多,而红光等长波长光散射损失较少,这时天空可能会呈现橙红色等其他颜色。"
    req := model.EmbeddingRequestStrings{
       Input: []string{
          queryInstruction + query,
          document,
       },
       Model: "${YOUR_ENDPOINT_ID}",
    }

    resp, err := client.CreateEmbeddings(ctx, req)
    if err != nil {
       fmt.Printf("embeddings error: %v\n", err)
       return
    }

    embedding2048 := normalize(resp, 2048)
    score2048 := matmulVector(embedding2048[0].Embedding, embedding2048[1].Embedding)
    println(fmt.Sprintf("product: %f", score2048))

    embedding1024 := normalize(resp, 1024)
    score1024 := matmulVector(embedding1024[0].Embedding, embedding1024[1].Embedding)
    println(fmt.Sprintf("product: %f", score1024))

    embedding512 := normalize(resp, 512)
    score512 := matmulVector(embedding512[0].Embedding, embedding512[1].Embedding)
    println(fmt.Sprintf("product: %f", score512))
}

func normalize(resp model.EmbeddingResponse, dim int) []model.Embedding {
    newData := []model.Embedding{}
    for _, d := range resp.Data {
       embedding := d.Embedding
       if len(embedding) > dim {
          embedding = embedding[:dim]
       }

       newData = append(newData, model.Embedding{
          Object:    d.Object,
          Embedding: normalizeVector(embedding),
          Index:     d.Index,
       })
    }

    return newData
}

func normalizeVector(vector []float32) []float32 {
    var sum float32
    for _, v := range vector {
       sum += v * v
    }
    sum = float32(math.Sqrt(float64(sum)))

    if sum == 0 {
       return vector
    }

    var newVector []float32
    for _, v := range vector {
       newVector = append(newVector, v/sum)
    }
    return newVector
}

func matmulVector(vector1 []float32, vector2 []float32) float32 {
    var dotProduct float32
    for i := 0; i < len(vector1); i++ {
       dotProduct += vector1[i] * vector2[i]
    }

    return dotProduct
}

设置自定义 header

package main

import (
    "context"
    "fmt"
    "io"
    "os"

    "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
    "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
    "github.com/volcengine/volcengine-go-sdk/volcengine"
)


func main() {
    client := arkruntime.NewClientWithApiKey(
       os.Getenv("ARK_API_KEY"),
    )
    ctx := context.Background()

    fmt.Println("----- embeddings request -----")
    req := model.EmbeddingRequestStrings{
       Input: []string{
          "花椰菜又称菜花、花菜,是一种常见的蔬菜。",
       },
       Model: "<YOUR_ENDPOINT_ID>",
    }

    resp, err := client.CreateEmbeddings(
        ctx,
        req,
        arkruntime.WithCustomHeader(model.ClientRequestHeader, "20240627112200D3FE1CFF83C5D5523085"),
    )
    if err != nil {
       fmt.Printf("embeddings error: %v\n", err)
       return
    }

    s, _ := json.Marshal(resp)
    fmt.Println(string(s))
}

异常处理

package main

import (
    "context"
    "errors"
    "fmt"
    "io"
    "os"

    "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
    "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
    "github.com/volcengine/volcengine-go-sdk/volcengine"
)

func main() {
    client := arkruntime.NewClientWithApiKey(
       os.Getenv("ARK_API_KEY"),
    )
    ctx := context.Background()

    fmt.Println("----- embeddings request -----")
    req := model.EmbeddingRequestStrings{
       Input: []string{
          "花椰菜又称菜花、花菜,是一种常见的蔬菜。",
       },
       Model: "<YOUR_ENDPOINT_ID>",
    }

    resp, err := client.CreateEmbeddings(ctx, req)
    if err != nil {
       apiErr := &model.APIError{}
       if errors.As(err, &apiErr) {
          fmt.Printf("embeddings error: %v\n", apiErr)
       }
       return
    }
    s, _ := json.Marshal(resp)
    fmt.Println(string(s))
}