using Pkg
Pkg.activate(joinpath(@__DIR__, ".."))
using FunctorFlow
using Random
using StatisticsCausal-JEPA: Object-Level Interventions as Categorical Operations
Connecting C-JEPA’s latent interventions with FunctorFlow’s RN-Kan-Do-Calculus
Introduction
C-JEPA (Causal-JEPA, arXiv:2602.11389) demonstrates that object-level masking during JEPA training induces a causal inductive bias — making interaction reasoning functionally necessary. Rather than masking random patches (as in I-JEPA), C-JEPA masks entire objects across history frames. This forces the predictor to recover the masked object’s state purely from its interactions with other objects, not from its own self-dynamics.
The central theoretical result (Theorem 1: Interaction Necessity) shows that under object-level masking, minimizing the prediction loss necessarily identifies the influence neighborhood — the minimal set of context entities sufficient to predict a masked entity. This is the causal analog of a Markov blanket.
FunctorFlow.jl’s categorical framework maps these ideas precisely:
| C-JEPA Concept | FunctorFlow Construction |
|---|---|
| Object slots \(S_t = \{s_t^1, \ldots, s_t^N\}\) | Categorical objects in a Diagram |
| Object-level masking | do-calculus intervention via left-Kan \(\Sigma\) |
| Unmasked context | Conditioning via right-Kan \(\Delta\) |
| Influence neighborhood \(\mathcal{N}_t(i)\) | Markov blanket = minimal sufficient \(\Delta\) |
| Joint prediction loss | Obstruction loss in JEPA diagram |
| EMA target encoder | Frozen reference coalgebra |
| History reconstruction (\(\mathcal{L}_\text{history}\)) | Right-Kan completion (repair) |
| Future prediction (\(\mathcal{L}_\text{future}\)) | Left-Kan aggregation (prediction) |
| Slot Attention (entity decomposition) | Product diagram (\(\otimes\)) of entity diagrams |
| Counterfactual reasoning | Composition: intervene (\(\Sigma\)) then condition (\(\Delta\)) |
| Identity anchor | Coalgebra initial state |
Setup
Object-Centric World as a Product Diagram
C-JEPA processes each frame through Slot Attention, decomposing a scene into \(N\) object slots: \(S_t = \{s_t^1, \ldots, s_t^N\}\). Each slot is an independent entity with its own latent state and dynamics. Categorically, each entity is an F-coalgebra — a state paired with a transition morphism \(X \to F(X)\) — and the full scene state is their product.
function entity_diagram(name::Symbol, slot_dim::Int)
D = Diagram(name)
add_object!(D, :State, kind=:hidden_state,
description="Entity latent state (ℝ^$slot_dim)")
add_object!(D, :NextState, kind=:hidden_state,
description="Next-step entity state")
add_morphism!(D, :dynamics, :State, :NextState,
description="Entity self-dynamics f: X → F(X)")
add_coalgebra!(D, :entity_coalgebra,
state=:State, transition=:dynamics, functor_type=:identity)
return D
end
N = 5 # number of entity slots
slot_dim = 16
entities = [entity_diagram(Symbol("Entity_$i"), slot_dim) for i in 1:N]
println("Created $(length(entities)) entity coalgebras")
for (i, e) in enumerate(entities)
coalgs = get_coalgebras(e)
println(" Entity $i: coalgebra=$(first(keys(coalgs)))")
endCreated 5 entity coalgebras
Entity 1: coalgebra=entity_coalgebra
Entity 2: coalgebra=entity_coalgebra
Entity 3: coalgebra=entity_coalgebra
Entity 4: coalgebra=entity_coalgebra
Entity 5: coalgebra=entity_coalgebra
The product construction gives us projection morphisms \(\pi_i\) from the scene to each entity, and the universal property ensures that any compatible mapping factors through the product:
scene = product(entities...; name=:Scene)
scene_check = verify(scene)
println("Scene product: $(scene_check)")
println("Projections: $(scene.projections)")
println(summary(scene.product_diagram))Scene product: (passed = true, checks = Dict{Symbol, Bool}(:has_factor_factor_4 => 1, :has_factor_factor_1 => 1, :has_factor_factor_2 => 1, :has_factor_factor_5 => 1, :has_factor_factor_3 => 1), construction = :product)
Projections: [:factor_1, :factor_2, :factor_3, :factor_4, :factor_5]
FunctorFlow.Diagram
Object-Level Masking as a Categorical Intervention
This is the core conceptual bridge. In C-JEPA, masking object \(i\) across history means replacing its self-dynamics with a learnable mask token (anchored by its initial state at \(t=0\)). The remaining unmasked entities form the context. The predictor must recover the masked entity’s state from context alone.
In FunctorFlow’s RN-Kan-Do-Calculus:
- Observational regime: Full history \(Z_T\) with all entities visible
- Interventional regime: Masked history \(\bar{Z}_T\) where entity \(i\)’s history is replaced by mask tokens — this is \(\text{do}(\text{entity}_i = \text{mask})\)
The left-Kan extension \(\Sigma\) implements the do-operation (intervention), and the right-Kan extension \(\Delta\) implements conditioning (observing the context).
obs_regime = CausalContext(:full_observation;
observational_regime=:all_entities_visible,
interventional_regime=:entity_masked)
causal_d = build_causal_diagram(:EntityInteraction;
context=obs_regime,
observation_source=:EntityStates,
causal_relation=:InteractionGraph,
intervention_target=:InterventionalPrediction,
conditioning_target=:ConditionalState)
println("Causal diagram: $(causal_d.name)")
println(" Conditioning (Δ, right-Kan): $(causal_d.conditioning_kan)")
println(" Intervention (Σ, left-Kan): $(causal_d.intervention_kan)")
println(summary(causal_d.base_diagram))Causal diagram: EntityInteraction
Conditioning (Δ, right-Kan): condition
Intervention (Σ, left-Kan): intervene
FunctorFlow.Diagram
The masking operation is categorically a left-Kan extension (colimit / aggregation): we push forward along the masking relation, which removes entity \(i\)’s self-dynamics from the diagram. The prediction of the masked entity then requires a right-Kan extension (limit / completion) from the remaining context.
We can check identifiability — is the masked entity’s state recoverable from context?
ident = is_identifiable(causal_d, :InterventionalPrediction;
observed=[:EntityStates])
println("Identifiability: $(ident)")Identifiability: (identifiable = true, rule = :adjustment, reasoning = "Both Kan extensions share source and causal structure; adjustment formula applies via back-door criterion")
Influence Neighborhood as Right-Kan Completion
C-JEPA’s influence neighborhood \(\mathcal{N}_t(i)\) is the minimal sufficient subset \(\mathcal{N}_t(i) \subseteq Z_T^{(-i)}\) needed to predict masked entity \(i\). In categorical terms, this is the kernel of the right-Kan extension — the smallest subdiagram whose \(\Delta\)-completion recovers the masked state.
Theorem 1 (Interaction Necessity) states: under object-level masking, minimizing the prediction loss necessarily identifies \(\mathcal{N}_t(i)\). Categorically, the optimal predictor’s right-Kan extension is supported exactly on the influence neighborhood.
D_influence = Diagram(:InfluenceNeighborhood)
add_object!(D_influence, :MaskedEntity, kind=:hidden_state,
description="Entity whose history is masked (target)")
add_object!(D_influence, :ContextEntities, kind=:hidden_state,
description="Unmasked entities forming the context")
add_object!(D_influence, :InteractionRelation, kind=:relation,
description="Which context entities influence the target")
add_object!(D_influence, :CompletedState, kind=:hidden_state,
description="Inferred state of masked entity via right-Kan")
# Right-Kan: complete the masked entity from context
Δ(D_influence, :ContextEntities;
along=:InteractionRelation,
name=:infer_from_context,
target=:CompletedState,
reducer=:first_non_null)
println(summary(D_influence))FunctorFlow.Diagram
The right-Kan extension \(\Delta_J F\) computes the best approximation of the masked entity given the context, mediated by the interaction relation \(J\). The influence neighborhood is precisely the support of \(J\) — the minimal set of context entities for which \(\Delta_J F \cong F\) (the completion is exact).
C-JEPA Training as Coalgebraic Obstruction
The C-JEPA training loss has two complementary components, each corresponding to a Kan extension:
- \(\mathcal{L}_\text{history}\): Recover masked slots from context → right-Kan completion (\(\Delta\)) — repair missing data by conditioning on context
- \(\mathcal{L}_\text{future}\): Predict future slots from history → left-Kan aggregation (\(\Sigma\)) — forward prediction by pushing history forward
The total loss measures the obstruction to commutativity — the failure of the predicted state to match the target encoder’s output.
D_cjepa = Diagram(:CJEPA)
# -- History branch (right-Kan: repair masked entities) --
add_object!(D_cjepa, :HistoryContext, kind=:hidden_state,
description="Unmasked history slots (context encoder output)")
add_object!(D_cjepa, :MaskRelation, kind=:relation,
description="Which slots are masked vs. visible")
add_object!(D_cjepa, :ReconstructedHistory, kind=:hidden_state,
description="Recovered masked history slots")
add_object!(D_cjepa, :TargetHistory, kind=:hidden_state,
description="Ground truth masked history (from EMA target encoder)")
Δ(D_cjepa, :HistoryContext;
along=:MaskRelation,
name=:recover_masked,
target=:ReconstructedHistory,
reducer=:first_non_null)
# Morphism from target encoder (EMA-frozen) to produce ground truth
add_morphism!(D_cjepa, :target_enc_history, :HistoryContext, :TargetHistory,
description="EMA target encoder for history (frozen)")
add_obstruction_loss!(D_cjepa, :history_loss;
paths=[(:recover_masked, :target_enc_history)],
comparator=:l2, weight=1.0,
description="L_history: masked slot reconstruction")
# -- Future branch (left-Kan: predict future from context) --
add_object!(D_cjepa, :CausalRelation, kind=:relation,
description="Temporal causal structure (history → future)")
add_object!(D_cjepa, :PredictedFuture, kind=:hidden_state,
description="Predicted future slots")
add_object!(D_cjepa, :TargetFuture, kind=:hidden_state,
description="Ground truth future slots (from EMA target encoder)")
Σ(D_cjepa, :HistoryContext;
along=:CausalRelation,
name=:predict_future,
target=:PredictedFuture,
reducer=:sum)
add_morphism!(D_cjepa, :target_enc_future, :HistoryContext, :TargetFuture,
description="EMA target encoder for future (frozen)")
add_obstruction_loss!(D_cjepa, :future_loss;
paths=[(:predict_future, :target_enc_future)],
comparator=:l2, weight=1.0,
description="L_future: forward prediction")
println(summary(D_cjepa))FunctorFlow.Diagram
The two losses interact: \(\mathcal{L}_\text{history}\) forces the model to learn the interaction structure (which entities influence which), while \(\mathcal{L}_\text{future}\) forces it to learn temporal dynamics conditioned on that interaction structure. Together, they are the obstruction to a fully commutative diagram — the residual measures how far the learned world model is from perfect causal understanding.
Synthetic Multi-Object Interaction Simulation
To make these ideas concrete, we create a synthetic scene with \(N = 5\) particles on a 1D line interacting via spring-like forces. This is simple enough to analyze exactly, yet rich enough to demonstrate C-JEPA’s key insight: without masking, a model can exploit self-dynamics shortcuts; with masking, it must learn interactions.
rng = Random.MersenneTwister(42)
N_objects = 5
slot_dim = 4 # (position, velocity, mass, charge)
T_history = 3
T_future = 2
T_total = T_history + T_future
function simulate_scene(rng, N; steps=5, dt=0.1)
positions = randn(rng, N) * 2.0
velocities = randn(rng, N) * 0.5
masses = ones(N)
charges = rand(rng, [-1.0, 1.0], N)
trajectory = []
for t in 1:steps
state = hcat(positions, velocities, masses, charges) # N × 4
push!(trajectory, state)
# Pairwise forces (Coulomb-like)
forces = zeros(N)
for i in 1:N, j in 1:N
i == j && continue
dx = positions[j] - positions[i]
forces[i] += charges[i] * charges[j] * sign(dx) / (abs(dx)^2 + 0.1)
end
velocities = velocities .+ forces ./ masses .* dt
positions = positions .+ velocities .* dt
end
return trajectory
end
trajectories = [simulate_scene(rng, N_objects; steps=T_total) for _ in 1:32]
println("Generated $(length(trajectories)) trajectories, each with $(T_total) frames")
println("Frame shape: $(size(trajectories[1][1])) = ($N_objects objects × $slot_dim features)")
# Show the first trajectory's positions over time
traj1 = trajectories[1]
println("\nEntity positions over time (trajectory 1):")
for t in 1:T_total
pos = round.(traj1[t][:, 1]; digits=3)
println(" t=$t: $pos")
endGenerated 32 trajectories, each with 5 frames
Frame shape: (5, 4) = (5 objects × 4 features)
Entity positions over time (trajectory 1):
t=1: [2.421, -0.158, 0.807, 0.58, -0.134]
t=2: [2.461, 0.016, 0.715, 0.583, -0.251]
t=3: [2.51, 0.171, 0.508, 0.631, -0.289]
t=4: [2.567, 0.371, 0.324, 0.548, -0.273]
t=5: [2.632, 0.527, 0.282, 0.306, -0.203]
Masking and Counterfactual Queries
Object-level masking replaces an entity’s history (after \(t=0\)) with NaN tokens, preserving only the identity anchor at \(t=0\). In FunctorFlow, this is a do-operation: we intervene on the entity’s state, severing its self-dynamics.
function mask_entity(trajectory, entity_idx; anchor_frame=1)
T = length(trajectory)
masked = [copy(frame) for frame in trajectory]
anchor = trajectory[anchor_frame][entity_idx, :]
for t in (anchor_frame + 1):T
masked[t][entity_idx, :] .= NaN # mask token placeholder
end
return masked, anchor
end
# Mask entity 3 — this is do(entity_3 = mask_token)
masked_traj, anchor = mask_entity(trajectories[1], 3)
println("Identity anchor for entity 3 (t=0): $(round.(anchor; digits=3))")
println("Entity 3 at t=2 after masking: $(masked_traj[2][3, :])")
println("Entity 1 at t=2 (unmasked): $(round.(masked_traj[2][1, :]; digits=3))")Identity anchor for entity 3 (t=0): [0.807, -0.027, 1.0, 1.0]
Entity 3 at t=2 after masking: [NaN, NaN, NaN, NaN]
Entity 1 at t=2 (unmasked): [2.461, 0.403, 1.0, -1.0]
Now we use FunctorFlow’s causal machinery to compute the interventional expectation — what is the expected state of entity 3 given the intervention?
# The CausalDiagram uses observation_source=:EntityStates and causal_relation=:InteractionGraph
obs_data = Dict{Symbol,Any}(
:EntityStates => trajectories[1][2][:, 1], # entity positions at t=2
:InteractionGraph => Dict("full" => collect(1:N)), # all entities interact
)
# Bind reducers for the causal diagram's Kan extensions
bind_reducer!(causal_d.base_diagram, :sum,
(data, relation, meta) -> data) # identity aggregation for demo
bind_reducer!(causal_d.base_diagram, :first_non_null,
(data, relation, meta) -> data) # identity completion for demo
ie_result = interventional_expectation(causal_d, obs_data)
println("Interventional expectation keys: $(keys(ie_result))")
println("Intervention result: $(ie_result[:intervention])")
println("Conditioning result: $(ie_result[:conditioning])")Interventional expectation keys: [:conditioning, :intervention, :all_values]
Intervention result: [2.4608197268247225, 0.01601258820176013, 0.7148194917284902, 0.5825704526616224, -0.25136721974573484]
Conditioning result: [2.4608197268247225, 0.01601258820176013, 0.7148194917284902, 0.5825704526616224, -0.25136721974573484]
We can compile and execute the influence neighborhood diagram with concrete implementations:
# Bind reducer with correct 3-arg signature: (data, relation, metadata)
bind_reducer!(D_influence, :first_non_null,
(data, relation, meta) -> data) # passthrough for symbolic demo
compiled_influence = compile_to_callable(D_influence)
# Compute the context: all entities except the masked one
context_states = masked_traj[2][[1, 2, 4, 5], :] # exclude entity 3
interaction_mask = [true, true, true, true] # all context entities participate
result = FunctorFlow.run(compiled_influence, Dict(
:MaskedEntity => fill(NaN, slot_dim),
:ContextEntities => context_states,
:InteractionRelation => interaction_mask
))
println("Influence neighborhood result keys: $(keys(result.values))")
predicted = result.values[:infer_from_context]
println("Predicted entity 3 state (from context): $(round.(predicted; digits=3))")Influence neighborhood result keys: [:InteractionRelation, :infer_from_context, :MaskedEntity, :CompletedState, :ContextEntities]
Predicted entity 3 state (from context): [2.461 0.403 1.0 -1.0; 0.016 1.74 1.0 1.0; 0.583 0.027 1.0 1.0; -0.251 -1.173 1.0 1.0]
Counterfactual Reasoning via Σ ∘ Δ Composition
The key categorical insight: counterfactual = intervene (\(\Sigma\)) then condition (\(\Delta\)).
“What would entity \(i\)’s trajectory be if entity \(j\) had been removed from the scene?”
This is the composition \(\Delta_Q \circ \Sigma_J\): first push forward along the intervention relation \(J\) (removing entity \(j\)), then complete along the query relation \(Q\) (inferring entity \(i\)’s state in the counterfactual world).
D_counterfactual = Diagram(:Counterfactual)
add_object!(D_counterfactual, :ObservedScene, kind=:hidden_state,
description="Full observed scene state")
add_object!(D_counterfactual, :InterventionRelation, kind=:relation,
description="Remove entity j from the interaction graph")
add_object!(D_counterfactual, :IntervenedScene, kind=:hidden_state,
description="Scene without entity j's influence")
add_object!(D_counterfactual, :QueryRelation, kind=:relation,
description="Select entity i from the intervened scene")
add_object!(D_counterfactual, :CounterfactualState, kind=:hidden_state,
description="Entity i's state in the counterfactual world")
# Step 1: Σ (intervene — aggregate scene without entity j)
Σ(D_counterfactual, :ObservedScene;
along=:InterventionRelation,
name=:intervene,
target=:IntervenedScene,
reducer=:sum)
# Step 2: Δ (condition — infer entity i from the intervened scene)
Δ(D_counterfactual, :IntervenedScene;
along=:QueryRelation,
name=:condition_query,
target=:CounterfactualState,
reducer=:first_non_null)
println(summary(D_counterfactual))FunctorFlow.Diagram
Execute the counterfactual query: “What would entity 3 do if entity 1 were removed?”
# Bind reducers with correct 3-arg signatures
bind_reducer!(D_counterfactual, :sum,
(data, relation, meta) -> data) # passthrough (symbolic demo)
bind_reducer!(D_counterfactual, :first_non_null,
(data, relation, meta) -> data) # passthrough (symbolic demo)
compiled_cf = compile_to_callable(D_counterfactual)
# Observed scene at t=2
observed_scene = trajectories[1][2]
# Intervention: remove entity 1 (keep entities 2,3,4,5)
intervened = observed_scene[[2, 3, 4, 5], :]
# Query: select entity 3 from the counterfactual scene
query_state = intervened[2, :] # entity 3 is at index 2 after removal
cf_result = FunctorFlow.run(compiled_cf, Dict(
:ObservedScene => observed_scene,
:InterventionRelation => [false, true, true, true, true], # remove entity 1
:QueryRelation => [false, true, false, false], # select entity 3
))
println("Counterfactual query: 'What if entity 1 were removed?'")
println(" Entity 3 observed state: $(round.(observed_scene[3, :]; digits=3))")
println(" Entity 3 counterfactual state: $(round.(query_state; digits=3))")
println(" Causal effect of entity 1 on 3: $(round.(observed_scene[3, :] .- query_state; digits=3))")Counterfactual query: 'What if entity 1 were removed?'
Entity 3 observed state: [0.715, -0.922, 1.0, 1.0]
Entity 3 counterfactual state: [0.715, -0.922, 1.0, 1.0]
Causal effect of entity 1 on 3: [0.0, 0.0, 0.0, 0.0]
Topos-Theoretic View: Subobject Classifier for Masking
FunctorFlow’s topos module provides a precise account of masking as a subobject classifier. The mask \(M \subseteq \{1, \ldots, N\}\) classifies which entities are visible (context) vs. hidden (target). In a topos, this classification is mediated by the characteristic morphism \(\chi_M : \{1, \ldots, N\} \to \Omega\), where \(\Omega = \{\text{visible}, \text{masked}\}\) is the subobject classifier.
sc = SubobjectClassifier(:MaskClassifier;
truth_values=Set{Symbol}([:visible, :masked]))
# The mask predicate classifies entities by their visibility
mask_pred = InternalPredicate(:is_visible, sc;
characteristic_map = x -> any(isnan, x) ? :masked : :visible)
# Classify entities in a masked frame
masked_frame = masked_traj[2] # frame where entity 3 is masked
entity_data = [masked_frame[i, :] for i in 1:N_objects]
println("Entity classification under masking:")
for (i, entity) in enumerate(entity_data)
label = evaluate_predicate(mask_pred, entity)
println(" Entity $i: $label")
endEntity classification under masking:
Entity 1: visible
Entity 2: visible
Entity 3: masked
Entity 4: visible
Entity 5: visible
We can also use classify_subobject to get the full classification as a dictionary:
classification = classify_subobject(sc,
x -> any(isnan, x) ? :masked : :visible,
entity_data)
println("Subobject classification: $classification")Subobject classification: Dict{Any, Symbol}(5 => Symbol("false"), 4 => Symbol("false"), 2 => Symbol("false"), 3 => Symbol("false"), 1 => Symbol("false"))
The topos perspective reveals that different masking strategies correspond to different subobject classifiers over the same entity set — and the C-JEPA objective is a sheaf condition: the local predictions (one per masked entity) must glue consistently into a global scene prediction.
Bisimulation: When Two Masking Strategies Are Equivalent
Two masking strategies are bisimilar if they induce the same interaction structure — that is, if the influence neighborhoods they produce are identical. In coalgebraic terms, this means the world models trained under different masks are behaviorally equivalent: they generate the same observable predictions for all future queries.
# Create two masking-strategy world models as diagrams
D_mask3 = Diagram(:Mask3Strategy)
add_object!(D_mask3, :State, kind=:hidden_state)
add_object!(D_mask3, :NextState, kind=:hidden_state)
add_morphism!(D_mask3, :dynamics, :State, :NextState,
description="Dynamics under masking entity 3")
add_coalgebra!(D_mask3, :cjepa_mask_3,
state=:State, transition=:dynamics, functor_type=:identity)
D_mask4 = Diagram(:Mask4Strategy)
add_object!(D_mask4, :State, kind=:hidden_state)
add_object!(D_mask4, :NextState, kind=:hidden_state)
add_morphism!(D_mask4, :dynamics, :State, :NextState,
description="Dynamics under masking entity 4")
add_coalgebra!(D_mask4, :cjepa_mask_4,
state=:State, transition=:dynamics, functor_type=:identity)
# Declare bisimulation: these strategies are equivalent if they produce
# the same influence neighborhood
add_bisimulation!(D_mask3, :masking_equivalence;
coalgebra_a=:cjepa_mask_3,
coalgebra_b=:cjepa_mask_4,
relation=:dynamics,
description="Masking entities 3 and 4 produces equivalent interaction structure")
bisims = get_bisimulations(D_mask3)
println("Bisimulation declared: $(first(keys(bisims)))")
println(" Coalgebra A: $(first(values(bisims)).coalgebra_a)")
println(" Coalgebra B: $(first(values(bisims)).coalgebra_b)")
println(" Relation: $(first(values(bisims)).relation)")Bisimulation declared: masking_equivalence
Coalgebra A: cjepa_mask_3
Coalgebra B: cjepa_mask_4
Relation: dynamics
When two masks are bisimilar, C-JEPA’s learned representations will be invariant to the choice between them — the model extracts the same causal structure regardless of which specific entity is masked, provided the influence neighborhoods are isomorphic.
EMA Target Encoder as Frozen Coalgebra
C-JEPA uses an exponential moving average (EMA) target encoder to produce stable prediction targets. In coalgebraic terms, the target encoder is a frozen reference coalgebra — its transition dynamics are fixed, providing a stable attractor for the online encoder to converge toward.
# Simulate EMA update between online and target parameters
# ema_update! expects iterable-of-arrays (e.g., tuple/vector of parameter arrays)
online_params = [randn(rng, 4, 4), randn(rng, 4)] # e.g., weight matrix + bias
target_params = [randn(rng, 4, 4), randn(rng, 4)]
dist_before = sum(sum((o .- t).^2) for (o, t) in zip(online_params, target_params))
println("Before EMA update:")
println(" Online-Target distance: $(round(dist_before; digits=4))")
ema_update!(target_params, online_params; decay=0.996)
dist_after = sum(sum((o .- t).^2) for (o, t) in zip(online_params, target_params))
println("After EMA update (decay=0.996):")
println(" Online-Target distance: $(round(dist_after; digits=4))")
# Multiple updates converge the target toward the online encoder
for _ in 1:100
ema_update!(target_params, online_params; decay=0.996)
end
println("After 100 EMA updates:")
dist_final = sum(sum((o .- t).^2) for (o, t) in zip(online_params, target_params))
println(" Online-Target distance: $(round(dist_final; digits=6))")Before EMA update:
Online-Target distance: 18.2989
After EMA update (decay=0.996):
Online-Target distance: 18.1528
After 100 EMA updates:
Online-Target distance: 8.143523
The coalgebraic interpretation: each EMA step is a coalgebra morphism that contracts the distance between the online and target state spaces. The fixed point (after many updates) is where the two coalgebras become bisimilar — the target encoder faithfully represents the online encoder’s learned dynamics.
Summary and Correspondence Table
C-JEPA demonstrates that object-level masking is not merely a data augmentation strategy — it is a causal intervention that makes interaction reasoning functionally necessary. FunctorFlow provides the categorical language to make this precise:
| C-JEPA | Category Theory | FunctorFlow |
|---|---|---|
| Object slots \(S_t\) | Objects in Prod category | product(entities...) |
| Object masking | do-intervention | \(\Sigma\) (left-Kan) in CausalDiagram |
| Context (unmasked) | Conditioning | \(\Delta\) (right-Kan) in CausalDiagram |
| Influence neighborhood | Markov blanket / right-Kan kernel | \(\Delta\) completion sufficiency |
| \(\mathcal{L}_\text{history}\) (reconstruction) | Right-Kan obstruction | add_obstruction_loss!(:history_loss, ...) |
| \(\mathcal{L}_\text{future}\) (prediction) | Left-Kan obstruction | add_obstruction_loss!(:future_loss, ...) |
| EMA target encoder | Frozen coalgebra reference | ema_update! |
| Identity anchor (\(t=0\)) | Initial object | Coalgebra initial state |
| Interaction Necessity (Thm 1) | Minimality of influence presheaf | Right-Kan kernel sufficiency |
| Slot Attention | Entity decomposition | Entity sub-diagrams |
| Counterfactual query | \(\Sigma \circ \Delta\) composition | intervene then condition |
| Mask predicate | Subobject classifier | SubobjectClassifier in topos.jl |
| Masking equivalence | Bisimulation of coalgebras | add_bisimulation! |
The key insight is that C-JEPA’s empirical success (~20% improvement on counterfactual VQA, 8× faster MPC planning) has a precise categorical explanation: object-level masking enforces the right universal property — it forces the predictor to factor through the influence neighborhood, which is exactly the causal structure needed for counterfactual reasoning and efficient planning.