Skip to content

Mini SGLang (Part 1) - Architecture, Engine & Request Flow

Table of contents

Open Table of contents

Key Concepts

Before diving in, here are the key terms you’ll encounter:

ConceptDescription
PrefillProcess all input tokens at once, compute full KV cache
DecodeGenerate 1 token/step, reuse cached KV
Paged KV CacheAllocate KV cache in fixed-size pages, not contiguous blocks
Page TableMaps (request_id, token_pos) → page_index in KV cache pool
CUDA GraphPre-captured GPU ops for decode (fixed shapes only)
Prefix CacheReuse KV cache across requests with same prefix
ZMQ IPCInter-process communication via Unix sockets (Push/Pull)
Chunked PrefillSplit large inputs into multiple batches (≤ prefill_budget)

1. Architecture & Startup Flow

We start by understanding how Mini SGLang organizes its processes and how the system boots up.

Mini SGLang uses a multi-process architecture with ZMQ for inter-process communication.

1.1 System Architecture

┌────────────────────────────────────────────────────────────┐
│                   MAIN PROCESS (FastAPI)                   │
│                     FrontendManager                        │
│       - HTTP endpoints (/v1/chat/completions)              │
│       - Send/receive via ZMQ                               │
└────────────────────────────────────────────────────────────┘
      │ ZMQ (ipc:///tmp/minisgl_4)        ▲ ZMQ (ipc:///tmp/minisgl_3)
      ▼                                   │
┌─────────────────┐              ┌─────────────────┐
│  TOKENIZER      │              │  DETOKENIZER    │
│  text → tokens  │              │  tokens → text  │
└─────────────────┘              └─────────────────┘
      │ ZMQ (ipc:///tmp/minisgl_0)        ▲ ZMQ (ipc:///tmp/minisgl_1)
      └────────────┐          ┌──────────┘
                   ▼          │
┌────────────────────────────────────────────────────────────┐
│                     SCHEDULER PROCESS                      │
│  ┌──────────────────────────────────────────────────────┐  │
│  │ Scheduler                                            │  │
│  │  ├── PrefillManager (pending requests)               │  │
│  │  ├── DecodeManager (running requests)                │  │
│  │  ├── CacheManager (KV cache pages)                   │  │
│  │  └── TableManager (request slots)                    │  │
│  └──────────────────────────────────────────────────────┘  │
│  ┌──────────────────────────────────────────────────────┐  │
│  │ Engine                                               │  │
│  │  ├── Model (Qwen3/Llama)                             │  │
│  │  ├── KVCache (paged memory)                          │  │
│  │  ├── AttnBackend (FlashInfer)                        │  │
│  │  └── GraphRunner (CUDA graphs for decode)            │  │
│  └──────────────────────────────────────────────────────┘  │
└────────────────────────────────────────────────────────────┘
ChannelAddressDirectionMessage Type
Frontend → Tokenizeripc:///tmp/minisgl_4Push/PullTokenizeMsg
Tokenizer → Scheduleripc:///tmp/minisgl_0Push/PullUserMsg
Scheduler → Detokenizeripc:///tmp/minisgl_1Push/PullDetokenizeMsg
Detokenizer → Frontendipc:///tmp/minisgl_3Push/PullUserReply
Scheduler Broadcastipc:///tmp/minisgl_2Pub/Sub(for TP > 1)

1.2 Startup Flow

python -m minisgl --model "Qwen/Qwen3-0.6B"

Phase 1: Entry Point & Parse Arguments

📁 python/minisgl/__main__.py
📁 python/minisgl/server/launch.py:40 → launch_server()
📁 python/minisgl/server/args.py:54 → parse_args()
ServerArgs(
    model_path="Qwen/Qwen3-0.6B",
    dtype=torch.bfloat16,
    tp_info=DistributedInfo(rank=0, size=1),
    server_host="127.0.0.1",
    server_port=1919,
)

Phase 2: Create FrontendManager (Main Process)

📁 python/minisgl/server/api_server.py:408 → run_api_server()

FrontendManager initializes ZMQ connections:

Phase 3: Spawn Worker Processes

📁 python/minisgl/server/launch.py:47 → start_subprocess()

Uses multiprocessing.Process with spawn method:

# 1. Spawn Scheduler
mp.Process(
    target=_run_scheduler,
    args=(new_args, ack_queue),
    name="minisgl-TP0-scheduler",
).start()

# 2. Spawn Detokenizer
mp.Process(
    target=tokenize_worker,
    kwargs={
        "addr": "ipc:///tmp/minisgl_1",      # Listen here
        "backend_addr": "ipc:///tmp/minisgl_0",
        "frontend_addr": "ipc:///tmp/minisgl_3",  # Send to frontend
    },
    name="minisgl-detokenizer-0",
).start()

# 3. Spawn Tokenizer
mp.Process(
    target=tokenize_worker,
    kwargs={
        "addr": "ipc:///tmp/minisgl_4",      # Listen here
        "backend_addr": "ipc:///tmp/minisgl_0",   # Send to scheduler
    },
    name="minisgl-tokenizer-0",
).start()

Phase 4: Scheduler Process Initialization

📁 python/minisgl/server/launch.py:16 → _run_scheduler()
📁 python/minisgl/scheduler/scheduler.py:80 → Scheduler.__init__()

This runs in a SEPARATE PROCESS!

# 1. Create Engine (loads model, creates KV cache)
self.engine = Engine(config)

# 2. Initialize ZMQ I/O
self.io = SchedulerIO(
    recv_addr="ipc:///tmp/minisgl_0",   # Receive from tokenizer
    send_addr="ipc:///tmp/minisgl_1",   # Send to detokenizer
)

# 3. Create managers
self.prefill_manager = PrefillManager(...)
self.decode_manager = DecodeManager()
self.cache_manager = CacheManager(...)
self.table_manager = TableManager(...)

# 4. Send acknowledgment
ack_queue.put("Scheduler is ready")

Phase 5: Tokenizer/Detokenizer Initialization

📁 python/minisgl/tokenizer/server.py:30 → tokenize_worker()

This runs in SEPARATE PROCESS(es)!

# 1. Create ZMQ connections
send_backend = zmq.Socket(zmq.PUSH)  # To scheduler
send_frontend = zmq.Socket(zmq.PUSH) # To frontend (detokenizer only)
recv_listener = zmq.Socket(zmq.PULL) # Listen for messages

# 2. Load HuggingFace tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

# 3. Create managers
tokenize_manager = TokenizeManager(tokenizer)
detokenize_manager = DetokenizeManager(tokenizer)

# 4. Send acknowledgment
ack_queue.put("Tokenize server 0 is ready")

Phase 6: Wait for All Workers & Start Server

📁 python/minisgl/server/launch.py:110
📁 python/minisgl/server/api_server.py:454
# Main process waits for acks
for _ in range(num_tokenizers + 2):
    logger.info(ack_queue.get())
# Output:
# [INFO] Scheduler is ready
# [INFO] Tokenize server 0 is ready

# Start Uvicorn
uvicorn.run(app, host="127.0.0.1", port=1919)
# [INFO] Uvicorn running on http://127.0.0.1:1919

Phase 7: Scheduler Main Loop

📁 python/minisgl/scheduler/scheduler.py:288 → run_forever()
@torch.inference_mode()
def run_forever(self) -> NoReturn:
    data = None
    while True:
        data = self.overlap_loop(data)
        # Continuously:
        # 1. Receive messages from tokenizer
        # 2. Schedule prefill/decode batches
        # 3. Execute forward pass
        # 4. Send results to detokenizer

1.3 Process Summary

Process NameMain FunctionZMQ Role
Main (FastAPI)run_api_server()Send to tokenizer, recv from detokenizer
minisgl-TP0-schedulerrun_forever()Recv from tokenizer, send to detokenizer
minisgl-tokenizer-Ntokenize_worker()Recv from frontend, send to scheduler
minisgl-detokenizer-0tokenize_worker()Recv from scheduler, send to frontend

2. Engine Initialization

With the multi-process architecture in place, let’s zoom into the Scheduler process - the brain of the system. When it starts, the Engine is initialized with these key parameters:

2.1 Key Parameters Overview

ParameterDefaultDescriptionDetermined By
max_running_req256Max concurrent requests in systemConfig
max_seq_len40960Max tokens per request (input + output)Model’s RoPE max_position
prefill_budget8192Max tokens per prefill batchConfig
num_kv_pages~180KTotal KV cache pages availableGPU memory
cuda_graph_max_bs160Max batch size for CUDA graphGPU memory

How these relate:

┌─────────────────────────────────────────────────────────────────────┐
│                         CAPACITY LIMITS                             │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  max_running_req = 256                                              │
│  └── How many requests can run simultaneously                       │
│      └── Each request gets 1 slot in TableManager                   │
│      └── Page table shape: (257, max_seq_len)                       │
│                                                                     │
│  max_seq_len = 40960                                                │
│  └── Max context length per request                                 │
│      └── input_tokens + output_tokens ≤ 40960                       │
│      └── From model config (RoPE max_position)                      │
│                                                                     │
│  prefill_budget = 8192                                              │
│  └── Max tokens processed in one prefill batch                      │
│      └── Controls latency (don't block decode too long)             │
│      └── Large inputs are chunked                                   │
│                                                                     │
│  num_kv_pages ≈ 180874                                              │
│  └── Total memory for KV cache                                      │
│      └── Shared across ALL requests                                 │
│      └── Each token needs 1 page (112KB for Qwen3-0.6B)             │
│                                                                     │
│  cuda_graph_max_bs = 160                                            │
│  └── Largest batch size with pre-captured CUDA graph                │
│      └── Decode batches > 160 run without graph (slower)            │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

Example capacity calculation:

# With Qwen3-0.6B on 24GB GPU:
num_kv_pages = 180874
max_running_req = 256
max_seq_len = 40960

# Theoretical max (if all requests use max length):
# 256 requests × 40960 tokens = 10485760 pages needed
# But we only have 180874 pages!

# Practical limit:
# 180874 pages / 256 requests ≈ 706 tokens/request average
# OR fewer long requests + more short requests

Prefill budget limits tokens per batch to control memory usage and latency. Large inputs are chunked:

prefill_budget = 8192 (max tokens per prefill batch)

input_len=5000   → 1 batch
input_len=10000  → 2 batches: 8192 + 1808

2.2 KV Cache & Page Table Relationship

Problem: LLM inference needs to store Key-Value cache for each processed token. If we allocate contiguous memory for each request → memory fragmentation & waste.

Solution: Paged KV Cache - divide memory into fixed-size pages, allocate on-demand.

Note: In Mini SGLang, page_size = 1, meaning 1 page = 1 token. This simplifies the mapping: number of pages = number of tokens. Production systems like vLLM/SGLang often use larger page sizes (e.g., 16 tokens/page) for efficiency.

┌───────────────────────────────────────────────────────────────────────┐
│                            GPU MEMORY                                 │
├───────────────────────────────────────────────────────────────────────┤
│                                                                       │
│  KV Cache Pool (180874 pages)                                         │
│  ┌───────┬───────┬───────┬───────┬───────┬───────┬───────────────┐    │
│  │ Page 0│ Page 1│ Page 2│ Page 3│ Page 4│ Page 5│...Page 180873 │    │
│  │ 112KB │ 112KB │ 112KB │ 112KB │ 112KB │ 112KB │     112KB     │    │
│  └───────┴───────┴───────┴───────┴───────┴───────┴───────────────┘    │
│      ▲       ▲       ▲       ▲       ▲       ▲                        │
│      │       │       │       │       │       │                        │
│      └───────┴───────┴───────┴───────┴───────┘                        │
│              │              │                                         │
│    ┌─────────┘              └─────────┐                               │
│    │                                  │                               │
├────┼──────────────────────────────────┼───────────────────────────────┤
│    │      Page Table (257 × 40960)    │                               │
│    │                                  │                               │
│  ┌─┴────────────────────────────────┬─┴─────────────────────────────┐ │
│  │ req_idx │ pos 0 │ pos 1 │ pos 2 │ pos 3 │ pos 4 │ ...            │ │
│  ├─────────┼───────┼───────┼───────┼───────┼───────┼────────────────┤ │
│  │ Req 0   │   0   │   1   │   2   │  -1   │  -1   │ (3 tokens)     │ │
│  │ Req 1   │   3   │   4   │   5   │  -1   │  -1   │ (3 tokens)     │ │
│  │ ...     │  ...  │  ...  │  ...  │  ...  │  ...  │                │ │
│  │ Req 256 │ dummy │ dummy │ dummy │ dummy │ dummy │ (CUDA graph)   │ │
│  └─────────┴───────┴───────┴───────┴───────┴───────┴────────────────┘ │
│                                                                       │
└───────────────────────────────────────────────────────────────────────┘

Page Table[req_idx][token_pos] = page_index in KV Cache Pool

How It Works

1. KV Cache Pool - Pre-allocated GPU memory:

# Each page stores KV for 1 token (page_size=1) across all layers
cache_per_page = (
    2                    # key + value
    * head_dim           # 128
    * num_kv_heads       # 8
    * page_size          # 1 token
    * dtype.itemsize     # 2 bytes (bfloat16)
    * num_layers         # 28 layers
)
# = 2 * 128 * 8 * 1 * 2 * 28 = 114688 bytes ≈ 112 KB/page

num_pages = available_gpu_memory // cache_per_page
# Example: 19.32 GiB → 180874 pages

2. Page Table - Maps (request, position) → page:

# Shape: (max_running_req + 1, max_seq_len) = (257, 40960)
page_table = torch.zeros((257, 40960), dtype=torch.int32, device="cuda")

# When we need KV for token position 5 of request 0:
page_idx = page_table[0, 5]  # → page index in KV Cache Pool
kv_data = kv_cache_pool[page_idx]  # → actual KV tensors

3. Allocation Flow:

Request arrives (input_len=16):

    ├── CacheManager.allocate(16)
    │   └── Take 16 free pages from pool: [0, 1, 2, ..., 15]

    ├── TableManager.allocate()
    │   └── Take 1 free slot: slot=0

    └── Update page_table[0, 0:16] = [0, 1, 2, ..., 15]

Decode step (generate 1 token):

    ├── CacheManager.allocate(1)
    │   └── Take 1 free page: [16]

    └── Update page_table[0, 16] = 16

Final state for request 0:
    page_table[0] = [0, 1, 2, ..., 16, -1, -1, ...]
                     └── 17 pages allocated ──┘

4. Why Paged?

ApproachMemory UsageFlexibility
ContiguousPre-allocate max_seq_len per request → wasteFixed size
PagedAllocate on-demand → efficientDynamic growth
Example: 2 requests with max_seq_len=1000:

Contiguous:
┌──────────────────────────────────────┐
│ Req 0: [████████░░░░░░░░░░░░░░░░░░░] │  ← 200 tokens, waste 800
│ Req 1: [██████████████░░░░░░░░░░░░░] │  ← 500 tokens, waste 500
└──────────────────────────────────────┘
Total waste: 1300 slots

Paged:
┌──────────────────────────────────────┐
│ Pool: [████████████████████████░░░░] │  ← 700 pages used
│ Req 0 pages: [0,1,2,...,199]         │
│ Req 1 pages: [200,201,...,699]       │
└──────────────────────────────────────┘
Total waste: 0 (pages allocated on-demand)

CUDA Graph Capture

Pre-captured for decode phase only (fixed shapes):

batch_sizes = [1, 2, 4] + list(range(8, 161, 8))
# = [1, 2, 4, 8, 16, 24, 32, ..., 160]
# Total: 23 graphs

# During inference, batch is padded to nearest graph size
# Example: batch_size=5 → pad to 8

Why decode only? Prefill has variable sequence lengths, CUDA graphs need fixed shapes.


3. First Request

Now that we understand the Engine’s memory layout and CUDA graphs, let’s trace a single request through the entire system - from HTTP arrival to response streaming.

3.1 Overview

┌─────────────────────────────────────────────────────────────────────┐
│                        REQUEST LIFECYCLE                            │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  Step 1: HTTP Request                                               │
│     └── Client → FastAPI → Tokenizer (text → tokens)                │
│                                                                     │
│  Step 2: Prefill                                                    │
│     └── Process ALL input tokens at once                            │
│     └── Allocate KV cache pages, compute KV, sample 1st token       │
│                                                                     │
│  Step 3: Decode Loop (repeat until done)                            │
│     └── Process 1 token per step (using CUDA graph)                 │
│     └── Allocate 1 page, compute KV, sample next token              │
│                                                                     │
│  Step 4: Stream Response                                            │
│     └── Detokenizer → FastAPI → Client (SSE streaming)              │
│                                                                     │
│  Step 5: Cleanup                                                    │
│     └── Free resources (table slot, KV pages)                       │
│     └── Optionally cache prefix for reuse                           │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

3.2 Request Flow

Example: "What is 2+2?" with max_tokens=16, temperature=0.0

Step 1: HTTP → Frontend → Tokenizer

POST /v1/chat/completions


FrontendManager:
    - Assign uid=0 (first request gets uid=0)
    - Create TokenizeMsg(uid=0, text="What is 2+2?", sampling_params=...)

    ▼ ZMQ Push (ipc:///tmp/minisgl_4)

Tokenizer Process:
    - Apply chat template + encode → input_ids (16 tokens)
    - Create UserMsg(uid=0, input_ids=tensor([...]), sampling_params=...)

    ▼ ZMQ Push (ipc:///tmp/minisgl_0)

Why 16 tokens? The raw text “What is 2+2?” is only ~5 tokens, but after applying chat template:

<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n
<|im_start|>user\nWhat is 2+2?<|im_end|>\n
<|im_start|>assistant\n

The full prompt becomes ~16 tokens (varies by model’s chat template).

Step 2: Scheduler Prefill

Scheduler receives UserMsg:

    ├── PrefillManager.add_one_req(msg)
    │   └── pending_list.append(PendingReq(...))

    ├── Schedule prefill batch:
    │   ├── CacheManager.match_req() → cached_len=0 (no prefix cache hit)
    │   ├── TableManager.allocate() → slot=255
    │   └── CacheManager.allocate(16) → 16 pages

    └── Engine.forward_batch(phase="prefill"):
        ├── Direct model forward (no CUDA graph)
        ├── Process all 16 input tokens in parallel
        ├── Compute KV cache for all 16 positions
        └── Sample first output token

Key values explained:

ValueExplanation
cached_len=0No prefix cache hit (first request with this prompt)
slot=255TableManager allocates from end: free_slots=[0,1,...,255] → pop last → 255
16 pages1 page per input token (page_size=1), so 16 tokens need 16 pages
no CUDA graphPrefill has variable seq_len, can’t use pre-captured graphs

What happens inside Engine.forward_batch(phase="prefill"):

Input: [tok_0, tok_1, ..., tok_15]  (16 tokens)

1. Embedding → X shape: (16, d_model)

2. For each layer (×28):
   ┌──────────────────────────────────────────────────────────────────────┐
   │ a) Compute Q, K, V for ALL 16 tokens in parallel                     │
   │                                                                      │
   │ b) Write K, V to the 16 pages allocated by CacheManager:             │
   │    kv_cache[layer, 0, page_0:page_15] = K                            │
   │    kv_cache[layer, 1, page_0:page_15] = V                            │
   │                                                                      │
   │ c) Self-attention with causal mask:                                  │
   │    scores = Q × K^T / √d_k      shape: (16, 16)                      │
   │         K_0  K_1  K_2  ... K_15                                      │
   │    Q_0   ✓    ✗    ✗   ...  ✗   ← Each token only attends            │
   │    Q_1   ✓    ✓    ✗   ...  ✗      to previous tokens                │
   │    ...                                                               │
   │    Q_15  ✓    ✓    ✓   ...  ✓   ← Last token sees all                │
   │                                                                      │
   │ d) FFN → next layer                                                  │
   └──────────────────────────────────────────────────────────────────────┘

3. Sample from LAST position only:
   logits = X[15] × W_vocab  → tok_16 = first output token

Memory state after prefill:

KV Cache pages: 180874 - 16 = 180858 free
Page table[slot=255, pos 0:16] = [page_0, page_1, ..., page_15]

KV Cache content (per layer):
  K_cache[page_0:page_15] = [K_0, K_1, ..., K_15]   ← 16 key vectors
  V_cache[page_0:page_15] = [V_0, V_1, ..., V_15]   ← 16 value vectors

Output: tok_16 sampled (first output token) → Ready for decode!

Step 3: Decode Loop (×15)

For each decode step:

    ├── DecodeManager.schedule_next_batch()
    │   └── Create batch with running requests

    ├── CacheManager.allocate(1) → 1 new page for new token

    ├── Engine.forward_batch(phase="decode"):
    │   ├── CUDA graph replay (batch_size=1)
    │   ├── Process only 1 new token (last generated)
    │   └── Attention reads from cached KV (16 + n positions)

    └── Send DetokenizeMsg to detokenizer

        ▼ ZMQ Push (ipc:///tmp/minisgl_1)

What happens inside Engine.forward_batch(phase="decode"):

Decode step 1 (generate token 17):

┌─────────────────────────────────────────────────────────────────────────┐
│ Prefill (16 tokens):              Decode (1 new token):                 │
│ ────────────────────              ─────────────────────                 │
│ Q = [x_0..x_15] × W_q             Q_16 = x_16 × W_q  ← ONLY 1 token     │
│ K = [x_0..x_15] × W_k             K_16 = x_16 × W_k  ← ONLY 1 token     │
│ V = [x_0..x_15] × W_v             V_16 = x_16 × W_v  ← ONLY 1 token     │
│                                                                         │
│ KV Cache (from prefill): K_cached = [K_0..K_15], V_cached = [V_0..V_15] │
│                                                                         │
│ Decode computation:                                                     │
│   1. Compute K_16, V_16, Q_16       ← Only 3 projections (not 48!)      │
│   2. Write to cache: K[16] = K_16, V[16] = V_16                         │
│   3. Attention: Q_16 × [K_0..K_16]^T  ← Read ALL 17 K vectors           │
│   4. Output: softmax(scores) × [V_0..V_16]                              │
│   5. Sample token 17                                                    │
└─────────────────────────────────────────────────────────────────────────┘

Why decode is MEMORY-BOUND (not compute-bound):

Saved by KV cache (not recomputed):
  K_0..K_15, V_0..V_15 → read from cache, not recomputed
  → Saved: 16 × 2 = 32 projection ops per layer

Still required (grows with seq_len):
  Attention = Q_16 × [K_0..K_16]^T = 17 dot products
  Step 1:  17 ops → Step 2: 18 ops → ... → Step 15: 31 ops

The bottleneck:
  - Must READ entire KV cache each step (seq_len × hidden_dim × 2)
  - But only COMPUTE for 1 token → low arithmetic intensity
  - GPU waits for memory → this is why batching helps throughput!

Why 15 decode steps?

max_tokens = 16 (total output tokens requested)

Prefill:  samples 1 token  → remain_len = 16 - 1 = 15
Decode 1: samples 1 token  → remain_len = 15 - 1 = 14
Decode 2: samples 1 token  → remain_len = 14 - 1 = 13
...
Decode 15: samples 1 token → remain_len = 1 - 1 = 0 → STOP

Total output: 1 (prefill) + 15 (decode) = 16 tokens ✓

Memory timeline:

Note: Pages are allocated to compute KV, not to store sampled tokens. The sampled token’s KV is computed in the NEXT step.

StepPages allocFree pagesKV computed for posToken sampled at pos
Before prefill-180874--
After prefill161808580-15 (input)16 (1st output)
After decode 111808571617 (2nd output)
After decode 211808561718 (3rd output)
After decode 1511808433031 (16th output)
Total31-31 positions16 output tokens

Why 31 pages for 32 tokens (16 input + 16 output)?

Why CUDA graph works for decode?

Step 4: Detokenizer → Frontend → Client

Detokenizer receives DetokenizeMsg:

    ├── Decode token_id → text (incremental detokenization)

    └── Create UserReply(uid=0, text="...", finished=False/True)

        ▼ ZMQ Push (ipc:///tmp/minisgl_3)

FrontendManager:
    └── Stream response to HTTP client (SSE)


data: {"choices": [{"delta": {"content": "The"}}]}
data: {"choices": [{"delta": {"content": " answer"}}]}
data: {"choices": [{"delta": {"content": " is"}}]}
data: {"choices": [{"delta": {"content": " 4"}}]}
...
data: [DONE]

Step 5: Completion & Cleanup

When finished=True (remain_len <= 0):

    ├── DecodeManager.remove_req(req)
    │   └── running_reqs: 1 → 0

    ├── TableManager.free(slot=255)
    │   └── free_slots: [0,...,254] → [0,...,255] (256 available)

    ├── CacheManager.free_and_cache_finished_req()
    │   └── Option 1: Free all 31 pages back to pool
    │   └── Option 2: Cache prefix (16 input pages) for future reuse

    └── Scheduler returns to idle, waiting for new requests

Final memory state:

If prefix cached:
  - Free pages: 180843 + 15 = 180858 (output pages freed)
  - Cached: 16 pages (input prefix for potential reuse)

If not cached:
  - Free pages: 180843 + 31 = 180874 (all pages freed)

Stopping Conditions

Generation stops when ANY condition is true:

finished = (
    req.remain_len <= 0           # Reached max_tokens
    or next_token == eos_token_id  # Model output EOS token
    or device_len >= max_seq_len   # Hit context length limit (40960)
)
ConditionExample
remain_len <= 0Generated 16 tokens as requested
next_token == EOSModel naturally ends response
device_len >= max_seq_lenInput + output hits 40960 tokens

4. Second Request - Prefix Cache Hit

The first request completed and its KV cache was stored in the RadixTree. What happens when an identical request arrives? This is where Prefix Caching shines.

4.1 Overview

┌─────────────────────────────────────────────────────────────────────────┐
│                         REQUEST 2 FLOW                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌──────────────┐     ┌──────────────┐     ┌──────────────┐             │
│  │   Tokenize   │────>│  Cache Match │────>│   Prefill    │             │
│  │  (16 tokens) │     │ (15 cached!) │     │  (1 token)   │             │
│  └──────────────┘     └──────────────┘     └──────────────┘             │
│                              │                    │                     │
│                              ▼                    ▼                     │
│                    ┌──────────────────────────────────────┐             │
│                    │           KV Cache                   │             │
│                    │  ┌─────────────────┬────────────┐    │             │
│                    │  │ Cached (15 pgs) │ New (1 pg) │    │             │
│                    │  │   [REUSED]      │ [COMPUTED] │    │             │
│                    │  └─────────────────┴────────────┘    │             │
│                    └──────────────────────────────────────┘             │
│                                      │                                  │
│                                      ▼                                  │
│                           ┌──────────────────┐                          │
│                           │   Decode Loop    │                          │
│                           │  (15 iterations) │                          │
│                           └──────────────────┘                          │
│                                      │                                  │
│                                      ▼                                  │
│                           ┌──────────────────┐                          │
│                           │  Update Cache    │                          │
│                           └──────────────────┘                          │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘
MetricFirst RequestSecond RequestSavings
cached_len015-
extend_len (tokens to compute)16194%
Pages allocated (prefill)16194%

Key Insight: Second request computes only 1 token instead of 16!

4.2 Why Prefix Cache Works

RadixTree State (before request 2)

                    ROOT

              [t0, t1, ..., t14]  ← 15 tokens cached from request 1

                    [t15]         ← Last input token

              [d0, d1, ..., d14]  ← 15 decode tokens

Cache Match Logic

Request 2 input: [t0, t1, ..., t14, t15]  (16 tokens)

Cache searches: input_ids[: input_len - 1] = input_ids[:15]
                ↑________________________↑
                Match! 15 tokens found in RadixTree

Why exclude the last token?

# scheduler/cache.py:31
handle, match_indices = self.manager.match_prefix(req.input_ids[: input_len - 1])

4.3 Request Flow

Step 1: Cache Match

CacheManager.match_req()
├── Search RadixTree for input_ids[:15]
├── Found match → cached_len = 15
└── Lock cached pages (prevent eviction)

Step 2: Allocate Resources

extend_len = input_len - cached_len = 16 - 15 = 1

TableManager.allocate() → slot = 255
CacheManager.allocate(1) → 1 page (only for the last token!)

Step 3: Prefill Forward

┌─────────────────────────────────────────────────────────────────┐
│ Input: [t15]  (ONLY 1 token)                                    │
│ KV Cache: REUSE 15 tokens from cache                            │
│ Forward: Compute attention for 1 token with cached KV           │
│ Output: logits for t15 → sample first output token              │
└─────────────────────────────────────────────────────────────────┘

Attention shapes:
  Q: [1, num_heads, head_dim]       ← 1 new token
  K: [16, num_heads, head_dim]      ← 15 cached + 1 new
  V: [16, num_heads, head_dim]      ← 15 cached + 1 new

Step 4: Decode Loop (15 iterations)

Same as first request - allocate 1 page per step.

Step 5: Cleanup & Cache Update

TableManager.free(slot=255)
CacheManager.free_and_cache_finished_req()
├── Unlock old cached pages
├── Insert new sequence into RadixTree
└── Free redundant pages

4.4 Deep Dive: Cache Update Mechanics

Understanding cached_len Evolution

Warning: cached_len = 15 does NOT mean RadixTree only has 15 tokens!

  • RadixTree has 31 tokens (from request 1: 16 input + 15 decode)
  • cached_len = 15 means “15 tokens of request 2’s INPUT match the tree”
                        cached_len    device_len    Status
                        ──────────    ──────────    ──────
Initial (cache match):      15            16        Ready for prefill
After prefill:              16            17        First output sampled
After decode 1:             17            18
...
After decode 15:            31            32        remain_len=0 → DONE

Page Ownership After Cleanup

When free_and_cache_finished_req is called with 31 pages:

┌─────────────────────────────────────────────────────────────────────────┐
│                        31 Pages of Request 2                            │
├─────────────────┬───────────────────┬───────────────────────────────────┤
│   0-14 (15)     │    15-23 (9)      │          24-30 (7)                │
├─────────────────┼───────────────────┼───────────────────────────────────┤
│ From cache      │ Allocated but     │ Allocated, NEW tokens             │
│ (locked)        │ REDUNDANT         │ (different decode output)         │
├─────────────────┼───────────────────┼───────────────────────────────────┤
│ → UNLOCK        │ → FREE            │ → ADD TO CACHE                    │
│ (decr ref_count)│ (return to pool)  │ (new RadixTree nodes)             │
└─────────────────┴───────────────────┴───────────────────────────────────┘

RadixTree After Request 2

                ROOT

          [t0...t14] (15)         ← Shared (ref_count++)

              [t15] (1)           ← Shared

          [d0...d7] (8)           ← Shared (if same output)

           ┌──────┴──────┐
    [d8...d14]        [d8'...d14']
    (Request 1)       (Request 2, new nodes)

4.5 Performance Summary

┌─────────────────────────────────────────────────────────────────┐
│                    PREFILL COMPUTATION                          │
├─────────────────────────────────────────────────────────────────┤
│ First Request:                                                  │
│   Compute: 16 tokens × 28 layers = 448 layer forwards           │
│   Memory: Allocate 16 pages                                     │
│                                                                 │
│ Second Request:                                                 │
│   Compute: 1 token × 28 layers = 28 layer forwards              │
│   Memory: Allocate 1 page (reuse 15 from cache)                 │
│   Speedup: ~16x faster prefill!                                 │
└─────────────────────────────────────────────────────────────────┘

Partial cache hit:

Request 1: "What is 2 + 2?"  → cached_len = 0
Request 2: "What is 2 + 2?"  → cached_len = 15 (full match)
Request 3: "What is 2 + 3?"  → cached_len = 12 (partial: "What is 2 + ")

Next: we’ll cover:

  • Batching multiple requests
  • Overlap scheduling
  • Tensor parallelism (TP > 1)

Share this post on: