Table of contents
Open Table of contents
- Key Concepts
- 1. Architecture & Startup Flow
- 2. Engine Initialization
- 3. First Request
- 4. Second Request - Prefix Cache Hit
Key Concepts
Before diving in, here are the key terms you’ll encounter:
| Concept | Description |
|---|---|
| Prefill | Process all input tokens at once, compute full KV cache |
| Decode | Generate 1 token/step, reuse cached KV |
| Paged KV Cache | Allocate KV cache in fixed-size pages, not contiguous blocks |
| Page Table | Maps (request_id, token_pos) → page_index in KV cache pool |
| CUDA Graph | Pre-captured GPU ops for decode (fixed shapes only) |
| Prefix Cache | Reuse KV cache across requests with same prefix |
| ZMQ IPC | Inter-process communication via Unix sockets (Push/Pull) |
| Chunked Prefill | Split large inputs into multiple batches (≤ prefill_budget) |
1. Architecture & Startup Flow
We start by understanding how Mini SGLang organizes its processes and how the system boots up.
Mini SGLang uses a multi-process architecture with ZMQ for inter-process communication.
1.1 System Architecture
┌────────────────────────────────────────────────────────────┐
│ MAIN PROCESS (FastAPI) │
│ FrontendManager │
│ - HTTP endpoints (/v1/chat/completions) │
│ - Send/receive via ZMQ │
└────────────────────────────────────────────────────────────┘
│ ZMQ (ipc:///tmp/minisgl_4) ▲ ZMQ (ipc:///tmp/minisgl_3)
▼ │
┌─────────────────┐ ┌─────────────────┐
│ TOKENIZER │ │ DETOKENIZER │
│ text → tokens │ │ tokens → text │
└─────────────────┘ └─────────────────┘
│ ZMQ (ipc:///tmp/minisgl_0) ▲ ZMQ (ipc:///tmp/minisgl_1)
└────────────┐ ┌──────────┘
▼ │
┌────────────────────────────────────────────────────────────┐
│ SCHEDULER PROCESS │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Scheduler │ │
│ │ ├── PrefillManager (pending requests) │ │
│ │ ├── DecodeManager (running requests) │ │
│ │ ├── CacheManager (KV cache pages) │ │
│ │ └── TableManager (request slots) │ │
│ └──────────────────────────────────────────────────────┘ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Engine │ │
│ │ ├── Model (Qwen3/Llama) │ │
│ │ ├── KVCache (paged memory) │ │
│ │ ├── AttnBackend (FlashInfer) │ │
│ │ └── GraphRunner (CUDA graphs for decode) │ │
│ └──────────────────────────────────────────────────────┘ │
└────────────────────────────────────────────────────────────┘
| Channel | Address | Direction | Message Type |
|---|---|---|---|
| Frontend → Tokenizer | ipc:///tmp/minisgl_4 | Push/Pull | TokenizeMsg |
| Tokenizer → Scheduler | ipc:///tmp/minisgl_0 | Push/Pull | UserMsg |
| Scheduler → Detokenizer | ipc:///tmp/minisgl_1 | Push/Pull | DetokenizeMsg |
| Detokenizer → Frontend | ipc:///tmp/minisgl_3 | Push/Pull | UserReply |
| Scheduler Broadcast | ipc:///tmp/minisgl_2 | Pub/Sub | (for TP > 1) |
1.2 Startup Flow
python -m minisgl --model "Qwen/Qwen3-0.6B"
Phase 1: Entry Point & Parse Arguments
📁 python/minisgl/__main__.py
📁 python/minisgl/server/launch.py:40 → launch_server()
📁 python/minisgl/server/args.py:54 → parse_args()
ServerArgs(
model_path="Qwen/Qwen3-0.6B",
dtype=torch.bfloat16,
tp_info=DistributedInfo(rank=0, size=1),
server_host="127.0.0.1",
server_port=1919,
)
Phase 2: Create FrontendManager (Main Process)
📁 python/minisgl/server/api_server.py:408 → run_api_server()
FrontendManager initializes ZMQ connections:
recv_from_detokenizer: Pull fromipc:///tmp/minisgl_3send_to_tokenizer: Push toipc:///tmp/minisgl_4
Phase 3: Spawn Worker Processes
📁 python/minisgl/server/launch.py:47 → start_subprocess()
Uses multiprocessing.Process with spawn method:
# 1. Spawn Scheduler
mp.Process(
target=_run_scheduler,
args=(new_args, ack_queue),
name="minisgl-TP0-scheduler",
).start()
# 2. Spawn Detokenizer
mp.Process(
target=tokenize_worker,
kwargs={
"addr": "ipc:///tmp/minisgl_1", # Listen here
"backend_addr": "ipc:///tmp/minisgl_0",
"frontend_addr": "ipc:///tmp/minisgl_3", # Send to frontend
},
name="minisgl-detokenizer-0",
).start()
# 3. Spawn Tokenizer
mp.Process(
target=tokenize_worker,
kwargs={
"addr": "ipc:///tmp/minisgl_4", # Listen here
"backend_addr": "ipc:///tmp/minisgl_0", # Send to scheduler
},
name="minisgl-tokenizer-0",
).start()
Phase 4: Scheduler Process Initialization
📁 python/minisgl/server/launch.py:16 → _run_scheduler()
📁 python/minisgl/scheduler/scheduler.py:80 → Scheduler.__init__()
This runs in a SEPARATE PROCESS!
# 1. Create Engine (loads model, creates KV cache)
self.engine = Engine(config)
# 2. Initialize ZMQ I/O
self.io = SchedulerIO(
recv_addr="ipc:///tmp/minisgl_0", # Receive from tokenizer
send_addr="ipc:///tmp/minisgl_1", # Send to detokenizer
)
# 3. Create managers
self.prefill_manager = PrefillManager(...)
self.decode_manager = DecodeManager()
self.cache_manager = CacheManager(...)
self.table_manager = TableManager(...)
# 4. Send acknowledgment
ack_queue.put("Scheduler is ready")
Phase 5: Tokenizer/Detokenizer Initialization
📁 python/minisgl/tokenizer/server.py:30 → tokenize_worker()
This runs in SEPARATE PROCESS(es)!
# 1. Create ZMQ connections
send_backend = zmq.Socket(zmq.PUSH) # To scheduler
send_frontend = zmq.Socket(zmq.PUSH) # To frontend (detokenizer only)
recv_listener = zmq.Socket(zmq.PULL) # Listen for messages
# 2. Load HuggingFace tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 3. Create managers
tokenize_manager = TokenizeManager(tokenizer)
detokenize_manager = DetokenizeManager(tokenizer)
# 4. Send acknowledgment
ack_queue.put("Tokenize server 0 is ready")
Phase 6: Wait for All Workers & Start Server
📁 python/minisgl/server/launch.py:110
📁 python/minisgl/server/api_server.py:454
# Main process waits for acks
for _ in range(num_tokenizers + 2):
logger.info(ack_queue.get())
# Output:
# [INFO] Scheduler is ready
# [INFO] Tokenize server 0 is ready
# Start Uvicorn
uvicorn.run(app, host="127.0.0.1", port=1919)
# [INFO] Uvicorn running on http://127.0.0.1:1919
Phase 7: Scheduler Main Loop
📁 python/minisgl/scheduler/scheduler.py:288 → run_forever()
@torch.inference_mode()
def run_forever(self) -> NoReturn:
data = None
while True:
data = self.overlap_loop(data)
# Continuously:
# 1. Receive messages from tokenizer
# 2. Schedule prefill/decode batches
# 3. Execute forward pass
# 4. Send results to detokenizer
1.3 Process Summary
| Process Name | Main Function | ZMQ Role |
|---|---|---|
| Main (FastAPI) | run_api_server() | Send to tokenizer, recv from detokenizer |
| minisgl-TP0-scheduler | run_forever() | Recv from tokenizer, send to detokenizer |
| minisgl-tokenizer-N | tokenize_worker() | Recv from frontend, send to scheduler |
| minisgl-detokenizer-0 | tokenize_worker() | Recv from scheduler, send to frontend |
2. Engine Initialization
With the multi-process architecture in place, let’s zoom into the Scheduler process - the brain of the system. When it starts, the Engine is initialized with these key parameters:
2.1 Key Parameters Overview
| Parameter | Default | Description | Determined By |
|---|---|---|---|
max_running_req | 256 | Max concurrent requests in system | Config |
max_seq_len | 40960 | Max tokens per request (input + output) | Model’s RoPE max_position |
prefill_budget | 8192 | Max tokens per prefill batch | Config |
num_kv_pages | ~180K | Total KV cache pages available | GPU memory |
cuda_graph_max_bs | 160 | Max batch size for CUDA graph | GPU memory |
How these relate:
┌─────────────────────────────────────────────────────────────────────┐
│ CAPACITY LIMITS │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ max_running_req = 256 │
│ └── How many requests can run simultaneously │
│ └── Each request gets 1 slot in TableManager │
│ └── Page table shape: (257, max_seq_len) │
│ │
│ max_seq_len = 40960 │
│ └── Max context length per request │
│ └── input_tokens + output_tokens ≤ 40960 │
│ └── From model config (RoPE max_position) │
│ │
│ prefill_budget = 8192 │
│ └── Max tokens processed in one prefill batch │
│ └── Controls latency (don't block decode too long) │
│ └── Large inputs are chunked │
│ │
│ num_kv_pages ≈ 180874 │
│ └── Total memory for KV cache │
│ └── Shared across ALL requests │
│ └── Each token needs 1 page (112KB for Qwen3-0.6B) │
│ │
│ cuda_graph_max_bs = 160 │
│ └── Largest batch size with pre-captured CUDA graph │
│ └── Decode batches > 160 run without graph (slower) │
│ │
└─────────────────────────────────────────────────────────────────────┘
Example capacity calculation:
# With Qwen3-0.6B on 24GB GPU:
num_kv_pages = 180874
max_running_req = 256
max_seq_len = 40960
# Theoretical max (if all requests use max length):
# 256 requests × 40960 tokens = 10485760 pages needed
# But we only have 180874 pages!
# Practical limit:
# 180874 pages / 256 requests ≈ 706 tokens/request average
# OR fewer long requests + more short requests
Prefill budget limits tokens per batch to control memory usage and latency. Large inputs are chunked:
prefill_budget = 8192 (max tokens per prefill batch)
input_len=5000 → 1 batch
input_len=10000 → 2 batches: 8192 + 1808
2.2 KV Cache & Page Table Relationship
Problem: LLM inference needs to store Key-Value cache for each processed token. If we allocate contiguous memory for each request → memory fragmentation & waste.
Solution: Paged KV Cache - divide memory into fixed-size pages, allocate on-demand.
Note: In Mini SGLang,
page_size = 1, meaning 1 page = 1 token. This simplifies the mapping: number of pages = number of tokens. Production systems like vLLM/SGLang often use larger page sizes (e.g., 16 tokens/page) for efficiency.
┌───────────────────────────────────────────────────────────────────────┐
│ GPU MEMORY │
├───────────────────────────────────────────────────────────────────────┤
│ │
│ KV Cache Pool (180874 pages) │
│ ┌───────┬───────┬───────┬───────┬───────┬───────┬───────────────┐ │
│ │ Page 0│ Page 1│ Page 2│ Page 3│ Page 4│ Page 5│...Page 180873 │ │
│ │ 112KB │ 112KB │ 112KB │ 112KB │ 112KB │ 112KB │ 112KB │ │
│ └───────┴───────┴───────┴───────┴───────┴───────┴───────────────┘ │
│ ▲ ▲ ▲ ▲ ▲ ▲ │
│ │ │ │ │ │ │ │
│ └───────┴───────┴───────┴───────┴───────┘ │
│ │ │ │
│ ┌─────────┘ └─────────┐ │
│ │ │ │
├────┼──────────────────────────────────┼───────────────────────────────┤
│ │ Page Table (257 × 40960) │ │
│ │ │ │
│ ┌─┴────────────────────────────────┬─┴─────────────────────────────┐ │
│ │ req_idx │ pos 0 │ pos 1 │ pos 2 │ pos 3 │ pos 4 │ ... │ │
│ ├─────────┼───────┼───────┼───────┼───────┼───────┼────────────────┤ │
│ │ Req 0 │ 0 │ 1 │ 2 │ -1 │ -1 │ (3 tokens) │ │
│ │ Req 1 │ 3 │ 4 │ 5 │ -1 │ -1 │ (3 tokens) │ │
│ │ ... │ ... │ ... │ ... │ ... │ ... │ │ │
│ │ Req 256 │ dummy │ dummy │ dummy │ dummy │ dummy │ (CUDA graph) │ │
│ └─────────┴───────┴───────┴───────┴───────┴───────┴────────────────┘ │
│ │
└───────────────────────────────────────────────────────────────────────┘
Page Table[req_idx][token_pos] = page_index in KV Cache Pool
How It Works
1. KV Cache Pool - Pre-allocated GPU memory:
# Each page stores KV for 1 token (page_size=1) across all layers
cache_per_page = (
2 # key + value
* head_dim # 128
* num_kv_heads # 8
* page_size # 1 token
* dtype.itemsize # 2 bytes (bfloat16)
* num_layers # 28 layers
)
# = 2 * 128 * 8 * 1 * 2 * 28 = 114688 bytes ≈ 112 KB/page
num_pages = available_gpu_memory // cache_per_page
# Example: 19.32 GiB → 180874 pages
2. Page Table - Maps (request, position) → page:
# Shape: (max_running_req + 1, max_seq_len) = (257, 40960)
page_table = torch.zeros((257, 40960), dtype=torch.int32, device="cuda")
# When we need KV for token position 5 of request 0:
page_idx = page_table[0, 5] # → page index in KV Cache Pool
kv_data = kv_cache_pool[page_idx] # → actual KV tensors
3. Allocation Flow:
Request arrives (input_len=16):
│
├── CacheManager.allocate(16)
│ └── Take 16 free pages from pool: [0, 1, 2, ..., 15]
│
├── TableManager.allocate()
│ └── Take 1 free slot: slot=0
│
└── Update page_table[0, 0:16] = [0, 1, 2, ..., 15]
Decode step (generate 1 token):
│
├── CacheManager.allocate(1)
│ └── Take 1 free page: [16]
│
└── Update page_table[0, 16] = 16
Final state for request 0:
page_table[0] = [0, 1, 2, ..., 16, -1, -1, ...]
└── 17 pages allocated ──┘
4. Why Paged?
| Approach | Memory Usage | Flexibility |
|---|---|---|
| Contiguous | Pre-allocate max_seq_len per request → waste | Fixed size |
| Paged | Allocate on-demand → efficient | Dynamic growth |
Example: 2 requests with max_seq_len=1000:
Contiguous:
┌──────────────────────────────────────┐
│ Req 0: [████████░░░░░░░░░░░░░░░░░░░] │ ← 200 tokens, waste 800
│ Req 1: [██████████████░░░░░░░░░░░░░] │ ← 500 tokens, waste 500
└──────────────────────────────────────┘
Total waste: 1300 slots
Paged:
┌──────────────────────────────────────┐
│ Pool: [████████████████████████░░░░] │ ← 700 pages used
│ Req 0 pages: [0,1,2,...,199] │
│ Req 1 pages: [200,201,...,699] │
└──────────────────────────────────────┘
Total waste: 0 (pages allocated on-demand)
CUDA Graph Capture
Pre-captured for decode phase only (fixed shapes):
batch_sizes = [1, 2, 4] + list(range(8, 161, 8))
# = [1, 2, 4, 8, 16, 24, 32, ..., 160]
# Total: 23 graphs
# During inference, batch is padded to nearest graph size
# Example: batch_size=5 → pad to 8
Why decode only? Prefill has variable sequence lengths, CUDA graphs need fixed shapes.
3. First Request
Now that we understand the Engine’s memory layout and CUDA graphs, let’s trace a single request through the entire system - from HTTP arrival to response streaming.
3.1 Overview
┌─────────────────────────────────────────────────────────────────────┐
│ REQUEST LIFECYCLE │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ Step 1: HTTP Request │
│ └── Client → FastAPI → Tokenizer (text → tokens) │
│ │
│ Step 2: Prefill │
│ └── Process ALL input tokens at once │
│ └── Allocate KV cache pages, compute KV, sample 1st token │
│ │
│ Step 3: Decode Loop (repeat until done) │
│ └── Process 1 token per step (using CUDA graph) │
│ └── Allocate 1 page, compute KV, sample next token │
│ │
│ Step 4: Stream Response │
│ └── Detokenizer → FastAPI → Client (SSE streaming) │
│ │
│ Step 5: Cleanup │
│ └── Free resources (table slot, KV pages) │
│ └── Optionally cache prefix for reuse │
│ │
└─────────────────────────────────────────────────────────────────────┘
3.2 Request Flow
Example: "What is 2+2?" with max_tokens=16, temperature=0.0
Step 1: HTTP → Frontend → Tokenizer
POST /v1/chat/completions
│
▼
FrontendManager:
- Assign uid=0 (first request gets uid=0)
- Create TokenizeMsg(uid=0, text="What is 2+2?", sampling_params=...)
│
▼ ZMQ Push (ipc:///tmp/minisgl_4)
│
Tokenizer Process:
- Apply chat template + encode → input_ids (16 tokens)
- Create UserMsg(uid=0, input_ids=tensor([...]), sampling_params=...)
│
▼ ZMQ Push (ipc:///tmp/minisgl_0)
Why 16 tokens? The raw text “What is 2+2?” is only ~5 tokens, but after applying chat template:
<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n
<|im_start|>user\nWhat is 2+2?<|im_end|>\n
<|im_start|>assistant\n
The full prompt becomes ~16 tokens (varies by model’s chat template).
Step 2: Scheduler Prefill
Scheduler receives UserMsg:
│
├── PrefillManager.add_one_req(msg)
│ └── pending_list.append(PendingReq(...))
│
├── Schedule prefill batch:
│ ├── CacheManager.match_req() → cached_len=0 (no prefix cache hit)
│ ├── TableManager.allocate() → slot=255
│ └── CacheManager.allocate(16) → 16 pages
│
└── Engine.forward_batch(phase="prefill"):
├── Direct model forward (no CUDA graph)
├── Process all 16 input tokens in parallel
├── Compute KV cache for all 16 positions
└── Sample first output token
Key values explained:
| Value | Explanation |
|---|---|
cached_len=0 | No prefix cache hit (first request with this prompt) |
slot=255 | TableManager allocates from end: free_slots=[0,1,...,255] → pop last → 255 |
16 pages | 1 page per input token (page_size=1), so 16 tokens need 16 pages |
no CUDA graph | Prefill has variable seq_len, can’t use pre-captured graphs |
What happens inside Engine.forward_batch(phase="prefill"):
Input: [tok_0, tok_1, ..., tok_15] (16 tokens)
1. Embedding → X shape: (16, d_model)
2. For each layer (×28):
┌──────────────────────────────────────────────────────────────────────┐
│ a) Compute Q, K, V for ALL 16 tokens in parallel │
│ │
│ b) Write K, V to the 16 pages allocated by CacheManager: │
│ kv_cache[layer, 0, page_0:page_15] = K │
│ kv_cache[layer, 1, page_0:page_15] = V │
│ │
│ c) Self-attention with causal mask: │
│ scores = Q × K^T / √d_k shape: (16, 16) │
│ K_0 K_1 K_2 ... K_15 │
│ Q_0 ✓ ✗ ✗ ... ✗ ← Each token only attends │
│ Q_1 ✓ ✓ ✗ ... ✗ to previous tokens │
│ ... │
│ Q_15 ✓ ✓ ✓ ... ✓ ← Last token sees all │
│ │
│ d) FFN → next layer │
└──────────────────────────────────────────────────────────────────────┘
3. Sample from LAST position only:
logits = X[15] × W_vocab → tok_16 = first output token
Memory state after prefill:
KV Cache pages: 180874 - 16 = 180858 free
Page table[slot=255, pos 0:16] = [page_0, page_1, ..., page_15]
KV Cache content (per layer):
K_cache[page_0:page_15] = [K_0, K_1, ..., K_15] ← 16 key vectors
V_cache[page_0:page_15] = [V_0, V_1, ..., V_15] ← 16 value vectors
Output: tok_16 sampled (first output token) → Ready for decode!
Step 3: Decode Loop (×15)
For each decode step:
│
├── DecodeManager.schedule_next_batch()
│ └── Create batch with running requests
│
├── CacheManager.allocate(1) → 1 new page for new token
│
├── Engine.forward_batch(phase="decode"):
│ ├── CUDA graph replay (batch_size=1)
│ ├── Process only 1 new token (last generated)
│ └── Attention reads from cached KV (16 + n positions)
│
└── Send DetokenizeMsg to detokenizer
│
▼ ZMQ Push (ipc:///tmp/minisgl_1)
What happens inside Engine.forward_batch(phase="decode"):
Decode step 1 (generate token 17):
┌─────────────────────────────────────────────────────────────────────────┐
│ Prefill (16 tokens): Decode (1 new token): │
│ ──────────────────── ───────────────────── │
│ Q = [x_0..x_15] × W_q Q_16 = x_16 × W_q ← ONLY 1 token │
│ K = [x_0..x_15] × W_k K_16 = x_16 × W_k ← ONLY 1 token │
│ V = [x_0..x_15] × W_v V_16 = x_16 × W_v ← ONLY 1 token │
│ │
│ KV Cache (from prefill): K_cached = [K_0..K_15], V_cached = [V_0..V_15] │
│ │
│ Decode computation: │
│ 1. Compute K_16, V_16, Q_16 ← Only 3 projections (not 48!) │
│ 2. Write to cache: K[16] = K_16, V[16] = V_16 │
│ 3. Attention: Q_16 × [K_0..K_16]^T ← Read ALL 17 K vectors │
│ 4. Output: softmax(scores) × [V_0..V_16] │
│ 5. Sample token 17 │
└─────────────────────────────────────────────────────────────────────────┘
Why decode is MEMORY-BOUND (not compute-bound):
Saved by KV cache (not recomputed):
K_0..K_15, V_0..V_15 → read from cache, not recomputed
→ Saved: 16 × 2 = 32 projection ops per layer
Still required (grows with seq_len):
Attention = Q_16 × [K_0..K_16]^T = 17 dot products
Step 1: 17 ops → Step 2: 18 ops → ... → Step 15: 31 ops
The bottleneck:
- Must READ entire KV cache each step (seq_len × hidden_dim × 2)
- But only COMPUTE for 1 token → low arithmetic intensity
- GPU waits for memory → this is why batching helps throughput!
Why 15 decode steps?
max_tokens = 16 (total output tokens requested)
Prefill: samples 1 token → remain_len = 16 - 1 = 15
Decode 1: samples 1 token → remain_len = 15 - 1 = 14
Decode 2: samples 1 token → remain_len = 14 - 1 = 13
...
Decode 15: samples 1 token → remain_len = 1 - 1 = 0 → STOP
Total output: 1 (prefill) + 15 (decode) = 16 tokens ✓
Memory timeline:
Note: Pages are allocated to compute KV, not to store sampled tokens. The sampled token’s KV is computed in the NEXT step.
| Step | Pages alloc | Free pages | KV computed for pos | Token sampled at pos |
|---|---|---|---|---|
| Before prefill | - | 180874 | - | - |
| After prefill | 16 | 180858 | 0-15 (input) | 16 (1st output) |
| After decode 1 | 1 | 180857 | 16 | 17 (2nd output) |
| After decode 2 | 1 | 180856 | 17 | 18 (3rd output) |
| … | … | … | … | … |
| After decode 15 | 1 | 180843 | 30 | 31 (16th output) |
| Total | 31 | - | 31 positions | 16 output tokens |
Why 31 pages for 32 tokens (16 input + 16 output)?
- The 16th output token (position 31) is sampled but never processed
- Its KV is never computed because generation stops (remain_len = 0)
- So we only need KV for positions 0-30 = 31 pages
Why CUDA graph works for decode?
- Decode always processes exactly 1 token per request
- Batch is padded to nearest captured size (1, 2, 4, 8, …)
- Fixed shapes → can replay pre-captured graph
Step 4: Detokenizer → Frontend → Client
Detokenizer receives DetokenizeMsg:
│
├── Decode token_id → text (incremental detokenization)
│
└── Create UserReply(uid=0, text="...", finished=False/True)
│
▼ ZMQ Push (ipc:///tmp/minisgl_3)
│
FrontendManager:
└── Stream response to HTTP client (SSE)
│
▼
data: {"choices": [{"delta": {"content": "The"}}]}
data: {"choices": [{"delta": {"content": " answer"}}]}
data: {"choices": [{"delta": {"content": " is"}}]}
data: {"choices": [{"delta": {"content": " 4"}}]}
...
data: [DONE]
Step 5: Completion & Cleanup
When finished=True (remain_len <= 0):
│
├── DecodeManager.remove_req(req)
│ └── running_reqs: 1 → 0
│
├── TableManager.free(slot=255)
│ └── free_slots: [0,...,254] → [0,...,255] (256 available)
│
├── CacheManager.free_and_cache_finished_req()
│ └── Option 1: Free all 31 pages back to pool
│ └── Option 2: Cache prefix (16 input pages) for future reuse
│
└── Scheduler returns to idle, waiting for new requests
Final memory state:
If prefix cached:
- Free pages: 180843 + 15 = 180858 (output pages freed)
- Cached: 16 pages (input prefix for potential reuse)
If not cached:
- Free pages: 180843 + 31 = 180874 (all pages freed)
Stopping Conditions
Generation stops when ANY condition is true:
finished = (
req.remain_len <= 0 # Reached max_tokens
or next_token == eos_token_id # Model output EOS token
or device_len >= max_seq_len # Hit context length limit (40960)
)
| Condition | Example |
|---|---|
remain_len <= 0 | Generated 16 tokens as requested |
next_token == EOS | Model naturally ends response |
device_len >= max_seq_len | Input + output hits 40960 tokens |
4. Second Request - Prefix Cache Hit
The first request completed and its KV cache was stored in the RadixTree. What happens when an identical request arrives? This is where Prefix Caching shines.
4.1 Overview
┌─────────────────────────────────────────────────────────────────────────┐
│ REQUEST 2 FLOW │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Tokenize │────>│ Cache Match │────>│ Prefill │ │
│ │ (16 tokens) │ │ (15 cached!) │ │ (1 token) │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────────────────────────────────┐ │
│ │ KV Cache │ │
│ │ ┌─────────────────┬────────────┐ │ │
│ │ │ Cached (15 pgs) │ New (1 pg) │ │ │
│ │ │ [REUSED] │ [COMPUTED] │ │ │
│ │ └─────────────────┴────────────┘ │ │
│ └──────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ Decode Loop │ │
│ │ (15 iterations) │ │
│ └──────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ Update Cache │ │
│ └──────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
| Metric | First Request | Second Request | Savings |
|---|---|---|---|
cached_len | 0 | 15 | - |
extend_len (tokens to compute) | 16 | 1 | 94% |
| Pages allocated (prefill) | 16 | 1 | 94% |
Key Insight: Second request computes only 1 token instead of 16!
4.2 Why Prefix Cache Works
RadixTree State (before request 2)
ROOT
│
[t0, t1, ..., t14] ← 15 tokens cached from request 1
│
[t15] ← Last input token
│
[d0, d1, ..., d14] ← 15 decode tokens
Cache Match Logic
Request 2 input: [t0, t1, ..., t14, t15] (16 tokens)
Cache searches: input_ids[: input_len - 1] = input_ids[:15]
↑________________________↑
Match! 15 tokens found in RadixTree
Why exclude the last token?
# scheduler/cache.py:31
handle, match_indices = self.manager.match_prefix(req.input_ids[: input_len - 1])
- Last input token must be processed to predict the next token
- If cached → nothing to compute in prefill
- Standard LLM inference pattern
4.3 Request Flow
Step 1: Cache Match
CacheManager.match_req()
├── Search RadixTree for input_ids[:15]
├── Found match → cached_len = 15
└── Lock cached pages (prevent eviction)
Step 2: Allocate Resources
extend_len = input_len - cached_len = 16 - 15 = 1
TableManager.allocate() → slot = 255
CacheManager.allocate(1) → 1 page (only for the last token!)
Step 3: Prefill Forward
┌─────────────────────────────────────────────────────────────────┐
│ Input: [t15] (ONLY 1 token) │
│ KV Cache: REUSE 15 tokens from cache │
│ Forward: Compute attention for 1 token with cached KV │
│ Output: logits for t15 → sample first output token │
└─────────────────────────────────────────────────────────────────┘
Attention shapes:
Q: [1, num_heads, head_dim] ← 1 new token
K: [16, num_heads, head_dim] ← 15 cached + 1 new
V: [16, num_heads, head_dim] ← 15 cached + 1 new
Step 4: Decode Loop (15 iterations)
Same as first request - allocate 1 page per step.
Step 5: Cleanup & Cache Update
TableManager.free(slot=255)
CacheManager.free_and_cache_finished_req()
├── Unlock old cached pages
├── Insert new sequence into RadixTree
└── Free redundant pages
4.4 Deep Dive: Cache Update Mechanics
Understanding cached_len Evolution
Warning:
cached_len = 15does NOT mean RadixTree only has 15 tokens!
- RadixTree has 31 tokens (from request 1: 16 input + 15 decode)
cached_len = 15means “15 tokens of request 2’s INPUT match the tree”
cached_len device_len Status
────────── ────────── ──────
Initial (cache match): 15 16 Ready for prefill
After prefill: 16 17 First output sampled
After decode 1: 17 18
...
After decode 15: 31 32 remain_len=0 → DONE
Page Ownership After Cleanup
When free_and_cache_finished_req is called with 31 pages:
┌─────────────────────────────────────────────────────────────────────────┐
│ 31 Pages of Request 2 │
├─────────────────┬───────────────────┬───────────────────────────────────┤
│ 0-14 (15) │ 15-23 (9) │ 24-30 (7) │
├─────────────────┼───────────────────┼───────────────────────────────────┤
│ From cache │ Allocated but │ Allocated, NEW tokens │
│ (locked) │ REDUNDANT │ (different decode output) │
├─────────────────┼───────────────────┼───────────────────────────────────┤
│ → UNLOCK │ → FREE │ → ADD TO CACHE │
│ (decr ref_count)│ (return to pool) │ (new RadixTree nodes) │
└─────────────────┴───────────────────┴───────────────────────────────────┘
RadixTree After Request 2
ROOT
│
[t0...t14] (15) ← Shared (ref_count++)
│
[t15] (1) ← Shared
│
[d0...d7] (8) ← Shared (if same output)
│
┌──────┴──────┐
[d8...d14] [d8'...d14']
(Request 1) (Request 2, new nodes)
4.5 Performance Summary
┌─────────────────────────────────────────────────────────────────┐
│ PREFILL COMPUTATION │
├─────────────────────────────────────────────────────────────────┤
│ First Request: │
│ Compute: 16 tokens × 28 layers = 448 layer forwards │
│ Memory: Allocate 16 pages │
│ │
│ Second Request: │
│ Compute: 1 token × 28 layers = 28 layer forwards │
│ Memory: Allocate 1 page (reuse 15 from cache) │
│ Speedup: ~16x faster prefill! │
└─────────────────────────────────────────────────────────────────┘
Partial cache hit:
Request 1: "What is 2 + 2?" → cached_len = 0
Request 2: "What is 2 + 2?" → cached_len = 15 (full match)
Request 3: "What is 2 + 3?" → cached_len = 12 (partial: "What is 2 + ")
Next: we’ll cover:
- Batching multiple requests
- Overlap scheduling
- Tensor parallelism (TP > 1)