Anthropic-cookbook:Skills:Contextual-embeddings:Guide.ipynb at Main · Anthropics
Anthropic-cookbook:Skills:Contextual-embeddings:Guide.ipynb at Main · Anthropics
main
In traditional RAG, documents are typically split into smaller chunks for efficient
retrieval. While this approach works well for many applications, it can lead to
problems when individual chunks lack sufficient context. Contextual
Embeddings solve this problem by adding relevant context to each chunk before
embedding. This method improves the quality of each embedded chunk,
allowing for more accurate retrieval and thus better overall performance.
Averaged across all data sources we tested, Contextual Embeddings reduced
the top-20-chunk retrieval failure rate by 35%.
The same chunk-specific context can also be used with BM25 search to further
improve retrieval performance. We introduce this technique in the “Contextual
BM25” section.
In this guide, we'll demonstrate how to build and optimize a Contextual Retrieval
system using a dataset of 9 codebases as our knowledge base. We'll walk
through:
2. Contextual Embeddings: what it is, why it works, and how prompt caching
makes it practical for production use cases.
Additional Notes:
Prompt caching is helpful in managing costs when using this retrieval method.
This feature is currently available on Anthropic's 1P API, and is coming soon to
our 3P partner environments in AWS Bedrock and GCP Vertex. We know that
many of our customers leverage AWS Knowledge Bases and GCP Vertex AI APIs
when building RAG solutions, and this method can be used on either platform
with a bit of customization. Consider reaching out to Anthropic or your
AWS/GCP account team for guidance on this!
To make it easier to use this method on Bedrock, the AWS team has provided us
with code that you can use to implement a Lambda function that adds context
with code that you can use to implement a Lambda function that adds context
to each document. If you deploy this Lambda function, you can select it as a
custom chunking option when configuring a Bedrock Knowledge Base. You can
find this code in contextual-rag-lambda-function . The main lambda
function code is in lambda_function.py .
Table of Contents
1. Setup
2. Basic RAG
3. Contextual Embeddings
4. Contextual BM25
5. Reranking
Setup
We'll need a few libraries, including:
You'll also need API keys from Anthropic, Voyage AI, and Cohere
In [2]: import os
client = anthropic.Anthropic(
# This is the default and can be omitted
api_key=os.getenv("ANTHROPIC_API_KEY"),
)
In [4]: import os
import pickle
import json
import numpy as np
import voyageai
from typing import List, Dict, Any
from tqdm import tqdm
class VectorDB:
def __init__(self, name: str, api_key = None):
if api_key is None:
api_key = os.getenv("VOYAGE_API_KEY")
self.client = voyageai.Client(api_key=api_key)
self.name = name
self.embeddings = []
self.metadata = []
self.query_cache = {}
self.db_path = f"./data/{name}/vector_db.pkl"
texts_to_embed = []
metadata = []
total_chunks = sum(len(doc['chunks']) for doc in dataset)
self._embed_and_store(texts_to_embed, metadata)
self.save_db()
self.embeddings = result
self.metadata = data
if not self.embeddings:
raise ValueError("No data loaded in the vector database.")
top_results = []
for idx in top_indices:
result = {
"metadata": self.metadata[idx],
"similarity": float(similarities[idx]),
}
top_results.append(result)
return top_results
def save_db(self):
data = {
"embeddings": self.embeddings,
"metadata": self.metadata,
"query_cache": json.dumps(self.query_cache),
}
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
with open(self.db_path, "wb") as file:
pickle.dump(data, file)
def load_db(self):
if not os.path.exists(self.db_path):
raise ValueError("Vector database file not found. Use load_data t
with open(self.db_path, "rb") as file:
data = pickle.load(file)
self.embeddings = data["embeddings"]
self.metadata = data["metadata"]
self.metadata = data["metadata"]
self.query_cache = json.loads(data["query_cache"])
def validate_embedded_chunks(self):
unique_contents = set()
for meta in self.metadata:
unique_contents.add(meta['content'])
print(f"Validation results:")
print(f"Total embedded chunks: {len(self.metadata)}")
print(f"Unique embedded contents: {len(unique_contents)}")
if len(self.metadata) != len(unique_contents):
print("Warning: There may be duplicate chunks in the embedded dat
else:
print("All embedded chunks are unique.")
Basic RAG
To get started, we'll set up a basic RAG pipeline using a bare bones approach.
This is sometimes called 'Naive RAG' by many in the industry. A basic RAG
pipeline includes the following 3 steps:
golden_contents.append(golden_chunk['content'].strip())
if not golden_contents:
print(f"Warning: No golden contents found for query: {query}")
continue
# Count how many golden chunks are in the top k retrieved documents
chunks_found = 0
for golden_content in golden_contents:
for doc in retrieved_docs[:k]:
retrieved_content = doc['metadata'].get('original_content', d
if retrieved_content == golden_content:
chunks_found += 1
break
Contextual Embeddings
With basic RAG, each embedded chunk contains a potentially useful piece of
information, but these chunks lack context. With Contextual Embeddings, we
create a variation on the embedding itself by adding more context to each text
chunk before embedding it. Specifically, we use Claude to create a concise
context that explains the chunk using the context of the overall document. In the
case of our codebases dataset, we can provide both the chunk and the full file
that each chunk was found within to an LLM, then produce the context. Then,
we will combine this 'context' and the raw text chunk together into a single text
block prior to creating each embedding.
Prompt caching also makes this much more cost effective. Creating contextual
Prompt caching also makes this much more cost effective. Creating contextual
embeddings requires us to pass the same document to the model for every
chunk we want to generate extra context for. With prompt caching, we can write
the overall doc to the cache once, and then because we're doing our ingestion
job all in sequence, we can just read the document from cache as we generate
context for each chunk within that document (the information you write to the
cache has a 5 minute time to live). This means that the first time we pass a
document to the model, we pay a bit more to write it to the cache, but for each
subsequent API call that contains that doc, we receive a 90% discount on all of
the input tokens read from the cache. Assuming 800 token chunks, 8k token
documents, 50 token context instructions, and 100 tokens of context per chunk,
the cost to generate contextualized chunks is $1.02 per million document
tokens.
When you load data into your ContextualVectorDB below, you'll see in logs just
how big this impact is.
Warning: some smaller embedding models have a fixed input token limit.
Contextualizing the chunk makes it longer, so if you notice much worse
performance from contextualized embeddings, the contextualized chunk is
likely getting truncated
CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>
Please give a short succinct context to situate this chunk within the overall
Answer only with the succinct context and nothing else.
"""
jsonl_data = load_jsonl('data/evaluation_set.jsonl')
# Example usage
doc_content = jsonl_data[0]['golden_documents'][0]['content']
chunk_content = jsonl_data[0]['golden_chunks'][0]['content']
In [318… import os
import pickle
import json
import numpy as np
import voyageai
from typing import List, Dict, Any
from tqdm import tqdm
import anthropic
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
class ContextualVectorDB:
def __init__(self, name: str, voyage_api_key=None, anthropic_api_key=None
if voyage_api_key is None:
voyage_api_key = os.getenv("VOYAGE_API_KEY")
if anthropic_api_key is None:
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
self.voyage_client = voyageai.Client(api_key=voyage_api_key)
self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key
self.name = name
self.embeddings = []
self.metadata = []
self.query_cache = {}
self.query_cache = {}
self.db_path = f"./data/{name}/contextual_vector_db.pkl"
self.token_counts = {
'input': 0,
'output': 0,
'cache_read': 0,
'cache_creation': 0
}
self.token_lock = threading.Lock()
CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>
Please give a short succinct context to situate this chunk within the
Answer only with the succinct context and nothing else.
"""
response = self.anthropic_client.beta.prompt_caching.messages.create(
model="claude-3-haiku-20240307",
max_tokens=1000,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": DOCUMENT_CONTEXT_PROMPT.format(doc_conten
"cache_control": {"type": "ephemeral"} #we will m
},
{
"type": "text",
"text": CHUNK_CONTEXT_PROMPT.format(chunk_content
},
]
},
],
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
)
return response.content[0].text, response.usage
return {
#append the context to the original text chunk
'text_to_embed': f"{chunk['content']}\n\n{contextualized_text
'metadata': {
'doc_id': doc['doc_id'],
'original_uuid': doc['original_uuid'],
'chunk_id': chunk['chunk_id'],
'original_index': chunk['original_index'],
'original_content': chunk['content'],
'contextualized_content': contextualized_text
}
}
self._embed_and_store(texts_to_embed, metadata)
self.save_db()
#we use voyage AI here for embeddings. Read more here: https://docs.voyag
def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
batch_size = 128
result = [
self.voyage_client.embed(
texts[i : i + batch_size],
texts[i : i + batch_size],
model="voyage-2"
).embeddings
for i in range(0, len(texts), batch_size)
]
self.embeddings = [embedding for batch in result for embedding in bat
self.metadata = data
if not self.embeddings:
raise ValueError("No data loaded in the vector database.")
top_results = []
for idx in top_indices:
result = {
"metadata": self.metadata[idx],
"similarity": float(similarities[idx]),
}
top_results.append(result)
return top_results
def save_db(self):
data = {
"embeddings": self.embeddings,
"metadata": self.metadata,
"query_cache": json.dumps(self.query_cache),
}
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
with open(self.db_path, "wb") as file:
pickle.dump(data, file)
def load_db(self):
if not os.path.exists(self.db_path):
raise ValueError("Vector database file not found. Use load_data t
with open(self.db_path, "rb") as file:
data = pickle.load(file)
self.embeddings = data["embeddings"]
self.metadata = data["metadata"]
self.query_cache = json.loads(data["query_cache"])
Contextual BM25
Contextual embeddings is an improvement on traditional semantic search RAG,
but we can improve performance further. In this section we'll show you how you
can use contextual embeddings and contextual BM25 together. While you can
see performance gains by pairing these techniques together without the
context, adding context to these methods reduces the top-20-chunk retrieval
failure rate by 42%.
One difference between a typical BM25 search and what we'll do in this section
is that, for each chunk, we'll run each BM25 search on both the chunk content
and the additional context that we generated in the previous section. From
there, we'll use a technique called reciprocal rank fusion to merge the results
from our BM25 search with our semantic search results. This allows us to
perform a hybrid search across both our BM25 corpus and vector DB to return
the most optimal documents for a given query.
In the function below, we allow you the option to add weightings to the semantic
search and BM25 search documents as you merge them with Reciprocal Rank
Fusion. By default, we set these to 0.8 for the semantic search results and 0.2 to
the BM25 results. We'd encourage you to experiment with different values here.
In [369… import os
import json
from typing import List, Dict, Any
from tqdm import tqdm
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
class ElasticsearchBM25:
def __init__(self, index_name: str = "contextual_bm25_index"):
self.es_client = Elasticsearch("http://localhost:9200")
self.index_name = index_name
self.create_index()
def create_index(self):
index_settings = {
"settings": {
"analysis": {"analyzer": {"default": {"type": "english"}}},
"similarity": {"default": {"type": "BM25"}},
"index.queries.cache.enabled": False # Disable query cache
},
"mappings": {
"properties": {
"content": {"type": "text", "analyzer": "english"},
"contextualized_content": {"type": "text", "analyzer": "e
"doc_id": {"type": "keyword", "index": False},
"chunk_id": {"type": "keyword", "index": False},
"original_index": {"type": "integer", "index": False},
}
},
}
if not self.es_client.indices.exists(index=self.index_name):
self.es_client.indices.create(index=self.index_name, body=index_s
print(f"Created index: {self.index_name}")
# Semantic search
semantic_results = db.search(query, k=num_chunks_to_recall)
ranked_chunk_ids = [(result['metadata']['doc_id'], result['metadata']['or
# Combine results
chunk_ids = list(set(ranked_chunk_ids + ranked_bm25_chunk_ids))
chunk_id_to_score = {}
try:
# Warm-up queries
warm_up_queries = original_data[:10]
for query_item in warm_up_queries:
_ = retrieve_advanced(query_item['query'], db, es_bm25, k)
total_score = 0
total_semantic_count = 0
total_semantic_count = 0
total_bm25_count = 0
total_results = 0
golden_contents = []
for doc_uuid, chunk_index in golden_chunk_uuids:
golden_doc = next((doc for doc in query_item['golden_document
if golden_doc:
golden_chunk = next((chunk for chunk in golden_doc['chunk
if golden_chunk:
golden_contents.append(golden_chunk['content'].strip(
if not golden_contents:
print(f"Warning: No golden contents found for query: {query}"
continue
chunks_found = 0
for golden_content in golden_contents:
for doc in retrieved_docs[:k]:
retrieved_content = doc['chunk']['original_content'].stri
if retrieved_content == golden_content:
chunks_found += 1
break
total_semantic_count += semantic_count
total_bm25_count += bm25_count
total_results += len(retrieved_docs)
total_queries = len(original_data)
average_score = total_score / total_queries
pass_at_n = average_score * 100
results = {
"pass_at_n": pass_at_n,
"average_score": average_score,
"total_queries": total_queries
}
print(f"Pass@{k}: {pass_at_n:.2f}%")
print(f"Average Score: {average_score:.2f}")
print(f"Total queries: {total_queries}")
print(f"Percentage of results from semantic search: {semantic_percent
print(f"Percentage of results from BM25: {bm25_percentage:.2f}%")
finally:
# Delete the Elasticsearch index
# Delete the Elasticsearch index
if es_bm25.es_client.indices.exists(index=es_bm25.index_name):
es_bm25.es_client.indices.delete(index=es_bm25.index_name)
print(f"Deleted Elasticsearch index: {es_bm25.index_name}")
Below, we'll demonstrate only the re-ranking step (skipping the hybrid search
technique for now). You'll see that we retrieve 10x the number of documents
than the number of final k documents we want to retrieve, then use a re-ranking
model from Cohere to select the 10 most relevant results from that list. Adding
the re-ranking step delivers a modest additional gain in performance. In our
case, Pass@10 improves from 92.81% --> 94.79%.
response = co.rerank(
model="rerank-english-v3.0",
query=query,
documents=documents,
top_n=k
)
time.sleep(0.1)
final_results = []
for r in response.results:
original_result = semantic_results[r.index]
final_results.append({
"chunk": original_result['metadata'],
"score": r.relevance_score
})
return final_results
golden_contents = []
for doc_uuid, chunk_index in golden_chunk_uuids:
golden_doc = next((doc for doc in query_item['golden_documents']
if golden_doc:
golden_chunk = next((chunk for chunk in golden_doc['chunks']
golden_chunk = next((chunk for chunk in golden_doc['chunks']
if golden_chunk:
golden_contents.append(golden_chunk['content'].strip())
if not golden_contents:
print(f"Warning: No golden contents found for query: {query}")
continue
chunks_found = 0
for golden_content in golden_contents:
for doc in retrieved_docs[:k]:
retrieved_content = doc['chunk']['original_content'].strip()
if retrieved_content == golden_content:
chunks_found += 1
break