Add proof of reaching definition analysis

This requires a few pieces:

* Make node tags use `Fin n` intead of natural numbers. This makes
  it possible to build a finite lattice over AST nodes, and also
  ensure automatic, total indexing from CFG nodes into the AST that
  created them. For this, use the elaborator to derive the ordering
  statements etc. where possible.
* Adjust the forward framework to enable proofs that don't just state
  correctness on the environment, but also on an arbitrary additional
  state accumulated from traversing the trace.
* State the reaching definition analysis's correctness in terms
  of this new framework.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-27 16:29:16 -05:00
parent 5737805125
commit b6b30958aa
20 changed files with 678 additions and 197 deletions

View File

@@ -9,13 +9,22 @@ namespace Forward
variable {L : Type} [FiniteHeightLattice L] {prog : Program} [E : StmtEvaluator L prog]
def evalStmtOrNone (s : prog.State) (o : Option BasicStmt) (hco : prog.code s = o)
(vs : VariableValues L prog) : VariableValues L prog :=
o.elimEq vs (fun bs h => E.eval s bs (hco.trans h))
lemma evalStmtOrNone_mono (s : prog.State) (o : Option BasicStmt)
(hco : prog.code s = o) : Monotone (evalStmtOrNone (L := L) s o hco) :=
elimEq_self_mono o (fun bs h vs => E.eval s bs (hco.trans h) vs)
(fun bs h => E.eval_mono s bs (hco.trans h))
def updateVariablesForState (s : prog.State) (sv : StateVariables L prog) :
VariableValues L prog :=
(prog.code s).foldl (fun vs bs => E.eval s bs vs) (variablesAt s sv)
evalStmtOrNone s (prog.code s) rfl (variablesAt s sv)
lemma updateVariablesForState_mono (s : prog.State) :
Monotone (updateVariablesForState (L := L) s) := fun _ _ hle =>
foldl_mono' (prog.code s) _ (E.eval_mono s ·) (variablesAt_le hle s)
evalStmtOrNone_mono s (prog.code s) rfl (variablesAt_le hle s)
def updateAll (sv : StateVariables L prog) : StateVariables L prog :=
FiniteMap.generalizedUpdate id updateVariablesForState
@@ -54,67 +63,99 @@ lemma joinForKey_initialState :
rw [joinForKey, prog.incoming_initialState_eq_nil]
rfl
variable [I : LatticeInterpretation L] [V : ValidStmtEvaluator L prog]
class ValidStateEvaluator (L : Type) [FiniteHeightLattice L] (prog : Program)
[E : StmtEvaluator L prog] [S : StateInterp L prog] where
step : (s : prog.State) {ρ₁ ρ₂ : Env} {bs : BasicStmt}
prog.code s = some bs EvalBasicStmt ρ₁ bs ρ₂ S.St ρ₁ S.St ρ₂
valid : (s : prog.State) {ρ₁ ρ₂ : Env} {bs : BasicStmt}
{vs : VariableValues L prog} {st : S.St ρ₁},
(hcode : prog.code s = some bs) (hbs : EvalBasicStmt ρ₁ bs ρ₂) vs ρ₁ st
E.eval s bs hcode vs ρ₂ (step s hcode hbs st)
botV_init : botV L prog [] S.init
instance [LatticeInterpretation L] [ValidStmtEvaluator L prog] :
ValidStateEvaluator L prog where
step := by intro _ _ _ _ _ _ _; exact PUnit.unit
valid := by intro _ _ _ _ _ _ hcode hbs hvs; exact ValidStmtEvaluator.valid hcode hbs hvs
botV_init := by intro k l _ v hmem; cases hmem
section
variable [S : StateInterp L prog] [V : ValidStateEvaluator L prog]
noncomputable def stepStmtOrNone (s : prog.State) {ρ₁ ρ₂ : Env} :
(o : Option BasicStmt) prog.code s = o EvalBasicStmtOpt ρ₁ o ρ₂
S.St ρ₁ S.St ρ₂
| none, _, .none, st => st
| some _, hco, .some hbs, st => V.step s hco hbs st
noncomputable def stepNode (s : prog.State) {ρ₁ ρ₂ : Env}
(h : EvalBasicStmtOpt ρ₁ (prog.code s) ρ₂) (st : S.St ρ₁) : S.St ρ₂ :=
stepStmtOrNone s (prog.code s) rfl h st
noncomputable def stepTraceState :
{s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env}
Trace prog.cfg s₁ s₂ ρ₁ ρ₂ S.St ρ₁ S.St ρ₂
| s₁, _, _, _, .single hnode, st => stepNode s₁ hnode st
| s₁, _, _, _, .edge hnode _ subtr, st =>
stepTraceState subtr (stepNode s₁ hnode st)
omit [DecidableEq L] in
lemma eval_fold_valid {s : prog.State} {bss : List BasicStmt}
{vs : VariableValues L prog} {ρ₁ ρ₂ : Env}
(hbss : EvalBasicStmts ρ₁ bss ρ₂) (hvs : vs ρ₁) :
bss.foldl (fun vs bs => E.eval s bs vs) vs ρ₂ := by
induction hbss generalizing vs with
| nil => exact hvs
| cons hbs _ ih => exact ih (ValidStmtEvaluator.valid hbs hvs)
omit [DecidableEq L] in
lemma updateVariablesForState_matches {s : prog.State}
{sv : StateVariables L prog} {ρ₁ ρ₂ : Env}
(hbss : EvalBasicStmts ρ₁ (prog.code s) ρ₂)
(hvs : variablesAt s sv ρ₁) :
updateVariablesForState s sv ρ₂ :=
eval_fold_valid hbss hvs
lemma evalStmtOrNone_valid {s : prog.State} {ρ₁ ρ₂ : Env} {st : S.St ρ₁}
{vs : VariableValues L prog} (o : Option BasicStmt) (hco : prog.code s = o)
(he : EvalBasicStmtOpt ρ₁ o ρ₂) (hvs : vs ρ₁ st) :
evalStmtOrNone s o hco vs ρ₂ (stepStmtOrNone s o hco he st) := by
cases he with
| none => exact hvs
| some hbs => exact V.valid s hco hbs hvs
omit [DecidableEq L] in
lemma updateAll_matches {s : prog.State} {sv : StateVariables L prog}
{ρ₁ ρ₂ : Env} (hbss : EvalBasicStmts ρ₁ (prog.code s) ρ)
(hvs : variablesAt s sv ρ₁) :
variablesAt s (updateAll sv) ρ := by
{ρ₁ ρ₂ : Env} {st : S.St ρ}
(hnode : EvalBasicStmtOpt ρ₁ (prog.code s) ρ₂)
(hvs : variablesAt s sv ρ st) :
variablesAt s (updateAll sv) ρ₂ (stepNode s hnode st) := by
rw [variablesAt_updateAll]
exact updateVariablesForState_matches hbss hvs
exact evalStmtOrNone_valid (prog.code s) rfl hnode hvs
lemma stepTrace {s₁ : prog.State} {ρ₁ ρ₂ : Env}
(hjoin : joinForKey s₁ (result L prog) ρ₁)
(hbss : EvalBasicStmts ρ₁ (prog.code s₁) ρ₂) :
variablesAt s₁ (result L prog) ρ₂ := by
lemma stepTrace {s₁ : prog.State} {ρ₁ ρ₂ : Env} {st : S.St ρ₁}
(hjoin : joinForKey s₁ (result L prog) ρ₁ st)
(hnode : EvalBasicStmtOpt ρ₁ (prog.code s₁) ρ₂) :
variablesAt s₁ (result L prog) ρ₂ (stepNode s₁ hnode st) := by
rw [result_eq L prog]
refine updateAll_matches hbss ?_
refine updateAll_matches hnode ?_
rw [variablesAt_joinAll]
exact hjoin
lemma walkTrace {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env}
(hjoin : joinForKey s₁ (result L prog) ρ₁)
lemma walkTrace {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env} {st₁ : S.St ρ₁}
(hjoin : joinForKey s₁ (result L prog) ρ₁ st₁)
(tr : Trace prog.cfg s₁ s₂ ρ₁ ρ₂) :
variablesAt s₂ (result L prog) ρ₂ := by
variablesAt s₂ (result L prog) ρ₂ (stepTraceState tr st₁) := by
induction tr with
| single hbss => exact stepTrace hjoin hbss
| @edge _ ρ' _ i₁ i₂ _ hbss hedge _ ih =>
have hstep : variablesAt i₁ (result L prog) ρ' :=
stepTrace hjoin hbss
| single hnode => exact stepTrace hjoin hnode
| @edge _ ρ' _ i₁ i₂ _ hnode hedge _ ih =>
have hstep : variablesAt i₁ (result L prog) ρ' (stepNode i₁ hnode st₁) :=
stepTrace hjoin hnode
have hmem : variablesAt i₁ (result L prog)
(result L prog).valuesAt (prog.incoming i₂) :=
FiniteMap.mem_valuesAt prog.states_nodup
(prog.mem_incoming_of_edge hedge) (variablesAt_mem i₁ (result L prog))
exact ih (interp_foldr hstep hmem)
omit V in
lemma interp_joinForKey_initialState :
joinForKey prog.initialState (result L prog) [] := by
variable (L prog) in
theorem analyze_correct_state {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) :
variablesAt prog.finalState (result L prog) ρ
(stepTraceState (prog.trace hrun) S.init) := by
refine walkTrace ?_ (prog.trace hrun)
rw [joinForKey_initialState]
exact interp_botV_nil
exact ValidStateEvaluator.botV_init
end
variable (L prog) in
theorem analyze_correct {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) :
variablesAt prog.finalState (result L prog) ρ :=
walkTrace interp_joinForKey_initialState (prog.trace hrun)
theorem analyze_correct [LatticeInterpretation L] [ValidStmtEvaluator L prog]
{ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) :
variablesAt prog.finalState (result L prog) ρ () :=
analyze_correct_state L prog hrun
end Forward