using Pkg
Pkg.activate(joinpath(@__DIR__, "..", ".."))
using FunctorFlow
using Random
using StatisticsLeft-Kan / Right-Kan Duality for Language Modeling
Prediction via Σ and repair via Δ as an adjoint pair
Introduction
The adjunction \(\Sigma_J \dashv \Delta_J\) is one of the deepest patterns in category theory. Given a functor \(J \colon \mathcal{A} \to \mathcal{B}\), the left Kan extension \(\Sigma_J\) and right Kan extension \(\Delta_J\) form an adjoint pair:
\[ \Sigma_J \dashv \Delta_J \]
In FunctorFlow, this duality has a concrete computational interpretation:
- \(\Sigma\) (left-Kan): aggregate information along a relation — the universal prediction operation. Categorically, \(\Sigma_J(F)(c) = \mathrm{colim}_{j \to c}\, F(j)\): the best summary of all sources that can attend to target \(c\).
- \(\Delta\) (right-Kan): complete partial information along a compatibility relation — the universal repair operation. Categorically, \(\Delta_J(F)(c) = \lim_{c \to j}\, F(j)\): the most compatible completion consistent with all constraints on \(c\).
In language modeling, these dual operations correspond to two complementary tasks:
- Next-block prediction (\(\Sigma\)): given past context, aggregate it into a prediction for the next token block.
- Block denoising (\(\Delta\)): given a corrupted token block, repair it by completing the missing/corrupted positions.
Together, prediction and repair form a categorical auto-encoding loop:
- Unit \(\eta\colon \mathrm{Id} \to \Delta_J \circ \Sigma_J\) — predict, then repair, should recover the original.
- Counit \(\varepsilon\colon \Sigma_J \circ \Delta_J \to \mathrm{Id}\) — repair, then predict, should be faithful.
This vignette demonstrates both operations on a synthetic token sequence task, building all diagrams from scratch and using the structured_lm_duality block builder.
Setup
Synthetic Data
We create a simple synthetic corpus with clear block structure: sequences are built by concatenating randomly chosen 4-token patterns from a small vocabulary. This makes both prediction and repair tractable — the patterns are learnable.
vocab_size = 50
seq_len = 32
block_size = 4
n_patterns = 10
rng = Random.MersenneTwister(42)
# Generate a fixed set of base patterns
base_patterns = [rand(rng, 1:vocab_size, block_size) for _ in 1:n_patterns]
println("Base patterns:")
for (i, p) in enumerate(base_patterns)
println(" Pattern $i: $p")
endBase patterns:
Pattern 1: [11, 45, 20, 33]
Pattern 2: [35, 36, 49, 38]
Pattern 3: [11, 6, 10, 19]
Pattern 4: [32, 46, 9, 46]
Pattern 5: [10, 21, 3, 36]
Pattern 6: [48, 9, 14, 39]
Pattern 7: [7, 10, 33, 9]
Pattern 8: [38, 9, 44, 37]
Pattern 9: [10, 40, 32, 40]
Pattern 10: [38, 12, 42, 16]
function generate_block_sequence(rng, base_patterns; seq_len=32, block_size=4)
n_blocks = seq_len ÷ block_size
blocks = [base_patterns[rand(rng, 1:length(base_patterns))] for _ in 1:n_blocks]
return vcat(blocks...)
end
function make_batch(rng, base_patterns; batch_size=16, seq_len=32, block_size=4)
seqs = [generate_block_sequence(rng, base_patterns; seq_len, block_size) for _ in 1:batch_size]
return hcat(seqs...) # seq_len × batch_size
end
batch = make_batch(rng, base_patterns; batch_size=8)
println("Batch shape: ", size(batch), " (seq_len × batch_size)")
println("First sequence: ", batch[:, 1])Batch shape: (32, 8) (seq_len × batch_size)
First sequence: [48, 9, 14, 39, 10, 21, 3, 36, 35, 36 … 49, 38, 11, 6, 10, 19, 48, 9, 14, 39]
Left Kan: Block Prediction (Σ)
The left Kan extension \(\Sigma_J(F)\) computes the colimit — the universal aggregation of source values along a relation. In the language modeling setting:
- Source values: hidden-state embeddings at each position (simulated as one-hot token IDs).
- Relation: a causal mask — position \(j\) can attend to position \(c\) only if \(j \leq c\).
- Reducer: aggregates the attended values (here,
:sumfor simplicity). - Result: a contextualized representation at each position that summarizes all past tokens.
We then decode this aggregated context into predictions for the next token block.
Building the prediction diagram
D_predict = Diagram(:LeftKanPredict)
# Objects
add_object!(D_predict, :HiddenStates; kind=:messages,
description="Token embeddings at each position")
add_object!(D_predict, :CausalRelation; kind=:relation,
description="Causal mask: which positions attend to which")
add_object!(D_predict, :Contextualized; kind=:contextualized_messages,
description="Aggregated context per position")
# Left-Kan: aggregate past context via causal relation
Σ(D_predict, :HiddenStates;
along=:CausalRelation,
name=:causal_aggregate,
target=:Contextualized,
reducer=:sum)
add_object!(D_predict, :BlockLogits; kind=:output,
description="Predicted next-block token offsets")
# Decode to block offset predictions
add_morphism!(D_predict, :decode_block, :Contextualized, :BlockLogits;
description="Map aggregated context to block-level predictions")
println(D_predict)Diagram :LeftKanPredict ⟨4 objects, 1 morphisms, 1 Kan, 0 losses⟩
Executing the prediction
We build a causal relation as a dictionary mapping each position to the set of positions it can attend to (all earlier positions within the same block boundary).
seq = batch[:, 1]
n_blocks_seq = seq_len ÷ block_size
# Hidden states: map each position to its token ID
hidden_states = Dict(Symbol("pos_$i") => seq[i] for i in 1:seq_len)
# Causal relation: each block-boundary position attends to all positions
# in the current and previous blocks
causal_relation = Dict{Any, Any}()
for b in 1:n_blocks_seq
target_pos = b * block_size # last position of block b
target_key = Symbol("block_$b")
sources = [Symbol("pos_$i") for i in 1:target_pos]
causal_relation[target_key] = sources
end
println("Causal relation (block 1 attends to): ", causal_relation[:block_1])
println("Causal relation (block 3 attends to): ", causal_relation[:block_3])Causal relation (block 1 attends to): [:pos_1, :pos_2, :pos_3, :pos_4]
Causal relation (block 3 attends to): [:pos_1, :pos_2, :pos_3, :pos_4, :pos_5, :pos_6, :pos_7, :pos_8, :pos_9, :pos_10, :pos_11, :pos_12]
compiled_predict = compile_to_callable(D_predict)
# Provide a simple decode function: identity (pass through the aggregated sum)
result_predict = FunctorFlow.run(compiled_predict,
Dict(:HiddenStates => hidden_states,
:CausalRelation => causal_relation);
morphisms=Dict(:decode_block => x -> x)
)
println("Σ aggregation result (causal_aggregate):")
agg = result_predict.values[:causal_aggregate]
for b in 1:min(4, n_blocks_seq)
key = Symbol("block_$b")
if haskey(agg, key)
println(" $key → $(agg[key]) (sum of token IDs in causal window)")
end
endΣ aggregation result (causal_aggregate):
block_1 → 110 (sum of token IDs in causal window)
block_2 → 180 (sum of token IDs in causal window)
block_3 → 338 (sum of token IDs in causal window)
block_4 → 446 (sum of token IDs in causal window)
The left-Kan aggregation sums all token IDs in the causal window for each block boundary. In a real model, these would be dense embeddings and the reducer would be learned attention — but the categorical structure is the same.
Right Kan: Block Denoising (Δ)
The right Kan extension \(\Delta_J(F)\) computes the limit — the most compatible completion of partial data. In the denoising setting:
- Source values: a corrupted token block where some positions have been replaced with
nothing. - Relation: a compatibility structure — each corrupted position checks its neighbors for valid values.
- Reducer:
:first_non_null— fill each missing position with the first available compatible value. - Result: a repaired block with missing positions filled in.
Building the denoising diagram
D_denoise = Diagram(:RightKanDenoise)
add_object!(D_denoise, :NoisyBlock; kind=:partial,
description="Token block with corrupted positions (nothing = missing)")
add_object!(D_denoise, :CompatibilityRelation; kind=:relation,
description="Which positions can fill which missing slots")
add_object!(D_denoise, :CompletedBlock; kind=:completed_values,
description="Repaired block")
# Right-Kan: complete partial/corrupted data
Δ(D_denoise, :NoisyBlock;
along=:CompatibilityRelation,
name=:repair,
target=:CompletedBlock,
reducer=:first_non_null)
println(D_denoise)Diagram :RightKanDenoise ⟨3 objects, 0 morphisms, 1 Kan, 0 losses⟩
Corruption and repair
function corrupt_block(rng, block; noise_rate=0.3)
mask = rand(rng, length(block)) .> noise_rate
corrupted = Vector{Union{Int, Nothing}}(copy(block))
for i in eachindex(corrupted)
if !mask[i]
corrupted[i] = nothing
end
end
return corrupted, mask
end
# Take one block from the first sequence
original_block = seq[1:block_size]
corrupted, mask = corrupt_block(rng, original_block; noise_rate=0.5)
println("Original block: ", original_block)
println("Corruption mask: ", mask, " (true = kept, false = corrupted)")
println("Corrupted block: ", corrupted)Original block: [48, 9, 14, 39]
Corruption mask: Bool[1, 1, 0, 1] (true = kept, false = corrupted)
Corrupted block: Union{Nothing, Int64}[48, 9, nothing, 39]
# Build noisy values: position → token or nothing
noisy_values = Dict(Symbol("t_$i") => corrupted[i] for i in 1:block_size)
# Compatibility: each position can look at all other positions for repair
compat_relation = Dict{Any, Any}()
for i in 1:block_size
others = [Symbol("t_$j") for j in 1:block_size if j != i]
compat_relation[Symbol("t_$i")] = others
end
compiled_denoise = compile_to_callable(D_denoise)
result_denoise = FunctorFlow.run(compiled_denoise,
Dict(:NoisyBlock => noisy_values,
:CompatibilityRelation => compat_relation))
repaired = result_denoise.values[:repair]
println("\nRight-Kan repair (Δ):")
for i in 1:block_size
key = Symbol("t_$i")
orig = original_block[i]
noisy = corrupted[i]
fixed = get(repaired, key, nothing)
status = noisy === nothing ? " ← REPAIRED" : ""
println(" $key: original=$orig, corrupted=$noisy, repaired=$fixed$status")
end
Right-Kan repair (Δ):
t_1: original=48, corrupted=48, repaired=9
t_2: original=9, corrupted=9, repaired=48
t_3: original=14, corrupted=nothing, repaired=48 ← REPAIRED
t_4: original=39, corrupted=39, repaired=48
The right-Kan extension fills each nothing slot with the first non-null value from compatible neighbors. This is the categorical version of masked token reconstruction — \(\Delta_J(F)(c)\) finds the value most consistent with all constraints pointing out of \(c\).
The Duality Diagram
FunctorFlow provides structured_lm_duality, a block builder that combines both operations into a single diagram with a shared input:
D_dual = structured_lm_duality()
println(D_dual)Diagram :StructuredLMDuality ⟨5 objects, 0 morphisms, 2 Kan, 0 losses⟩
The structured_lm_duality block internally:
- Creates a KET block (left-Kan prediction via \(\Sigma\)) under the
:predictnamespace. - Creates a completion block (right-Kan repair via \(\Delta\)) under the
:repairnamespace. - Aliases both to share a common
:SharedInputobject.
Manual construction
We can also build the duality manually to see the full structure:
D_manual = Diagram(:ManualDuality)
# Shared embedding space
add_object!(D_manual, :Tokens; kind=:hidden_state,
description="Token embeddings — shared input")
# --- Left branch: prediction via Σ ---
add_object!(D_manual, :CausalRelation; kind=:relation)
add_object!(D_manual, :PredictionContext; kind=:contextualized_messages)
add_object!(D_manual, :PredictedBlock; kind=:output)
Σ(D_manual, :Tokens;
along=:CausalRelation,
name=:predict_aggregate,
target=:PredictionContext,
reducer=:sum)
add_morphism!(D_manual, :predict_decode, :PredictionContext, :PredictedBlock)
# --- Right branch: repair via Δ ---
add_object!(D_manual, :NoisyTokens; kind=:partial)
add_object!(D_manual, :RepairRelation; kind=:relation)
add_object!(D_manual, :RepairedTokens; kind=:completed_values)
add_object!(D_manual, :DenoisedBlock; kind=:output)
Δ(D_manual, :NoisyTokens;
along=:RepairRelation,
name=:repair_complete,
target=:RepairedTokens,
reducer=:first_non_null)
add_morphism!(D_manual, :repair_decode, :RepairedTokens, :DenoisedBlock)
println(D_manual)Diagram :ManualDuality ⟨8 objects, 2 morphisms, 2 Kan, 0 losses⟩
Executing both branches
compiled_dual = compile_to_callable(D_manual)
# Prepare inputs for both branches
tokens = Dict(Symbol("pos_$i") => seq[i] for i in 1:seq_len)
# Causal relation for prediction (block-level)
causal_rel = Dict{Any, Any}()
for b in 1:n_blocks_seq
target_pos = b * block_size
causal_rel[Symbol("block_$b")] = [Symbol("pos_$i") for i in 1:target_pos]
end
# Noisy tokens for repair (corrupt the second block)
block2 = seq[(block_size+1):(2*block_size)]
corrupted2, mask2 = corrupt_block(rng, block2; noise_rate=0.5)
noisy_tokens = Dict(Symbol("b2_$i") => corrupted2[i] for i in 1:block_size)
repair_rel = Dict{Any, Any}()
for i in 1:block_size
repair_rel[Symbol("b2_$i")] = [Symbol("b2_$j") for j in 1:block_size if j != i]
end
result_dual = FunctorFlow.run(compiled_dual,
Dict(:Tokens => tokens,
:CausalRelation => causal_rel,
:NoisyTokens => noisy_tokens,
:RepairRelation => repair_rel);
morphisms=Dict(
:predict_decode => x -> x,
:repair_decode => x -> x
)
)
println("=== Left-Kan (Σ): Prediction ===")
agg_dual = result_dual.values[:predict_aggregate]
for b in 1:min(3, n_blocks_seq)
key = Symbol("block_$b")
haskey(agg_dual, key) && println(" $key → $(agg_dual[key])")
end
println("\n=== Right-Kan (Δ): Repair ===")
rep_dual = result_dual.values[:repair_complete]
for i in 1:block_size
key = Symbol("b2_$i")
orig = block2[i]
noisy = corrupted2[i]
fixed = get(rep_dual, key, nothing)
status = noisy === nothing ? " ← REPAIRED" : ""
println(" $key: original=$orig, corrupted=$noisy, repaired=$fixed$status")
end=== Left-Kan (Σ): Prediction ===
block_1 → 110
block_2 → 180
block_3 → 338
=== Right-Kan (Δ): Repair ===
b2_1: original=10, corrupted=10, repaired=21
b2_2: original=21, corrupted=21, repaired=10
b2_3: original=3, corrupted=3, repaired=10
b2_4: original=36, corrupted=nothing, repaired=10 ← REPAIRED
The two branches compute different things from the same conceptual space:
- The left branch (\(\Sigma\)) merges many source values into a single summary per target — it answers “what is the aggregated context?”.
- The right branch (\(\Delta\)) fills in missing values from compatible neighbors — it answers “what is the most consistent completion?”.
The Predict-Repair Loop
The categorical adjunction \(\Sigma_J \dashv \Delta_J\) predicts a pipeline:
- Start with a clean sequence \(x\).
- Predict the next block: \(\hat{y} = \Sigma_J(x)\) — aggregate context into a prediction.
- Corrupt the prediction: \(\tilde{y} = \mathrm{noise}(\hat{y})\) — simulate real-world noise.
- Repair the corrupted block: \(\hat{x} = \Delta_J(\tilde{y})\) — complete the partial data.
The unit of the adjunction says that \(\Delta_J \circ \Sigma_J \approx \mathrm{Id}\) — the predict-then-repair cycle should approximately recover the original data. Let us verify this numerically.
function predict_repair_loop(seq, block_idx; noise_rate=0.3, block_size=4)
# Step 1: Prediction via Σ (aggregate tokens up to this block)
D_loop = Diagram(:PredictRepairLoop)
add_object!(D_loop, :Src; kind=:messages)
add_object!(D_loop, :Rel; kind=:relation)
Σ(D_loop, :Src; along=:Rel, name=:predicted, reducer=:mean)
compiled_loop = compile_to_callable(D_loop)
target_end = block_idx * block_size
src_vals = Dict(Symbol("p_$i") => Float64(seq[i]) for i in 1:target_end)
rel_vals = Dict(:target => [Symbol("p_$i") for i in 1:target_end])
result_pred = FunctorFlow.run(compiled_loop, Dict(:Src => src_vals, :Rel => rel_vals))
predicted_mean = result_pred.values[:predicted][:target]
# Step 2: Corruption — replace some block positions with nothing
true_block = seq[(target_end - block_size + 1):target_end]
rng_loop = Random.MersenneTwister(block_idx)
corrupted_block, cmask = corrupt_block(rng_loop, true_block; noise_rate)
# Step 3: Repair via Δ (fill missing with predicted mean)
D_repair = Diagram(:RepairStep)
add_object!(D_repair, :Partial; kind=:partial)
add_object!(D_repair, :Compat; kind=:relation)
Δ(D_repair, :Partial; along=:Compat, name=:repaired, reducer=:first_non_null)
compiled_repair = compile_to_callable(D_repair)
# Inject predicted mean as a fallback source alongside the corrupted values
partial_vals = Dict{Any, Any}(Symbol("t_$i") => corrupted_block[i] for i in 1:block_size)
partial_vals[:predicted] = round(Int, predicted_mean)
compat_vals = Dict{Any, Any}()
for i in 1:block_size
neighbors = [Symbol("t_$j") for j in 1:block_size if j != i]
push!(neighbors, :predicted)
compat_vals[Symbol("t_$i")] = neighbors
end
result_rep = FunctorFlow.run(compiled_repair, Dict(:Partial => partial_vals, :Compat => compat_vals))
repaired = result_rep.values[:repaired]
repaired_block = [get(repaired, Symbol("t_$i"), nothing) for i in 1:block_size]
return (true_block=true_block, corrupted=corrupted_block,
predicted_mean=predicted_mean, repaired=repaired_block)
end
# Run for blocks 2 through 5
for b in 2:min(5, n_blocks_seq)
r = predict_repair_loop(seq, b; noise_rate=0.5)
n_corrupted = count(x -> x === nothing, r.corrupted)
n_recovered = sum(r.repaired[i] == r.true_block[i] for i in 1:block_size if r.corrupted[i] === nothing; init=0)
println("Block $b: true=$(r.true_block), predicted_mean=$(round(r.predicted_mean; digits=1)), " *
"corrupted=$n_corrupted positions, recovered=$n_recovered/$n_corrupted")
endBlock 2: true=[10, 21, 3, 36], predicted_mean=22.5, corrupted=2 positions, recovered=0/2
Block 3: true=[35, 36, 49, 38], predicted_mean=28.2, corrupted=3 positions, recovered=0/3
Block 4: true=[38, 12, 42, 16], predicted_mean=27.9, corrupted=2 positions, recovered=0/2
Block 5: true=[38, 9, 44, 37], predicted_mean=28.7, corrupted=0 positions, recovered=0/0
Even with this simple symbolic execution (integer token IDs, :mean aggregation, :first_non_null repair), the predict-repair loop demonstrates the categorical structure. In a neural setting, the reducers would be learned (attention for \(\Sigma\), compatibility scoring for \(\Delta\)), but the diagrammatic skeleton is identical.
Categorical Interpretation
The adjunction \(\Sigma_J \dashv \Delta_J\) gives us:
Unit: \(\eta\colon \mathrm{Id} \to \Delta_J \circ \Sigma_J\)
Starting from a complete sequence, \(\Sigma\) aggregates it into a prediction, and \(\Delta\) repairs back toward the original. The unit measures how much information the predict-then-repair cycle preserves:
\[ x \xrightarrow{\eta_x} \Delta_J(\Sigma_J(x)) \]
If \(\eta\) is close to an isomorphism, the cycle is lossless — prediction captures enough information for exact repair.
Counit: \(\varepsilon\colon \Sigma_J \circ \Delta_J \to \mathrm{Id}\)
Starting from corrupted data, \(\Delta\) repairs it, and \(\Sigma\) re-aggregates the repaired version. The counit measures faithfulness:
\[ \Sigma_J(\Delta_J(\tilde{x})) \xrightarrow{\varepsilon_{\tilde{x}}} \tilde{x} \]
If \(\varepsilon\) is close to an isomorphism, repair-then-predict is faithful — the repaired data re-aggregates consistently.
In practice
This is the categorical version of auto-encoding:
- The unit η corresponds to the reconstruction objective (VAE / denoising autoencoder).
- The counit ε corresponds to the consistency objective (cycle consistency, round-trip loss).
- Training both branches jointly with an obstruction loss on the round-trip error enforces the adjunction.
# Verify numerically: predict then repair on the first few blocks
println("Predict-then-repair round-trip accuracy:")
for b in 2:min(6, n_blocks_seq)
r = predict_repair_loop(seq, b; noise_rate=0.5)
accuracy = mean(r.repaired[i] == r.true_block[i] for i in 1:block_size)
println(" Block $b: $(round(accuracy * 100; digits=1))% positions match original")
endPredict-then-repair round-trip accuracy:
Block 2: 0.0% positions match original
Block 3: 0.0% positions match original
Block 4: 0.0% positions match original
Block 5: 0.0% positions match original
Block 6: 0.0% positions match original
Summary
| Aspect | Σ (Left-Kan) | Δ (Right-Kan) |
|---|---|---|
| Category theory | Colimit | Limit |
| Universal property | Best summary given all sources | Most compatible completion given all constraints |
| Computation | Aggregation (many → one) | Completion (partial → complete) |
| AI interpretation | Prediction | Repair / denoising |
| Default reducer | :sum |
:first_non_null |
| Data flow | Many sources → single target | Incomplete data → filled data |
| LM example | Next-block prediction | Masked token reconstruction |
| FunctorFlow operator | Σ(D, src; along=rel, ...) |
Δ(D, src; along=rel, ...) |
The key insight is that \(\Sigma\) and \(\Delta\) are not two separate mechanisms — they are two faces of the same categorical coin. FunctorFlow makes this explicit by giving them the same API shape (source, relation, reducer) while their default reducers encode the duality: :sum aggregates existing values, while :first_non_null fills in missing ones.
The structured_lm_duality block builder packages this pattern into a reusable component, but as we have shown, the duality can also be constructed manually for full control over the diagram topology.