大型语言模型(LLM)的上下文窗口正在极速扩张。
YaRN 提供了一种优雅且计算零开销的方法,
通过动态调整旋转位置编码(RoPE),将上下文长度扩展数倍。
在理解 YaRN 之前,我们必须先理解 RoPE。RoPE 将词向量的元素两两分组,并在复平面上进行旋转,旋转的角度与该词在句子中的绝对位置 成正比。不同维度 拥有不同的旋转频率 。
| The rotated embedding for token at position m, dimension d | |
| Complex rotation by angle m·θ_d (Euler's formula) | |
| Input embedding vector at position m | |
| Rotation frequency for dimension d; decreases exponentially with d | |
| Base constant (controls frequency range) | |
| Total embedding dimension (e.g., 64) |
# RoPE: Rotary Position Embedding
def rope_embed(x, position_m, dim_d, base=10000, D=64):
# 1. 计算当前维度的旋转频率
theta_d = base ** (-2 * dim_d / D)
# 2. 计算旋转角度
angle = position_m * theta_d
cos_a, sin_a = cos(angle), sin(angle)
# 3. 在 2D 子空间中应用复数旋转
x_even_new = x_even * cos_a - x_odd * sin_a
x_odd_new = x_even * sin_a + x_odd * cos_a
return x_even_new, x_odd_new
关键在于,不同维度编码了不同的信息:
低维度(小 d)旋转极快,捕捉局部的相对距离;
高维度(大 d)旋转极慢,几乎静止,编码宏观的绝对位置信息。
旋转频率 决定了该维度在序列中需要多少个 Token 才能完成一次完整的 旋转。我们将这个跨度称为波长 (Wavelength) 。同时定义预训练上下文长度 与波长的比值 。
| Wavelength: token span for one full 2π rotation | |
| Rotation frequency at dimension d | |
| Frequency ratio indicating local/global behavior | |
| Original pre-training context length | |
| High-frequency dimensions (local detail) | |
| Low-frequency dimensions (global position) |
# Wavelength and ratio computation
def wavelength(dim_d, base=10000, D=64):
theta = base ** (-2 * dim_d / D)
lam = 2 * pi / theta # wavelength in tokens
return lam
def freq_ratio(dim_d, L=4096):
return L / wavelength(dim_d) # r >> 1: high freq, r << 1: low freq
这个比值决定了该维度如何被处理:当 时,该维度在训练窗口内完成多次循环 高频(局部信息);当 时,该维度几乎不旋转 低频(全局位置信息)。
要扩展上下文到 ,最直接的方法是 Position Interpolation (PI):把所有位置坐标缩小为 。这等价于将所有频率压缩为 。
| Position interpolation: compress positions by scale s | |
| Equivalent frequency compression in every dimension | |
| Target extension multiplier (e.g., 8, 16, 32) | |
| Uniform scaling | Applies equally to all dimensions, including high-frequency ones |
| High-freq impact | Adjacent-token angle differences shrink too much |
| Consequence | Model loses local positional resolution |
# Position Interpolation — the naive approach
def position_interpolation(position_m, scale_s):
return position_m / scale_s # compress ALL positions by s
# Problem: high-freq dims lose resolution!
# Adjacent tokens: angle_diff = theta_d / s → too small
这对高频维度(低 d)是灾难性的。高频维度原本用来区分非常近的 Token,将它们硬性拉伸后,相邻 Token 的相对角度差变得过小,导致模型“近视”,丢失了细粒度的相对位置分辨能力。
在引入分段 NTK 之前,我们先看看 NTK-Aware Interpolation。它不直接缩放位置,而是修改底数 ,将插值的压力分散到所有维度上:
| Updated RoPE base used for extended context | |
| Original RoPE base (usually 10000) | |
| Context extension multiplier | |
| Embedding dimension count | |
| Exponent | Dimension-dependent correction term |
| Main idea | Reparameterize frequencies instead of raw positions |
# NTK-Aware: modify the base instead of positions
def ntk_aware_base(base, scale_s, D=64):
return base * scale_s ** (D / (D - 2))
# Spreads interpolation pressure across ALL dimensions
# High-freq dims: less affected (good!)
# Low-freq dims: more affected (intended)
YaRN 的精髓在于 NTK-by-parts:不要一刀切地缩放。引入比值 $r = L /
\lambda_d$,并定义两个阈值 $\alpha=1, \beta=32$:
■ 高频 (r > \beta):波长远小于L。完全不插值,保持原样。
■ 中频 (\alpha < r < \beta):使用斜坡函数平滑过渡。
■ 低频 (r < \alpha):波长大于L。执行 PI 线性插值 (除以s)。
| Final blended frequency for dimension d | |
| Ramp factor controlling interpolation intensity | |
| Thresholds separating interpolate/transition/keep zones | |
| Pure PI mode for very low-frequency dimensions | |
| Keep original frequency for high-frequency dimensions | |
| Middle regime | Linear blend for smooth continuity and stability |
# YaRN Core: NTK-by-parts interpolation
def yarn_frequency(theta_d, dim_d, L=4096, scale_s=4, alpha=1, beta=32):
lam = 2 * pi / theta_d
r = L / lam # wavelength ratio
# Ramp function γ(r)
if r < alpha:
gamma = 0.0 # low freq → full PI
elif r > beta:
gamma = 1.0 # high freq → keep original
else:
gamma = (r - alpha) / (beta - alpha) # smooth blend
# Blended frequency
theta_new = (1 - gamma) * (theta_d / scale_s) + gamma * theta_d
return theta_new
让我们用具体数字(以 LLaMA 为例,)来看三个极端情况:
1. 最高频 (d=0):,波长 。比值 $r
\approx 652 \gg \beta$。此时 $\gamma=1$,完全使用原频率 $\theta_0$,保留局部信息。
2. 最低频 (d=31): 极小,波长远大于 。比值 $r \approx
0.00065 \ll \alpha$。此时 $\gamma=0\theta_{31}/s$。
3. 中频:按比例平滑混合。
在实际推理时,我们并不知道最终序列长度。YaRN 使用动态缩放因子,根据当前序列长度 实时计算:
| Dynamic context scaling factor used at inference time | |
| Current sequence length during autoregressive decoding | |
| Original pre-training max context length | |
| Prevents shrinking frequencies for short sequences | |
| No context extension; default RoPE behavior is preserved | |
| Activates YaRN interpolation/correction for long-context inference | |
| Practical impact | Single checkpoint can flexibly handle mixed sequence lengths |
| Key advantage | No extra fine-tuning needed for each target length |
这意味着模型可以在不进行任何微调(Fine-tuning)的情况下,在推理时动态适应不同的上下文长度。
为什么分段 NTK 如此有效?(Intuition)
核心洞察是:高频维度编码的是局部距离(相邻
Token)。它们的波长极短,在预训练窗口
内已经历了无数个完整周期。它们已经见过了所有可能的相对位置,如果强行插值,反而会破坏这种细粒度。而低频维度编码的是全局位置,波长超过
,在训练时连一个完整周期都没走完。这些维度面对更长的序列时,才真正需要通过 PI
进行“插值外推”。NTK-by-parts 完美地用数学将这一直觉落地。
YaRN 发现,在插值后,注意力的 Logits 分布会发生变动(平均长度变大导致方差改变)。为了修正这种温度偏差,YaRN 在注意力计算后乘以一个极小但关键的缩放系数:
| Attention correction factor | |
| Current extension scale factor | |
| Logarithmic temperature compensation term | |
| Exponent $-2$ | Stabilizes scaling as s grows large |
| t = 1, so original attention behavior remains unchanged | |
| t decreases, cooling shifted attention logits |
# Attention temperature correction
def attention_scale(scale_s):
t = (0.1 * log(scale_s) + 1) ** (-2)
return t # multiply attention logits by sqrt(t)
# s=1 → t=1.0 (no change)
# s=16 → t≈0.56 (cool down attention)
结合上述所有组件,端到端的 YaRN Pipeline 如下所示:
# Complete YaRN Pipeline
def yarn_rope(x, position_m, L=4096, base=10000, D=64, alpha=1, beta=32, seq_len=8192):
scale_s = max(1, seq_len / L) # dynamic scaling
for d in range(D // 2):
theta = base ** (-2 * d / D)
lam = 2 * pi / theta
r = L / lam
# NTK-by-parts
if r < alpha: gamma = 0.0 # → PI interpolation
elif r > beta: gamma = 1.0 # → keep original
else: gamma = (r - alpha) / (beta - alpha)
theta_new = (1 - gamma) * theta / scale_s + gamma * theta
angle = position_m * theta_new
# Apply rotation
x[2*d], x[2*d+1] = rotate_2d(x[2*d], x[2*d+1], angle)
# Attention scaling
t = (0.1 * log(scale_s) + 1) ** (-2)
return x, t
Below is a side-by-side comparison of the three main context extension methods, followed by real-world configurations:
| Feature | PI | NTK-Aware | YaRN |
|---|---|---|---|
| Approach | Scale all positions by 1/s | Modify base b | Per-dimension frequency adjustment |
| High-freq handling | ❌ Compressed (destroys local info) | ⚠️ Partially preserved | ✅ Fully preserved (γ=1) |
| Low-freq handling | ✅ Interpolated | ✅ Interpolated | ✅ PI interpolated (γ=0) |
| Fine-tuning required | 200-400 steps | 200-400 steps | ~400 steps (10x fewer tokens) |
| Dynamic inference | ❌ Fixed scale | ✅ Supported | ✅ Dynamic s = l'/L |
| Attention correction | ❌ None | ❌ None | ✅ Temperature scaling t |
| Perplexity at 128k | High (degraded) | Medium | Low (best) |
Here's how YaRN is configured in popular models:
LLaMA-2 7B (4k → 128k):
Mistral 7B (8k → 64k):
HuggingFace Transformers:
from transformers import AutoModelForCausalLM
# Load model with YaRN scaling
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
rope_scaling={
"type": "yarn",
"factor": 32.0,
"original_max_position_embeddings": 4096,
"attention_factor": 0.1,
"beta_fast": 32,
"beta_slow": 1,
}
)
Drag the sliders to compare how each method modifies rotation frequencies across all 32 dimensions: