04 — Cross-Encoder Rerank: Cost vs Quality¶
Problem: Bi-encoder retrieval is fast but shallow. A cross-encoder scores each (query, passage) pair with joint attention — better ordering, higher compute.
In this notebook: Retrieve top-8 with bi-encoder, rerank with cross-encoder/ms-marco-MiniLM-L-6-v2, show rank changes.
In [ ]:
import sys
from pathlib import Path
_REPO = Path.cwd().resolve()
if (_REPO / "src").is_dir():
sys.path.insert(0, str(_REPO / "src"))
from rag_series_utils import chroma_path, get_client
from sentence_transformers import SentenceTransformer, CrossEncoder
chunks = [
"Kubernetes liveness probes restart unhealthy pods automatically.",
"Readiness probes remove pods from service endpoints until traffic-ready.",
"Startup probes cover slow-boot containers without false kills.",
"Horizontal Pod Autoscaler scales based on CPU or custom metrics.",
]
query = "Which probe type stops sending traffic before the container is ready?"
bi = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
p = chroma_path("nb04_rerank")
client = get_client(p)
try:
client.delete_collection("c")
except Exception:
pass
col = client.create_collection("c", metadata={"hnsw:space": "cosine"})
emb = bi.encode(chunks, show_progress_bar=False).tolist()
col.add(ids=[str(i) for i in range(len(chunks))], documents=chunks, embeddings=emb)
qe = bi.encode(query, show_progress_bar=False).tolist()
res = col.query(query_embeddings=[qe], n_results=len(chunks))
order = res["documents"][0]
pairs = [[query, d] for d in order]
scores = ce.predict(pairs)
ranked = [order[i] for i in sorted(range(len(order)), key=lambda i: -scores[i])]
print("Bi-encoder order:", order)
print("Cross-encoder order:", ranked)
print("Correct doc first after rerank:", ranked[0].startswith("Readiness"))
Takeaways
- Typical pattern: retrieve wide (e.g. 50–200), rerank narrow (5–20).
- Watch p95 latency and GPU memory; batch rerank scores when possible.
- For very large corpora, consider late interaction models as a middle ground.