using FunctorFlowJEPA as Categorical World Model
Joint Embedding Predictive Architecture via Coalgebras
Introduction
JEPA (Joint Embedding Predictive Architecture) is LeCun’s proposed framework for building world models that predict in representation space rather than observation space. The key insight is that predicting pixel-by-pixel is intractable, but predicting embeddings of future states is feasible.
FunctorFlow.jl reveals that JEPA is fundamentally a coalgebraic construction:
| JEPA Component | Categorical Analog |
|---|---|
| Encoder | Coalgebra morphism |
| Predictor | Endofunctor dynamics |
| Target encoder (EMA) | Frozen reference coalgebra |
| Prediction loss | Obstruction to commutativity |
| World model | F-coalgebra (X → F(X)) |
| Optimal representation | Final coalgebra (Lambek’s lemma) |
Part 1: Coalgebra — World Models as State Transitions
A coalgebra is a pair (X, α : X → F(X)): a state space equipped with a transition structure. This is the categorical formalization of a world model.
# Build a simple world model diagram
D = world_model_block(;
name=:SimpleWorldModel,
observation_object=:Pixels,
latent_object=:Embedding,
encoder_name=:encode,
dynamics_name=:predict_next,
decoder_name=:decode,
)
println("Objects: ", join(keys(D.objects), ", "))
println("Operations: ", join(keys(D.operations), ", "))
println()
# The coalgebra structure
coalgebras = get_coalgebras(D)
for (name, c) in coalgebras
println(c)
endObjects: Pixels, Embedding
Operations: encode, predict_next, decode, encode_then_predict, autoencoder
Coalgebra(:coalgebra, Embedding →_{identity} Embedding)
The coalgebra declares that predict_next is the world model dynamics: given a latent state, it produces the next latent state.
# Bind concrete implementations
bind_morphism!(D, :encode, x -> x ./ 255.0) # normalize
bind_morphism!(D, :predict_next, x -> x .+ 0.01) # simple drift
bind_morphism!(D, :decode, x -> x .* 255.0) # denormalize
compiled = compile_to_callable(D)
result = FunctorFlow.run(compiled, Dict(
:Pixels => [128.0, 64.0, 200.0],
))
println("Encoded: ", round.(result.values[:encode]; digits=4))
println("Predicted:", round.(result.values[:predict_next]; digits=4))
println("Decoded: ", round.(result.values[:autoencoder]; digits=4))Encoded: [0.502, 0.251, 0.7843]
Predicted:[0.512, 0.261, 0.7943]
Decoded: [128.0, 64.0, 200.0]
Part 2: JEPA Block — Prediction in Embedding Space
The JEPA block encodes the fundamental prediction-in-embedding pattern as a FunctorFlow diagram:
Observation ──encoder_ctx──→ ContextRepr ──predictor──→ PredictedRepr
Target ────── encoder_tgt──→ TargetRepr
↕
prediction_loss (obstruction)
# Build a JEPA block
D_jepa = jepa_block(;
name=:ImageJEPA,
observation_object=:Context,
target_object=:MaskedRegion,
context_repr=:CtxEmb,
target_repr=:TgtEmb,
comparator=:l2,
)
println("=== JEPA Diagram: $(D_jepa.name) ===")
println("\nObjects:")
for (name, obj) in D_jepa.objects
println(" $(name) :: $(obj.kind)")
end
println("\nMorphisms:")
for (name, op) in D_jepa.operations
if op isa Morphism
role = get(op.metadata, :role, :unknown)
net = get(op.metadata, :network, :none)
detach = get(op.metadata, :detach, false)
println(" $(name): $(op.source) -> $(op.target) [role=$(role), net=$(net), detach=$(detach)]")
elseif op isa Composition
println(" $(name): $(op.source) → $(op.target) [chain=$(op.chain)]")
end
end
println("\nLosses:")
for (name, loss) in D_jepa.losses
println(" $(name): paths=$(loss.paths), comparator=$(loss.comparator)")
end
println("\nCoalgebra:")
for (name, c) in get_coalgebras(D_jepa)
println(" ", c)
end=== JEPA Diagram: ImageJEPA ===
Objects:
Context :: observation
MaskedRegion :: observation
CtxEmb :: representation
TgtEmb :: representation
Morphisms:
context_encoder: Context -> CtxEmb [role=encoder, net=online, detach=false]
target_encoder: MaskedRegion -> TgtEmb [role=encoder, net=target, detach=true]
predictor: CtxEmb -> TgtEmb [role=predictor, net=none, detach=false]
prediction_path: Context → TgtEmb [chain=[:context_encoder, :predictor]]
Losses:
prediction_loss: paths=[(:prediction_path, :target_encoder)], comparator=l2
Coalgebra:
Coalgebra(:jepa_dynamics, CtxEmb →_{identity} CtxEmb)
Part 3: JEPA Execution — The Obstruction Loss
The prediction loss IS an obstruction loss: it measures how far the encoder-predictor square is from commuting.
# Bind implementations: encoders as linear scaling
bind_morphism!(D_jepa, :context_encoder, x -> x .* 0.5)
bind_morphism!(D_jepa, :target_encoder, x -> x .* 0.5) # same encoder (EMA would make it close)
bind_morphism!(D_jepa, :predictor, identity) # trivial predictor
compiled_jepa = compile_to_callable(D_jepa)
# Case 1: Context and target are the same → perfect prediction
result1 = FunctorFlow.run(compiled_jepa, Dict(
:Context => [1.0, 2.0, 3.0],
:MaskedRegion => [1.0, 2.0, 3.0], # same input
))
println("Case 1 (same input): loss = ", round(result1.losses[:prediction_loss]; digits=6))
println(" Predicted: ", result1.values[:prediction_path])
println(" Target: ", result1.values[:target_encoder])
# Case 2: Different context and target → non-zero loss
result2 = FunctorFlow.run(compiled_jepa, Dict(
:Context => [1.0, 2.0, 3.0],
:MaskedRegion => [4.0, 5.0, 6.0], # different!
))
println("\nCase 2 (different input): loss = ", round(result2.losses[:prediction_loss]; digits=6))
println(" Predicted: ", result2.values[:prediction_path])
println(" Target: ", result2.values[:target_encoder])Case 1 (same input): loss = 0.0
Predicted: [0.5, 1.0, 1.5]
Target: [0.5, 1.0, 1.5]
Case 2 (different input): loss = 2.598076
Predicted: [0.5, 1.0, 1.5]
Target: [2.0, 2.5, 3.0]
When the loss is zero, the JEPA square commutes — the encoder is an exact coalgebra morphism.
Part 4: KAN-JEPA — Merging KET and JEPA
KAN-JEPA uses a left Kan extension (Σ) as the predictor, merging the KET attention pattern with JEPA’s embedding-space prediction:
D_kan = kan_jepa_block(;
name=:KanJEPA_Demo,
observation_object=:Tokens,
target_object=:MaskedTokens,
context_repr=:TokenEmb,
target_repr=:MaskedEmb,
relation_object=:Neighborhood,
reducer=:mean,
)
println("=== KAN-JEPA Diagram ===")
for (name, op) in D_kan.operations
if op isa KanExtension
println(" Σ($(op.source); along=$(op.along)) → $(op.target) [$(op.reducer)]")
elseif op isa Morphism
role = get(op.metadata, :role, :unknown)
println(" $(name): $(op.source) → $(op.target) [$(role)]")
end
end
println("\nThis is the natural fusion:")
println(" KET's attention = left-Kan aggregation (Σ)")
println(" JEPA's loss = obstruction to commutativity")=== KAN-JEPA Diagram ===
context_encoder: Tokens → TokenEmb [encoder]
target_encoder: MaskedTokens → MaskedEmb [encoder]
Σ(TokenEmb; along=Neighborhood) → MaskedEmb [mean]
This is the natural fusion:
KET's attention = left-Kan aggregation (Σ)
JEPA's loss = obstruction to commutativity
Part 5: Hierarchical JEPA (H-JEPA)
H-JEPA uses multiple abstraction levels for multi-scale prediction. Fine levels handle short-range details; coarse levels handle long-range abstract planning.
D_hjepa = hjepa_block(;
name=:MultiScaleJEPA,
levels=[:fine, :medium, :coarse],
)
println("=== H-JEPA: 3 Abstraction Levels ===")
println("\nObjects per level:")
for level in [:fine, :medium, :coarse]
obs = Symbol("$(level)__Obs_$(level)")
repr = Symbol("$(level)__CtxRepr_$(level)")
println(" $(level): $(obs) → $(repr)")
end
println("\nAbstraction morphisms:")
for (name, op) in D_hjepa.operations
if op isa Morphism && get(op.metadata, :role, nothing) == :abstraction
from = op.metadata[:from_level]
to = op.metadata[:to_level]
println(" $(name): $(from) → $(to)")
end
end
println("\nPorts:")
for (name, p) in D_hjepa.ports
println(" $(name) ($(p.direction)): $(p.ref)")
end=== H-JEPA: 3 Abstraction Levels ===
Objects per level:
fine: fine__Obs_fine → fine__CtxRepr_fine
medium: medium__Obs_medium → medium__CtxRepr_medium
coarse: coarse__Obs_coarse → coarse__CtxRepr_coarse
Abstraction morphisms:
abstract_fine_to_medium: fine → medium
abstract_medium_to_coarse: medium → coarse
Ports:
input (INPUT): fine__Obs_fine
fine_repr (OUTPUT): fine__CtxRepr_fine
coarse_repr (OUTPUT): coarse__CtxRepr_coarse
Part 6: Energy-Based Cost Module
The energy function measures compatibility in representation space. The cost module decomposes into intrinsic (immutable) and trainable components:
\[C(s) = \sum_i u_i \cdot \text{IC}_i(s) + \sum_j v_j \cdot \text{TC}_j(s)\]
# Build an energy block with VICReg-style regularization
D_energy = energy_block(;
name=:JEPAEnergy,
energy_type=:l2,
variance_weight=0.5,
covariance_weight=0.1,
collapse_strategy=VICREG,
)
cost_mods = get_cost_modules(D_energy)
cm = cost_mods[:cost]
println("=== Energy-Based Cost Module ===")
println("Intrinsic costs:")
for ic in cm.intrinsic_costs
println(" ", ic)
end
println("\nCollapse prevention: VICReg (variance + covariance regularization)")=== Energy-Based Cost Module ===
Intrinsic costs:
IntrinsicCost(:prediction_cost, type=prediction, weight=1.0)
IntrinsicCost(:variance_cost, type=variance, weight=0.5)
IntrinsicCost(:covariance_cost, type=covariance, weight=0.1)
Collapse prevention: VICReg (variance + covariance regularization)
Energy Function Implementations
x = [1.0, 0.5, -0.3, 0.8]
y = [0.9, 0.6, -0.2, 0.7]
println("L2 energy: ", round(energy_l2(x, y); digits=6))
println("Cosine energy: ", round(energy_cosine(x, y); digits=6))
println("Smooth L1: ", round(energy_smooth_l1(x, y); digits=6))
# Self-similarity should give zero energy
println("\nSelf-similarity:")
println(" L2(x, x) = ", energy_l2(x, x))
println(" Cosine(x, x) = ", round(energy_cosine(x, x); digits=10))L2 energy: 0.04
Cosine energy: 0.007994
Smooth L1: 0.02
Self-similarity:
L2(x, x) = 0.0
Cosine(x, x) = 5.1e-9
VICReg Regularization
using Random
Random.seed!(42)
# Good representations: diverse (high variance, low correlation)
Z_good = randn(4, 10) # 4 dims, 10 samples
var_good = variance_regularization(Z_good; gamma=1.0)
cov_good = covariance_regularization(Z_good)
# Collapsed representations: all identical
Z_bad = ones(4, 10)
var_bad = variance_regularization(Z_bad; gamma=1.0)
cov_bad = covariance_regularization(Z_bad)
println("Diverse representations:")
println(" Variance penalty: ", round(var_good; digits=4))
println(" Covariance penalty: ", round(cov_good; digits=4))
println("\nCollapsed representations:")
println(" Variance penalty: ", round(var_bad; digits=4), " (high = BAD)")
println(" Covariance penalty: ", round(cov_bad; digits=4))Diverse representations:
Variance penalty: 0.4483
Covariance penalty: 0.7101
Collapsed representations:
Variance penalty: 3.96 (high = BAD)
Covariance penalty: 0.0
Part 7: Lean 4 Proof Certificates
FunctorFlow.jl can generate Lean 4 proof certificates for JEPA diagrams, formally verifying the categorical structure:
D_proof = jepa_block(; name=:ProvedJEPA)
add_energy_function!(D_proof, :compat;
domain=[:ContextRepr, :TargetRepr], energy_type=:l2)
add_bisimulation!(D_proof, :enc_equiv;
coalgebra_a=:jepa_dynamics,
coalgebra_b=:jepa_dynamics,
relation=:predictor)
cert = render_jepa_certificate(D_proof)
println(cert)-- Auto-generated by FunctorFlow.jl
namespace FunctorFlowProofs.Generated.ProvedJEPA
open FunctorFlowProofs in
def exportedDiagram : DiagramDecl := {
name := "ProvedJEPA",
objects := ["Observation", "Target", "ContextRepr", "TargetRepr"],
operations := [{ name := "context_encoder", kind := OperationKind.morphism, refs := ["Observation", "ContextRepr"] }, { name := "target_encoder", kind := OperationKind.morphism, refs := ["Target", "TargetRepr"] }, { name := "predictor", kind := OperationKind.morphism, refs := ["ContextRepr", "TargetRepr"] }, { name := "prediction_path", kind := OperationKind.composition, refs := ["context_encoder", "predictor", "Observation", "TargetRepr"] }],
ports := [{ name := "context_input", ref := "Observation", kind := "object", portType := "observation", direction := "input" }, { name := "target_input", ref := "Target", kind := "object", portType := "observation", direction := "input" }, { name := "context_embedding", ref := "ContextRepr", kind := "object", portType := "representation", direction := "output" }, { name := "target_embedding", ref := "TargetRepr", kind := "object", portType := "representation", direction := "output" }, { name := "prediction", ref := "prediction_path", kind := "operation", portType := "representation", direction := "output" }, { name := "loss", ref := "prediction_loss", kind := "loss", portType := "loss", direction := "output" }]
}
def exportedArtifact : LoweringArtifact := {
diagram := exportedDiagram,
resolvedRefs := true,
portsClosed := true
}
theorem exportedArtifact_checks : exportedArtifact.check = true := by native_decide
theorem exportedArtifact_sound : exportedArtifact.Sound :=
LoweringArtifact.sound_of_check exportedArtifact_checks
-- Coalgebra declarations
def coalgebra_jepa_dynamics : CoalgebraDecl := {
name := "jepa_dynamics",
state := "ContextRepr",
transition := "predictor",
functorType := "identity"
}
-- JEPA prediction-as-obstruction theorems
/-- The JEPA prediction loss is an obstruction to commutativity
of the encoder/predictor square. When loss = 0, the square
commutes and the encoder is an exact coalgebra morphism. -/
theorem jepa_prediction_loss_is_obstruction :
exportedArtifact.lossIsObstruction "prediction_loss" := by
exact LoweringArtifact.loss_obstruction_of_check exportedArtifact_checks
/-- When the JEPA prediction loss is zero, the encoder-predictor
path commutes with the target encoder path, making the
encoder a coalgebra morphism (structure-preserving map). -/
theorem jepa_exact_implies_coalgebra_morphism
(h : ∀ l ∈ exportedArtifact.losses, l.value = 0) :
exportedArtifact.CoalgebraExact := by
exact LoweringArtifact.coalgebra_exact_of_zero_loss h
-- Bisimulation declarations
def bisim_enc_equiv : BisimulationDecl := {
name := "enc_equiv",
coalgebraA := "jepa_dynamics",
coalgebraB := "jepa_dynamics",
relation := "predictor"
}
/-- Two coalgebras are bisimilar iff they map to the same element
in the final coalgebra — behavioral equivalence. -/
theorem bisimilar_iff_final_coalgebra_equal
(A B : CoalgebraDecl) (R : BisimulationDecl)
(h : R.isBisimulation A B) :
A.finalImage = B.finalImage :=
CoalgebraDecl.bisim_implies_final_eq h
-- Energy function declarations
def energy_compat : EnergyDecl := {
name := "compat",
domain := ["ContextRepr", "TargetRepr"],
energyType := "l2"
}
/-- Energy functions are non-negative for L2 and cosine types. -/
theorem energy_nonneg (e : EnergyDecl)
(h : e.energyType ∈ ["l2", "cosine"]) :
0 ≤ e.evaluate := by
exact EnergyDecl.nonneg_of_standard h
end FunctorFlowProofs.Generated.ProvedJEPA
Key theorems generated: - jepa_prediction_loss_is_obstruction: The JEPA loss IS an obstruction to commutativity - jepa_exact_implies_coalgebra_morphism: Zero loss → encoder preserves structure - bisimilar_iff_final_coalgebra_equal: Bisimilar encoders are behaviorally equivalent - energy_nonneg: Energy functions are non-negative
Part 8: EMA Update (Collapse Prevention)
JEPA prevents representation collapse by using an exponential moving average (EMA) of the online encoder as the target encoder:
# Simulate EMA training dynamics
target_params = [Float32[0.1, 0.2, 0.3]]
online_params = [Float32[0.5, 0.6, 0.7]]
println("Before EMA:")
println(" Target: ", target_params[1])
println(" Online: ", online_params[1])
for step in 1:5
ema_update!(target_params, online_params; decay=0.99)
end
println("\nAfter 5 EMA steps (decay=0.99):")
println(" Target: ", round.(target_params[1]; digits=4))
println(" Online: ", online_params[1], " (unchanged)")
println("\nTarget slowly tracks online → prevents collapse without negatives")Before EMA:
Target: Float32[0.1, 0.2, 0.3]
Online: Float32[0.5, 0.6, 0.7]
After 5 EMA steps (decay=0.99):
Target: Float32[0.1196, 0.2196, 0.3196]
Online: Float32[0.5, 0.6, 0.7] (unchanged)
Target slowly tracks online → prevents collapse without negatives
Summary
| Construction | FunctorFlow Type | Purpose |
|---|---|---|
| Coalgebra | Coalgebra |
World model (state → F(state)) |
| Coalgebra morphism | CoalgebraMorphism |
Structure-preserving encoder |
| Final coalgebra | FinalCoalgebraWitness |
Optimal representation (Lambek) |
| Bisimulation | Bisimulation |
Behavioral equivalence of encoders |
| JEPA block | jepa_block() |
Encoder/predictor/target triple |
| KAN-JEPA | kan_jepa_block() |
Σ-attention as JEPA predictor |
| H-JEPA | hjepa_block() |
Multi-scale nested coalgebras |
| Energy function | EnergyFunction |
Compatibility measure in repr space |
| Cost module | CostModule |
IC + TC decomposition |
| EMA update | ema_update!() |
Collapse prevention |
The key insight: JEPA’s prediction loss is an obstruction to diagram commutativity. Minimizing it drives the encoder toward being an exact coalgebra morphism — a structure-preserving map between the observation world model and the latent world model.