前言
Weaviate 是一种低延迟的向量搜索引擎,支持不同的媒体类型(文本、图像等)。它提供语义搜索、问答提取、分类、可定制模型(PyTorch/TensorFlow/Keras)等功能。Weaviate 从头开始使用 Go 构建,可以存储对象和向量,允许将向量搜索与结构化过滤器和云原生数据库的容错性结合起来。通过 GraphQL、REST 和各种客户端编程语言都可以访问它。
1. Weaviate 简介
Weaviate 是一种开源的类型向量搜索引擎数据库。
Weaviate 允许您以类属性的方式存储 JSON 文档,同时将机器学习向量附加到这些文档上,以在向量空间中表示它们。
Weaviate 可以独立使用,也可以与各种模块一起使用,这些模块可以为您进行向量化并扩展核心功能。
Weaviate 具有 GraphQL-API,以便轻松访问您的数据。
使用 pip install weaviate-client
安装 Python SDK
2.代码
2.1 docker 启动weaviate
docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.26.1
2.2 代码
import weaviate
import requests
import json
import pandas as pd
#from weaviate.connect import ConnectionParams
import os
import weaviate
from M3eEmbeddingGenerator import M3eEmbeddingGenerator
from tqdm import tqdm
# from langchain.vectorstores import Weaviate
class WeaviateDataManage():
def __init__(self,config):
self.config = config
self.url = self.config['url']
self.batch_size = self.config['batch_size']
self.class_name = self.config['class_name']
self.weaviate_client = self.init_client()
self.embedding_generator = self.get_embedding_generator()
def init_client(self):
#auth_config = weaviate.auth.AuthApiKey(api_key=WEAVIATE_API_KEY)
try:
client = weaviate.Client(
url=self.url,
#auth_client_secret=auth_config,
timeout_config=(50, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=self.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
print(client.is_ready())
return client
# 句子向量化
def get_embedding_generator(self):
model_url = 'flyme-aigc.flyme.com/Triton-Server'
model_name = "M3e_large_onnx"
tokenizer_model_name = "./m3e-large"
embedding_generator = M3eEmbeddingGenerator(model_url, model_name, tokenizer_model_name)
return embedding_generator
def sentence_embedding(self,sentence):
embedding_result = self.embedding_generator.generate_embedding(sentence)
return embedding_result[0].tolist()
def create_class(self, class_name):
class_obj = {
'class': class_name,
'vectorIndexConfig': {
#'distance': 'l2-squared',
'distance': 'cosine',
},
'properties': [
{
'name': 'sentence_id',
'dataType': ['int'],
'description': 'The ID of the sentence'
},
{
'name': 'all_content',
'dataType': ['text'],
'description': 'The content of the sentence'
}
]
}
try:
res = self.weaviate_client.schema.create_class(class_obj)
print(f"Class {class_name} created.")
return res
except weaviate.exceptions.UnexpectedStatusCodeException as e:
print(f"Error creating class {class_name}: {e}")
return None
#删除某一条
def del_text(self, uuid):
self.weaviate_client.data_object.delete(uuid=uuid, class_name=CLASS_NAME)
# 删除calss
def delete(self,class_name):
self.weaviate_client.schema.delete_class(class_name)
def text_exists(self, uuid):
result = self.weaviate_client.query.get(CLASS_NAME).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": uuid,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][CLASS_NAME]
if len(entries) == 0:
return False
return True
def get_embedding(self, query_list):
new_query_list = []
new_query_embedding_list = []
for d in tqdm(query_list):
new_query_list.append(d["query"])
new_query_embedding_list.append(self.sentence_embedding(d["query"]))
res = {"all_content": [], "embedding": [], "id": []}
for i, d in enumerate(query_list):
res["all_content"].append(d["query"])
res["embedding"].append(new_query_embedding_list[i])
res["id"].append(d["id"])
return res
# 数据导入
def write(self, data):
df = pd.DataFrame(data)
client = self.init_client()
with client.batch(
batch_size=self.batch_size
) as batch:
for i in range(df.shape[0]):
print('importing data: {}'.format(i + 1))
# 定义properties
properties = {
'sentence_id': int(df.id[i]), # 这里是句子id, [1, 2, 3, ...]
'all_content': df.all_content[i], # 这里是句子内容
}
custom_vector = df.embedding[i]# 这里是句子向量化后的数据
# 导入数据
batch.add_data_object(
properties,
class_name=CLASS_NAME,
vector=custom_vector
)
# print(batch_write)
# batch_write.removesuffix()
print('import completed')
return True
def search(self, query, number):
#print(query, number)
#query_embedding = self.sentence_embedding([query])[0]["embedding"]
query_embedding = self.sentence_embedding(query)
#print(len(query_embedding))
nearVector = {
'vector': query_embedding
}
response = (
self.weaviate_client.query
.get(self.class_name, ['sentence_id', 'all_content']) # 第一个参数为class名字,第二个参数为需要显示的信息
.with_near_vector(nearVector) # 使用向量检索,nearVector为输入问题的向量形式
.with_limit(number) # 返回个数(TopK),这里选择返回5个
.with_additional(['distance']) # 选择是否输出距离
.do()
)
res = []
for i in response["data"]["Get"]["User_manual"]:
#sentence = json.loads(i["all_content"])
sentence = (i["all_content"])
#res.append({"query": sentence["query"], "content": "\n".join(sentence["content"]),
#"distance": i["_additional"]["distance"]})
res.append({"query": query, "content": sentence,
"distance": i["_additional"]["distance"]})
return res
if __name__ == '__main__':
# 初始化配置
WEAVIATE_ENDPOINT = "http://10.10.10.10:8080"
WEAVIATE_BATCH_SIZE = 100
CLASS_NAME = "user_manual"
wea_config = {
"url": WEAVIATE_ENDPOINT,
"batch_size":WEAVIATE_BATCH_SIZE,
"class_name" : CLASS_NAME}
# 实例化WeaviateDataManage类
#weaviate_service = WeaviateDataManage()
data_manager = WeaviateDataManage(wea_config)
# 删除类先
data_manager.delete(CLASS_NAME)
# 创建类
data_manager.create_class(CLASS_NAME)
# 准备数据
data = []
file_path = "/home/zhenhengdong/WORk/Intelligent_customer_service/用户手册/Datasets/Using_data/0806_data_3.jsonl"
with open(file_path,'r') as f:
lines = f.readlines()
for line in lines:
current_line = json.loads(line)
data.append(current_line)
# 获取嵌入并导入数据
embedding_data = data_manager.get_embedding(data)
data_manager.write(embedding_data)
# 检索数据
query = "如何拨打电话"
number_of_results = 5
results = data_manager.search(query, number_of_results)
print(results)
#data_manager.delete(CLASS_NAME)
Reference: