YaRN

Yet another RoPE extensioN method

大型语言模型(LLM)的上下文窗口正在极速扩张。
YaRN 提供了一种优雅且计算零开销的方法,
通过动态调整旋转位置编码(RoPE),将上下文长度扩展数倍。

📖 TL;DR — What is YaRN?

Why do we need position encoding at all? Without it, a Transformer treats "The cat sat" and "sat The cat" identically. RoPE elegantly injects position by rotating each dimension at a unique frequency.

1. RoPE 旋转位置编码基础

在理解 YaRN 之前,我们必须先理解 RoPE。RoPE 将词向量的元素两两分组,并在复平面上进行旋转,旋转的角度与该词在句子中的绝对位置 mm 成正比。不同维度 dd 拥有不同的旋转频率 θd\theta_d

f(xm,m,θd)=eimθdxmf(x_m, m, \theta_d) = e^{im\theta_d} \cdot x_m θd=b2d/D\theta_d = b^{-2d/|D|}
🔍 Symbol Breakdown
f(xm,m,θd)f(x_m, m, \theta_d)The rotated embedding for token at position m, dimension d
eimθde^{im\theta_d}Complex rotation by angle m·θ_d (Euler's formula)
xmx_mInput embedding vector at position m
θd=b2d/D\theta_d = b^{-2d/|D|}Rotation frequency for dimension d; decreases exponentially with d
b=10000b = 10000Base constant (controls frequency range)
D|D|Total embedding dimension (e.g., 64)
# RoPE: Rotary Position Embedding
def rope_embed(x, position_m, dim_d, base=10000, D=64):
    # 1. 计算当前维度的旋转频率
    theta_d = base ** (-2 * dim_d / D)
    
    # 2. 计算旋转角度
    angle = position_m * theta_d
    cos_a, sin_a = cos(angle), sin(angle)
    
    # 3. 在 2D 子空间中应用复数旋转
    x_even_new = x_even * cos_a - x_odd * sin_a
    x_odd_new  = x_even * sin_a + x_odd * cos_a
    
    return x_even_new, x_odd_new
📝 Code Walkthrough
  1. theta_d = base ** (-2 * dim_d / D) — Compute the rotation frequency. Low d → high frequency, high d → low frequency.
  2. angle = position_m * theta_d — The rotation angle is simply position × frequency.
  3. x_even_new, x_odd_new = ... — Apply 2D rotation matrix: [cos -sin; sin cos] to each pair of dimensions.

关键在于,不同维度编码了不同的信息:
低维度(小 d)旋转极快,捕捉局部的相对距离;
高维度(大 d)旋转极慢,几乎静止,编码宏观的绝对位置信息。

RoPE Visualization
d=0 (Fast)
d=8
d=16
d=24 (Slow)
m = 0
Think of wavelength as "how many tokens fit in one full rotation cycle." Short wavelength = fast spinning = sensitive to nearby tokens. Long wavelength = slow spinning = tracks global position.

2. 波长与上下文长度

旋转频率 θd\theta_d 决定了该维度在序列中需要多少个 Token 才能完成一次完整的 2π2\pi 旋转。我们将这个跨度称为波长 (Wavelength) λd\lambda_d。同时定义预训练上下文长度 LL 与波长的比值 r(d)r(d)

λd=2πθd\lambda_d = \frac{2\pi}{\theta_d} r(d)=Lλdr(d) = \frac{L}{\lambda_d}
🔍 Symbol Breakdown
λd\lambda_dWavelength: token span for one full 2π rotation
θd\theta_dRotation frequency at dimension d
r(d)r(d)Frequency ratio indicating local/global behavior
LLOriginal pre-training context length
r1r \gg 1High-frequency dimensions (local detail)
r1r \ll 1Low-frequency dimensions (global position)
# Wavelength and ratio computation
def wavelength(dim_d, base=10000, D=64):
    theta = base ** (-2 * dim_d / D)
    lam = 2 * pi / theta                  # wavelength in tokens
    return lam

def freq_ratio(dim_d, L=4096):
    return L / wavelength(dim_d)           # r >> 1: high freq, r << 1: low freq
📝 Code Walkthrough
  1. wavelength() — Converts frequency to wavelength (tokens per cycle).
  2. freq_ratio() — Normalizes wavelength by context length L.
  3. r behavior — Large r means local/high-frequency; tiny r means global/low-frequency.

这个比值决定了该维度如何被处理:当 r1r \gg 1 时,该维度在训练窗口内完成多次循环 \rightarrow 高频(局部信息);当 r1r \ll 1 时,该维度几乎不旋转 \rightarrow 低频(全局位置信息)

Wavelength Chart
预训练上下文长度 (L) 4096
Imagine stretching a rubber band with markings. PI stretches ALL markings uniformly — the fine markings become unreadable. That's exactly what happens to high-frequency position info.

3. 线性插值 (PI) 的灾难

要扩展上下文到 sLs \cdot L,最直接的方法是 Position Interpolation (PI):把所有位置坐标缩小为 m/sm/s。这等价于将所有频率压缩为 θd/s\theta_d/s

g(m)=ms,h(θd)=θdsg(m) = \frac{m}{s}, \quad h(\theta_d) = \frac{\theta_d}{s}
🔍 Symbol Breakdown
g(m)=m/sg(m)=m/sPosition interpolation: compress positions by scale s
h(θd)=θd/sh(\theta_d)=\theta_d/sEquivalent frequency compression in every dimension
ssTarget extension multiplier (e.g., 8, 16, 32)
Uniform scalingApplies equally to all dimensions, including high-frequency ones
High-freq impactAdjacent-token angle differences shrink too much
ConsequenceModel loses local positional resolution
# Position Interpolation — the naive approach
def position_interpolation(position_m, scale_s):
    return position_m / scale_s    # compress ALL positions by s
    # Problem: high-freq dims lose resolution!
    # Adjacent tokens: angle_diff = theta_d / s → too small
📝 Code Walkthrough
  1. position_m / scale_s — Uniformly compresses every position index.
  2. Implicit effect — Equivalent to shrinking every θ_d by the same factor.
  3. Failure mode — High-frequency dimensions can no longer separate nearby tokens.

这对高频维度(低 d)是灾难性的。高频维度原本用来区分非常近的 Token,将它们硬性拉伸后,相邻 Token 的相对角度差变得过小,导致模型“近视”,丢失了细粒度的相对位置分辨能力。

Position Interpolation Chart
缩放比例 (s) 1.0x
Instead of stretching positions, NTK-Aware changes the "ruler" itself (the base). This distributes the stretching pressure more evenly across all frequency bands.

4. NTK-Aware 插值 (前置基础)

在引入分段 NTK 之前,我们先看看 NTK-Aware Interpolation。它不直接缩放位置,而是修改底数 bb,将插值的压力分散到所有维度上:

b=bsD/(D2)b' = b \cdot s^{|D|/(|D|-2)}
🔍 Symbol Breakdown
bb'Updated RoPE base used for extended context
bbOriginal RoPE base (usually 10000)
ssContext extension multiplier
D|D|Embedding dimension count
Exponent D/(D2)|D|/(|D|-2)Dimension-dependent correction term
Main ideaReparameterize frequencies instead of raw positions
# NTK-Aware: modify the base instead of positions
def ntk_aware_base(base, scale_s, D=64):
    return base * scale_s ** (D / (D - 2))
    # Spreads interpolation pressure across ALL dimensions
    # High-freq dims: less affected (good!)
    # Low-freq dims: more affected (intended)
📝 Code Walkthrough
  1. ntk_aware_base() — Computes a new base b′ from s and D.
  2. Effect — Frequency shift becomes dimension-aware, less harsh than PI.
  3. Tradeoff — Better than PI, but still not fully selective like YaRN.
YaRN's key insight: don't treat all frequencies the same. High-frequency dims already know local patterns — leave them alone. Only interpolate the low-frequency dims that actually need to handle longer contexts.

4.1 YaRN 核心:分段 NTK 插值

YaRN 的精髓在于 NTK-by-parts:不要一刀切地缩放。引入比值 $r = L / \lambda_d$,并定义两个阈值 $\alpha=1, \beta=32$:

■ 高频 (r > \beta):波长远小于L。完全不插值,保持原样。
■ 中频 (\alpha < r < \beta):使用斜坡函数平滑过渡。
■ 低频 (r < \alpha):波长大于L。执行 PI 线性插值 (除以s)。

h(θd)=(1γ(r))θds+γ(r)θdh(\theta_d) = (1-\gamma(r)) \cdot \frac{\theta_d}{s} + \gamma(r) \cdot \theta_d γ(r)={0rlt;α1rgt;βrαβαotherwise\gamma(r) = \begin{cases} 0 & r &lt; \alpha \\ 1 & r &gt; \beta \\ \frac{r-\alpha}{\beta-\alpha} & \text{otherwise} \end{cases}
🔍 Symbol Breakdown
h(θd)h(\theta_d)Final blended frequency for dimension d
γ(r)\gamma(r)Ramp factor controlling interpolation intensity
α,β\alpha, \betaThresholds separating interpolate/transition/keep zones
γ=0\gamma=0Pure PI mode for very low-frequency dimensions
γ=1\gamma=1Keep original frequency for high-frequency dimensions
Middle regimeLinear blend for smooth continuity and stability
# YaRN Core: NTK-by-parts interpolation
def yarn_frequency(theta_d, dim_d, L=4096, scale_s=4, alpha=1, beta=32):
    lam = 2 * pi / theta_d
    r = L / lam                            # wavelength ratio

    # Ramp function γ(r)
    if r < alpha:
        gamma = 0.0                        # low freq → full PI
    elif r > beta:
        gamma = 1.0                        # high freq → keep original
    else:
        gamma = (r - alpha) / (beta - alpha)  # smooth blend

    # Blended frequency
    theta_new = (1 - gamma) * (theta_d / scale_s) + gamma * theta_d
    return theta_new
📝 Code Walkthrough
  1. r = L / λ — Measures whether a dimension behaves globally or locally.
  2. Piecewise γ — Sets full interpolate, full keep, or smooth transition.
  3. theta_new blend — Mixes θ_d/s and θ_d based on γ(r).

让我们用具体数字(以 LLaMA 为例,L=4096,b=10000L=4096, b=10000)来看三个极端情况:

1. 最高频 (d=0)θ0=1.0\theta_0 = 1.0,波长 λ06.28\lambda_0 \approx 6.28。比值 $r \approx 652 \gg \beta$。此时 $\gamma=1$,完全使用原频率 $\theta_0$,保留局部信息。
2. 最低频 (d=31)θ31\theta_{31} 极小,波长远大于 LL。比值 $r \approx 0.00065 \ll \alpha$。此时 $\gamma=0,使用全量PI插值,使用全量 PI 插值\theta_{31}/s$。
3. 中频:按比例平滑混合。

目标扩展倍数 (s) 16x

4.5. 动态 NTK 与直觉 (Intuition)

在实际推理时,我们并不知道最终序列长度。YaRN 使用动态缩放因子,根据当前序列长度 ll' 实时计算:

s=max(1,lL)s = \max(1, \frac{l'}{L})
🔍 Symbol Breakdown
ssDynamic context scaling factor used at inference time
ll'Current sequence length during autoregressive decoding
LLOriginal pre-training max context length
max(1,)\max(1,\cdot)Prevents shrinking frequencies for short sequences
s=1s=1No context extension; default RoPE behavior is preserved
s>1s>1Activates YaRN interpolation/correction for long-context inference
Practical impactSingle checkpoint can flexibly handle mixed sequence lengths
Key advantageNo extra fine-tuning needed for each target length

这意味着模型可以在不进行任何微调(Fine-tuning)的情况下,在推理时动态适应不同的上下文长度。

为什么分段 NTK 如此有效?(Intuition)
核心洞察是:高频维度编码的是局部距离(相邻 Token)。它们的波长极短,在预训练窗口 LL 内已经历了无数个完整周期。它们已经见过了所有可能的相对位置,如果强行插值,反而会破坏这种细粒度。而低频维度编码的是全局位置,波长超过 LL,在训练时连一个完整周期都没走完。这些维度面对更长的序列时,才真正需要通过 PI 进行“插值外推”。NTK-by-parts 完美地用数学将这一直觉落地。

After interpolation, the attention scores' distribution changes (like adjusting the "temperature" of a softmax). YaRN adds a tiny correction factor to restore the original distribution shape.

5. 注意力缩放 (Attention Scaling)

YaRN 发现,在插值后,注意力的 Logits 分布会发生变动(平均长度变大导致方差改变)。为了修正这种温度偏差,YaRN 在注意力计算后乘以一个极小但关键的缩放系数:

t=(0.1ln(s)+1)2t = \left(0.1 \cdot \ln(s) + 1\right)^{-2}
🔍 Symbol Breakdown
ttAttention correction factor
ssCurrent extension scale factor
0.1ln(s)+10.1\ln(s)+1Logarithmic temperature compensation term
Exponent $-2$Stabilizes scaling as s grows large
s=1s=1t = 1, so original attention behavior remains unchanged
s>1s>1t decreases, cooling shifted attention logits
# Attention temperature correction
def attention_scale(scale_s):
    t = (0.1 * log(scale_s) + 1) ** (-2)
    return t    # multiply attention logits by sqrt(t)
    # s=1 → t=1.0 (no change)
    # s=16 → t≈0.56 (cool down attention)
📝 Code Walkthrough
  1. Compute t from s — Uses logarithmic scaling for smooth correction growth.
  2. Return factor — Apply to attention logits (typically via √t).
  3. Behavior — No change at s=1, stronger cooling as scale increases.
Attention Scaling Chart

6. YaRN 完整流程总结

结合上述所有组件,端到端的 YaRN Pipeline 如下所示:

# Complete YaRN Pipeline
def yarn_rope(x, position_m, L=4096, base=10000, D=64, alpha=1, beta=32, seq_len=8192):
    scale_s = max(1, seq_len / L)  # dynamic scaling
    
    for d in range(D // 2):
        theta = base ** (-2 * d / D)
        lam = 2 * pi / theta
        r = L / lam
        
        # NTK-by-parts
        if r < alpha:    gamma = 0.0       # → PI interpolation
        elif r > beta:   gamma = 1.0       # → keep original
        else:            gamma = (r - alpha) / (beta - alpha)
        
        theta_new = (1 - gamma) * theta / scale_s + gamma * theta
        angle = position_m * theta_new
        
        # Apply rotation
        x[2*d], x[2*d+1] = rotate_2d(x[2*d], x[2*d+1], angle)
    
    # Attention scaling
    t = (0.1 * log(scale_s) + 1) ** (-2)
    return x, t
📝 Code Walkthrough
  1. dynamic scale_s — Computes runtime scale from current sequence length.
  2. per-dimension NTK-by-parts — Calculates γ and blended frequency per dim.
  3. final output — Returns rotated embeddings plus attention correction t.
128k无缝支持的上下文长度
10x微调 Token 量降低
2.5x所需训练步数减少

7. Method Comparison & Real-World Configs

How do PI, NTK-Aware, and YaRN actually compare? This table summarizes the tradeoffs, and the configs below show how to use YaRN in production.

Below is a side-by-side comparison of the three main context extension methods, followed by real-world configurations:

Feature PI NTK-Aware YaRN
ApproachScale all positions by 1/sModify base bPer-dimension frequency adjustment
High-freq handling❌ Compressed (destroys local info)⚠️ Partially preserved✅ Fully preserved (γ=1)
Low-freq handling✅ Interpolated✅ Interpolated✅ PI interpolated (γ=0)
Fine-tuning required200-400 steps200-400 steps~400 steps (10x fewer tokens)
Dynamic inference❌ Fixed scale✅ Supported✅ Dynamic s = l'/L
Attention correction❌ None❌ None✅ Temperature scaling t
Perplexity at 128kHigh (degraded)MediumLow (best)

Real-World Configurations

Here's how YaRN is configured in popular models:

LLaMA-2 7B (4k → 128k):

rope_scaling: type: yarn factor: 32.0 # s = 128k / 4k = 32 original_max_position_embeddings: 4096 attention_factor: 0.1 # attention temperature beta_fast: 32 # β threshold beta_slow: 1 # α threshold

Mistral 7B (8k → 64k):

rope_scaling: type: yarn factor: 8.0 # s = 64k / 8k = 8 original_max_position_embeddings: 8192 attention_factor: 0.1 beta_fast: 32 beta_slow: 1

HuggingFace Transformers:

from transformers import AutoModelForCausalLM

# Load model with YaRN scaling
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    rope_scaling={
        "type": "yarn",
        "factor": 32.0,
        "original_max_position_embeddings": 4096,
        "attention_factor": 0.1,
        "beta_fast": 32,
        "beta_slow": 1,
    }
)
📝 Code Walkthrough
  1. rope_scaling.type = "yarn" — Tells HuggingFace to use YaRN's NTK-by-parts interpolation instead of linear PI.
  2. factor = 32.0 — The scale factor s, extending context from 4k to 128k tokens.
  3. beta_fast / beta_slow — The β and α thresholds that define the three frequency zones.

8. Interactive Playground

Adjust the parameters below to see how PI, NTK-Aware, and YaRN handle each frequency dimension differently. Watch the three curves diverge as you increase the scale factor.

Drag the sliders to compare how each method modifies rotation frequencies across all 32 dimensions:

Scale (s) 4x
α (alpha) 1
β (beta) 32
PI vs NTK vs YaRN Comparison
● PI ● NTK-Aware ● YaRN
■ Keep: 0 dims
■ Transition: 0 dims
■ Interpolate: 0 dims
📝 Code Walkthrough
  1. Scale slider s — Globally controls how aggressively PI and YaRN compress frequencies.
  2. α / β sliders — Move the interpolation thresholds to re-partition dimensions into three zones.
  3. Three curves — Dashed PI, dotted NTK-Aware, and solid YaRN reveal where methods diverge.
  4. Zone counters — Instantly shows keep/transition/interpolate dimension counts under current settings.
  5. Live i18n labels — Slider labels and zone text update automatically when switching EN/中文.
  6. Visual intuition — YaRN hugs PI in low-frequency zones and converges to original in high-frequency zones.
  7. Practical takeaway — Tune α/β to preserve local semantics while extending long-range recall.