技术背景介绍
在人工智能的自然语言处理(NLP)任务中,生成文本嵌入是一个重要步骤。Amazon SageMaker 提供了一种简便的方法来托管自定义NLP模型,并通过其端点生成嵌入。这篇文章将介绍如何在SageMaker上托管Hugging Face模型,并通过自定义的类来进行嵌入生成。
核心原理解析
SageMaker 端点允许用户托管预训练的机器学习模型并通过API进行调用。通过使用SagemakerEndpointEmbeddings
类,我们可以与托管在SageMaker上的模型交互以生成文本嵌入。
为了处理批量请求,需要对inference.py
脚本中的predict_fn()
函数进行调整。将返回行由:
return {"vectors": sentence_embeddings[0].tolist()}
更改为:
return {"vectors": sentence_embeddings.tolist()}
这确保了返回的嵌入向量是批量化的。
代码实现演示(重点)
以下是一个完整的示例,展示如何配置和使用SageMaker端点来生成文本嵌入:
!pip3 install langchain boto3
import json
from typing import Dict, List
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["vectors"]
content_handler = ContentHandler()
embeddings = SagemakerEndpointEmbeddings(
endpoint_name="huggingface-pytorch-inference-2023-03-21-16-14-03-834",
region_name="us-east-1",
content_handler=content_handler,
)
query_result = embeddings.embed_query("foo")
doc_results = embeddings.embed_documents(["foo"])
print(doc_results)
代码说明:
ContentHandler
类用于定义输入和输出的转换方式。SagemakerEndpointEmbeddings
用于与SageMaker端点交互,生成嵌入。
应用场景分析
此实现特别适用于需要生成大规模文本嵌入以支持搜索、推荐系统以及上下文相关分析的应用场景。通过SageMaker的弹性计算能力,可以有效处理大量的嵌入生成请求。
实践建议
- 初始化优化: 确保使用的是最新版本的SageMaker和相关的Python包,如
boto3
。 - 安全性: 务必妥善管理和存储你的AWS凭证,建议使用IAM角色分离权限。
- 性能提升: 调整
inference.py
以适应批量处理请求,提升吞吐量。
如果遇到问题欢迎在评论区交流。
—END—