Retrieval-Augmented Generation (RAG) has become the dominant pattern for building LLM applications with up-to-date, domain-specific knowledge. This post explores production-grade RAG architectures, from basic implementations to advanced optimization techniques.
Why RAG?
LLMs have limitations that RAG addresses:
- Knowledge cutoff: Training data becomes stale
- Hallucinations: Generate plausible but incorrect information
- Domain-specific knowledge: Can’t know your private data
- Attribution: Can’t cite sources
RAG solves these by retrieving relevant documents and using them as context.
Basic RAG Architecture
from typing import List
from dataclasses import dataclass
@dataclass
class Document:
id: str
content: str
metadata: dict
embedding: List[float]
class BasicRAG:
def __init__(self, vector_db, llm, embedding_model):
self.vector_db = vector_db
self.llm = llm
self.embedding_model = embedding_model
async def query(self, question: str, top_k: int = 3) -> str:
# 1. Embed the question
question_embedding = self.embedding_model.encode(question)
# 2. Retrieve relevant documents
documents = await self.vector_db.search(
question_embedding,
limit=top_k
)
# 3. Build context from documents
context = "\n\n".join([
f"Source {i+1}: {doc.content}"
for i, doc in enumerate(documents)
])
# 4. Generate answer with context
prompt = self._build_prompt(question, context)
answer = await self.llm.generate(prompt)
# 5. Return answer with sources
return {
"answer": answer,
"sources": [doc.metadata for doc in documents]
}
def _build_prompt(self, question: str, context: str) -> str:
return f"""Answer the question based on the following context. If the context doesn't contain relevant information, say so.
Context:
{context}
Question: {question}
Answer:"""
Advanced Chunking Strategies
How you chunk documents dramatically affects RAG quality:
class ChunkingStrategy:
"""Advanced document chunking for RAG"""
def __init__(self, chunk_size: int = 512, overlap: int = 50):
self.chunk_size = chunk_size
self.overlap = overlap
def chunk_by_tokens(self, text: str) -> List[str]:
"""Simple token-based chunking with overlap"""
tokens = self.tokenize(text)
chunks = []
for i in range(0, len(tokens), self.chunk_size - self.overlap):
chunk_tokens = tokens[i:i + self.chunk_size]
chunks.append(self.detokenize(chunk_tokens))
return chunks
def chunk_semantic(self, text: str) -> List[str]:
"""Semantic chunking based on topic shifts"""
sentences = self.split_sentences(text)
chunks = []
current_chunk = []
current_embedding = None
for sentence in sentences:
sent_embedding = self.embedding_model.encode(sentence)
if current_embedding is None:
current_chunk.append(sentence)
current_embedding = sent_embedding
else:
# Check semantic similarity
similarity = cosine_similarity(current_embedding, sent_embedding)
if similarity < 0.7: # Topic shift detected
# Start new chunk
chunks.append(" ".join(current_chunk))
current_chunk = [sentence]
current_embedding = sent_embedding
else:
current_chunk.append(sentence)
# Update rolling average embedding
current_embedding = (current_embedding + sent_embedding) / 2
# Check size limit
if len(" ".join(current_chunk)) > self.chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = []
current_embedding = None
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def chunk_by_structure(self, markdown_text: str) -> List[dict]:
"""Structure-aware chunking for markdown"""
chunks = []
current_section = {"title": "", "content": [], "level": 0}
lines = markdown_text.split('\n')
for line in lines:
# Check for header
if line.startswith('#'):
# Save previous section
if current_section["content"]:
chunks.append({
"text": "\n".join(current_section["content"]),
"metadata": {
"title": current_section["title"],
"level": current_section["level"]
}
})
# Start new section
level = len(line) - len(line.lstrip('#'))
title = line.lstrip('#').strip()
current_section = {
"title": title,
"content": [],
"level": level
}
else:
current_section["content"].append(line)
# Save final section
if current_section["content"]:
chunks.append({
"text": "\n".join(current_section["content"]),
"metadata": {
"title": current_section["title"],
"level": current_section["level"]
}
})
return chunks
Hybrid Search: Vector + Keyword
Combine semantic and keyword search for better retrieval:
from rank_bm25 import BM25Okapi
class HybridRAG:
"""RAG with hybrid vector + keyword search"""
def __init__(self, vector_db, llm, embedding_model):
self.vector_db = vector_db
self.llm = llm
self.embedding_model = embedding_model
self.bm25 = None
self.documents = []
def index_documents(self, documents: List[Document]):
"""Index documents for both vector and keyword search"""
# Vector indexing
for doc in documents:
doc.embedding = self.embedding_model.encode(doc.content)
self.vector_db.insert(doc)
# Keyword indexing
self.documents = documents
tokenized = [doc.content.lower().split() for doc in documents]
self.bm25 = BM25Okapi(tokenized)
async def hybrid_search(
self,
query: str,
top_k: int = 10,
alpha: float = 0.5 # Weight: 0=keyword only, 1=vector only
) -> List[Document]:
"""Hybrid search combining vector and keyword scores"""
# Vector search
query_embedding = self.embedding_model.encode(query)
vector_results = await self.vector_db.search(query_embedding, limit=top_k * 2)
# Keyword search
tokenized_query = query.lower().split()
bm25_scores = self.bm25.get_scores(tokenized_query)
# Normalize and combine scores
vector_scores = {doc.id: doc.score for doc in vector_results}
max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1
combined_scores = {}
for doc in self.documents:
vec_score = vector_scores.get(doc.id, 0)
bm25_score = bm25_scores[self.documents.index(doc)] / max_bm25
combined_scores[doc.id] = alpha * vec_score + (1 - alpha) * bm25_score
# Sort by combined score and return top k
sorted_docs = sorted(
self.documents,
key=lambda d: combined_scores.get(d.id, 0),
reverse=True
)
return sorted_docs[:top_k]
Query Rewriting and Expansion
Improve retrieval by rewriting queries:
class QueryOptimizer:
"""Optimize queries for better retrieval"""
def __init__(self, llm):
self.llm = llm
async def rewrite_query(self, query: str) -> str:
"""Rewrite query for better retrieval"""
prompt = f"""Rewrite the following question to be more specific and include relevant keywords that would help find the answer in a document database.
Original question: {query}
Rewritten question:"""
rewritten = await self.llm.generate(prompt, temperature=0.3)
return rewritten
async def expand_query(self, query: str) -> List[str]:
"""Generate multiple query variations"""
prompt = f"""Generate 3 different ways to ask this question, each emphasizing different aspects:
Question: {query}
Variations:
1."""
variations = await self.llm.generate(prompt, temperature=0.7)
# Parse variations
expanded = [query] # Include original
for line in variations.split('\n'):
if line.strip() and not line.startswith('Question:'):
expanded.append(line.strip().lstrip('123.-) '))
return expanded[:4] # Return up to 4 total queries
class RewriteRAG(BasicRAG):
"""RAG with query rewriting"""
def __init__(self, vector_db, llm, embedding_model):
super().__init__(vector_db, llm, embedding_model)
self.query_optimizer = QueryOptimizer(llm)
async def query(self, question: str, top_k: int = 3) -> dict:
# Rewrite query
rewritten = await self.query_optimizer.rewrite_query(question)
# Expand to variations
variations = await self.query_optimizer.expand_query(rewritten)
# Retrieve for each variation
all_docs = []
for variation in variations:
docs = await self._retrieve(variation, top_k=top_k)
all_docs.extend(docs)
# Deduplicate and re-rank
unique_docs = self._deduplicate(all_docs)
top_docs = unique_docs[:top_k]
# Generate answer
context = "\n\n".join([doc.content for doc in top_docs])
prompt = self._build_prompt(question, context)
answer = await self.llm.generate(prompt)
return {
"answer": answer,
"sources": [doc.metadata for doc in top_docs],
"query_variations": variations
}
Re-ranking Retrieved Documents
Re-rank documents for relevance:
from sentence_transformers import CrossEncoder
class ReRankRAG:
"""RAG with re-ranking stage"""
def __init__(self, vector_db, llm, embedding_model):
self.vector_db = vector_db
self.llm = llm
self.embedding_model = embedding_model
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
async def query(self, question: str, retrieve_k: int = 20, top_k: int = 3) -> dict:
# 1. Initial retrieval (over-fetch)
docs = await self._retrieve(question, top_k=retrieve_k)
# 2. Re-rank with cross-encoder
pairs = [[question, doc.content] for doc in docs]
scores = self.reranker.predict(pairs)
# Sort by reranker scores
scored_docs = list(zip(docs, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
# 3. Take top k after re-ranking
top_docs = [doc for doc, score in scored_docs[:top_k]]
# 4. Generate answer
context = "\n\n".join([doc.content for doc in top_docs])
prompt = self._build_prompt(question, context)
answer = await self.llm.generate(prompt)
return {
"answer": answer,
"sources": [doc.metadata for doc in top_docs],
"rerank_scores": [float(score) for doc, score in scored_docs[:top_k]]
}
Multi-step RAG (Iterative Refinement)
class IterativeRAG:
"""Multi-step RAG that refines retrieval based on initial answer"""
async def query(self, question: str, max_iterations: int = 3) -> dict:
conversation_history = []
retrieved_docs = []
for i in range(max_iterations):
# Retrieve documents
docs = await self._retrieve(question, conversation_history)
retrieved_docs.extend(docs)
# Generate answer with current context
context = self._build_context(retrieved_docs)
answer = await self.llm.generate(
self._build_prompt(question, context)
)
conversation_history.append({
"iteration": i,
"answer": answer,
"docs": docs
})
# Check if answer is confident
if self._is_confident(answer):
break
# Generate follow-up query if needed
question = await self._generate_follow_up(question, answer)
# Final answer with all retrieved context
final_context = self._build_context(retrieved_docs)
final_answer = await self.llm.generate(
self._build_final_prompt(question, final_context, conversation_history)
)
return {
"answer": final_answer,
"iterations": conversation_history,
"sources": self._deduplicate_sources(retrieved_docs)
}
def _is_confident(self, answer: str) -> bool:
"""Check if answer indicates confidence"""
uncertain_phrases = [
"i don't know",
"not sure",
"unclear",
"need more information"
]
return not any(phrase in answer.lower() for phrase in uncertain_phrases)
async def _generate_follow_up(self, original_question: str, answer: str) -> str:
"""Generate follow-up query to fill knowledge gaps"""
prompt = f"""The user asked: {original_question}
The current answer is: {answer}
Generate a follow-up question that would help provide a more complete answer:"""
follow_up = await self.llm.generate(prompt, temperature=0.5)
return follow_up
Production Optimizations
import asyncio
from functools import lru_cache
class ProductionRAG:
"""Production-optimized RAG with caching and batching"""
def __init__(self, vector_db, llm, embedding_model, cache):
self.vector_db = vector_db
self.llm = llm
self.embedding_model = embedding_model
self.cache = cache
async def query(self, question: str, top_k: int = 3) -> dict:
# Check cache
cache_key = self._get_cache_key(question)
cached = await self.cache.get(cache_key)
if cached:
return cached
# Parallel retrieval and rewriting
rewritten_query, docs = await asyncio.gather(
self._rewrite_query(question),
self._retrieve(question, top_k=top_k * 2)
)
# Re-rank
top_docs = await self._rerank(question, docs, top_k)
# Generate answer
answer = await self._generate_answer(question, top_docs)
result = {
"answer": answer,
"sources": [doc.metadata for doc in top_docs]
}
# Cache result
await self.cache.set(cache_key, result, ttl=3600)
return result
@lru_cache(maxsize=1000)
def _get_cache_key(self, question: str) -> str:
"""Generate cache key for question"""
import hashlib
return hashlib.sha256(question.encode()).hexdigest()[:16]
async def batch_query(self, questions: List[str]) -> List[dict]:
"""Process multiple queries efficiently"""
# Batch embedding generation
embeddings = self.embedding_model.encode(questions)
# Parallel retrieval for all questions
retrieval_tasks = [
self.vector_db.search(emb, limit=10)
for emb in embeddings
]
all_docs = await asyncio.gather(*retrieval_tasks)
# Batch LLM generation
prompts = [
self._build_prompt(q, self._build_context(docs))
for q, docs in zip(questions, all_docs)
]
answers = await self.llm.batch_generate(prompts)
return [
{
"answer": answer,
"sources": [doc.metadata for doc in docs[:3]]
}
for answer, docs in zip(answers, all_docs)
]
Evaluation and Monitoring
class RAGEvaluator:
"""Evaluate RAG system performance"""
def __init__(self):
self.metrics = []
async def evaluate(self, rag_system, test_cases: List[dict]) -> dict:
"""
Evaluate RAG on test cases.
test_cases: [
{
"question": "...",
"expected_answer": "...",
"relevant_docs": [...]
}
]
"""
results = []
for test in test_cases:
result = await rag_system.query(test["question"])
# Evaluate retrieval quality
retrieval_score = self._evaluate_retrieval(
result["sources"],
test["relevant_docs"]
)
# Evaluate answer quality
answer_score = self._evaluate_answer(
result["answer"],
test["expected_answer"]
)
results.append({
"question": test["question"],
"retrieval_score": retrieval_score,
"answer_score": answer_score,
"latency": result.get("latency", 0)
})
return {
"avg_retrieval_score": sum(r["retrieval_score"] for r in results) / len(results),
"avg_answer_score": sum(r["answer_score"] for r in results) / len(results),
"avg_latency": sum(r["latency"] for r in results) / len(results),
"results": results
}
def _evaluate_retrieval(self, retrieved: List, relevant: List) -> float:
"""Calculate retrieval precision and recall"""
retrieved_ids = {doc.get("id") for doc in retrieved}
relevant_ids = set(relevant)
if not relevant_ids:
return 1.0
intersection = retrieved_ids & relevant_ids
precision = len(intersection) / len(retrieved_ids) if retrieved_ids else 0
recall = len(intersection) / len(relevant_ids)
# F1 score
if precision + recall == 0:
return 0.0
return 2 * (precision * recall) / (precision + recall)
Conclusion
Production RAG requires:
- Smart chunking - Semantic, structure-aware strategies
- Hybrid search - Combine vector and keyword
- Query optimization - Rewrite and expand
- Re-ranking - Improve relevance
- Caching - Reduce latency and cost
- Monitoring - Track retrieval and answer quality
RAG is the foundation for most production LLM applications. Invest in getting it right.