MUX: Continuous Reasoning via Multiplexed Tokens

1AITHYRA, 2University of Michigan, 3University of Oxford, 4TU Wien, 5KAIST
Equal advising
Overview of MUX method

MUX trains each latent reasoning token to represent a weighted linear superposition of a span of discrete reasoning subwords. The superposition is lossless by construction.

Abstract

Language models solve complex problems by articulating intermediate reasoning steps in natural language. While effective, this process is computationally bottlenecked: each reasoning step conveys only a single subword, and many are spent expressing a thought instead of carrying out computation.

We propose MUX, a simple method for high-bandwidth and compact reasoning based on distillation of discrete reasoning into continuous multiplexed tokens in a latent space. Here, each latent token is trained to represent a weighted linear superposition (multiplexing) of a span of discrete reasoning subwords, where this superposition is lossless by construction and the span can be fully recovered (demultiplexing).

We prove that simple position-dependent weightings, such as suitable geometric decay, support lossless multiplexing, which in turn prevents shortcut behaviors caused by latent collapse. We further show that multiplexed reasoning can perform parallel exploration in problems that require search.

Across 32 evaluation settings spanning four language models, MUX outperforms strong latent reasoning baselines. Ablation and probing analyses further show that the learned latent tokens encode faithful and interpretable reasoning. Our results suggest that lossless superposition as local learning targets constitutes a sufficient condition for achieving strong and efficient latent continuous reasoning.

Key Contributions

1. Latent reasoning via multiplexed tokens. A local distillation method for continuous latent reasoning based on multiplexed targets. For each latent token, we define a vocabulary-space target by taking a position-weighted linear superposition of one-hot encodings in its corresponding discrete reasoning span.
2. Lossless multiplexing. We identify simple classes of positional weightings (geometric, sinusoidal, rotary) that guarantee lossless multiplexing via a subset-sum separation condition. Lossless multiplexing prevents shortcut behaviors caused by latent collapse.
3. Parallel search via multiplexing. Multiplexed tokens are expressive enough to represent and update multiple hypotheses simultaneously, implementing each BFS step using a single latent token. Parallel search can naturally emerge from serial supervision.
4. State-of-the-art results. Best latent reasoning method across 32 mathematical reasoning settings spanning two training corpora, four language models, and four test sets. Surpasses strong discrete and continuous reasoning baselines on two search benchmarks.

Method

Lossless multiplexing illustration

Lossless multiplexing of a span «5+3=8» through position-weighted linear superposition.

Given a discrete reasoning span (r1, ..., rS), MUX constructs a vocabulary-space target via:

mux(r) = Σj αj · onehot(rj)

where the coefficients αj are position-dependent weights normalized to lie on the vocabulary simplex. The model is trained to match these targets via KL divergence through a linear-softmax head.

We prove that geometric, sinusoidal, and rotary weightings all support lossless multiplexing—the original span can be exactly recovered from the superposition.

Positioning of MUX

Method Supervision Lossless Shortcut-free Train Eff. Infer. Eff. Interpretable
SFT-CoT Discrete
CODI Global
SIM-CoT Local
KaVa Local
MUX Local

Results

32/32
Best latent reasoning across all evaluation settings
15
Settings where MUX outperforms discrete SFT-CoT
2.4–5.9×
Fewer reasoning tokens than SFT-CoT

Mathematical Reasoning

Test accuracies (%). Underlined when MUX outperforms SFT-CoT. MUX reports ±1 std over 3 seeds.

Method GSM8K-AUG GSM8K-AUG-NL
IDSVAMPGSM-HardMultiArith IDSVAMPGSM-HardMultiArith
GPT-2
SFT-CoT 44.141.89.890.7 34.236.97.188.7
CODI 43.742.99.992.8 34.130.86.858.9
SIM-CoT 42.642.69.492.8 30.927.56.553.9
MUX 48.145.010.693.0 37.436.78.972.4
LLaMA 3.2 1B-Instruct
SFT-CoT 61.666.715.699.3 53.262.913.398.5
Coconut 45.348.89.990.1 24.2
CODI 55.661.112.896.1 47.955.311.396.7
SIM-CoT 56.161.512.796.2 28.443.06.659.4
MUX 56.763.613.098.5 50.357.511.696.9

Scaling to Larger Models (GSM8K-AUG)

Method LLaMA 3.2 3B LLaMA 3.1 8B
IDSVAMPGSM-HardMultiArith IDSVAMPGSM-HardMultiArith
SFT-CoT 71.571.017.098.3 71.773.116.598.3
CODI 60.873.314.398.7 61.178.115.599.5
SIM-CoT 62.374.914.698.8 64.179.416.3100.0
MUX 65.077.115.2100.0 68.180.117.1100.0

Parallel Search

Search accuracies (%) averaged over 3 seeds.

Method MNNS Game of 24
No-CoT 68.474.4
SFT-CoT 84.684.3
Coconut 92.878.6
CoT2 98.985.0
MUX 99.688.7

Interpretability Analysis

Through probing analysis, we show that MUX latent tokens encode faithful and interpretable reasoning content. By projecting latent tokens through the LM head, the top-decoded subwords closely match the aligned discrete reasoning spans.

Qualitative interpretability results on mathematical reasoning
Mathematical reasoning (GSM8K-AUG): Top-5 LM-head decoded subwords per latent token. MUX recovers interpretable reasoning content.
Interpretability on MNNS search task
Parallel search (MNNS): Latent tokens encode the full search frontier, demonstrating parallel exploration of multiple hypotheses in superposition.
Quantitative interpretability metrics on math
Mathematical reasoning metrics
MNNS trace and frontier metrics
Parallel search metrics (trace & frontier)

Attention Analysis

Attention analysis on GSM8K-AUG
Attention routing through continuous reasoning tokens. Attention flows from the answer token through the latent reasoning sequence, showing that latent tokens actively contribute to the final prediction.

BibTeX

@article{suleymanzade2025mux,
  author    = {Suleymanzade, Ayhan and Gozeten, Halil Alperen and Bronstein, Michael and Ceylan, \.{I}smail \.{I}lkan and Kim, Jinwoo},
  title     = {{MUX}: Continuous Reasoning via Multiplexed Tokens},
  year      = {2025},
}