This requires a few pieces: * Make node tags use `Fin n` intead of natural numbers. This makes it possible to build a finite lattice over AST nodes, and also ensure automatic, total indexing from CFG nodes into the AST that created them. For this, use the elaborator to derive the ordering statements etc. where possible. * Adjust the forward framework to enable proofs that don't just state correctness on the environment, but also on an arbitrary additional state accumulated from traversing the trace. * State the reaching definition analysis's correctness in terms of this new framework. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
163 lines
6.7 KiB
Lean4
163 lines
6.7 KiB
Lean4
import Spa.Analysis.Forward.Lattices
|
||
import Spa.Analysis.Forward.Evaluation
|
||
import Spa.Analysis.Forward.Adapters
|
||
import Spa.Fixedpoint
|
||
|
||
namespace Spa
|
||
|
||
namespace Forward
|
||
|
||
variable {L : Type} [FiniteHeightLattice L] {prog : Program} [E : StmtEvaluator L prog]
|
||
|
||
def evalStmtOrNone (s : prog.State) (o : Option BasicStmt) (hco : prog.code s = o)
|
||
(vs : VariableValues L prog) : VariableValues L prog :=
|
||
o.elimEq vs (fun bs h => E.eval s bs (hco.trans h))
|
||
|
||
lemma evalStmtOrNone_mono (s : prog.State) (o : Option BasicStmt)
|
||
(hco : prog.code s = o) : Monotone (evalStmtOrNone (L := L) s o hco) :=
|
||
elimEq_self_mono o (fun bs h vs => E.eval s bs (hco.trans h) vs)
|
||
(fun bs h => E.eval_mono s bs (hco.trans h))
|
||
|
||
def updateVariablesForState (s : prog.State) (sv : StateVariables L prog) :
|
||
VariableValues L prog :=
|
||
evalStmtOrNone s (prog.code s) rfl (variablesAt s sv)
|
||
|
||
lemma updateVariablesForState_mono (s : prog.State) :
|
||
Monotone (updateVariablesForState (L := L) s) := fun _ _ hle =>
|
||
evalStmtOrNone_mono s (prog.code s) rfl (variablesAt_le hle s)
|
||
|
||
def updateAll (sv : StateVariables L prog) : StateVariables L prog :=
|
||
FiniteMap.generalizedUpdate id updateVariablesForState
|
||
prog.states sv
|
||
|
||
lemma updateAll_mono : Monotone (updateAll (L := L) (prog := prog)) :=
|
||
FiniteMap.generalizedUpdate_monotone monotone_id updateVariablesForState_mono
|
||
|
||
lemma updateAll_mem_eq {s : prog.State} {vs : VariableValues L prog}
|
||
{sv : StateVariables L prog} (hmem : (s, vs) ∈ updateAll sv) :
|
||
vs = updateVariablesForState s sv :=
|
||
FiniteMap.generalizedUpdate_mem_eq (prog.states_complete s) hmem
|
||
|
||
lemma variablesAt_updateAll (s : prog.State) (sv : StateVariables L prog) :
|
||
variablesAt s (updateAll sv) = updateVariablesForState s sv :=
|
||
updateAll_mem_eq (variablesAt_mem s (updateAll sv))
|
||
|
||
def analyze (sv : StateVariables L prog) : StateVariables L prog :=
|
||
updateAll (joinAll sv)
|
||
|
||
lemma analyze_mono : Monotone (analyze (L := L) (prog := prog)) := fun _ _ hle =>
|
||
updateAll_mono (joinAll_mono hle)
|
||
|
||
variable [DecidableEq L]
|
||
|
||
variable (L prog) in
|
||
def result : StateVariables L prog :=
|
||
Fixedpoint.aFix analyze analyze_mono
|
||
|
||
variable (L prog) in
|
||
lemma result_eq : result L prog = analyze (result L prog) :=
|
||
Fixedpoint.aFix_eq analyze analyze_mono
|
||
|
||
lemma joinForKey_initialState :
|
||
joinForKey prog.initialState (result L prog) = botV L prog := by
|
||
rw [joinForKey, prog.incoming_initialState_eq_nil]
|
||
rfl
|
||
|
||
class ValidStateEvaluator (L : Type) [FiniteHeightLattice L] (prog : Program)
|
||
[E : StmtEvaluator L prog] [S : StateInterp L prog] where
|
||
step : (s : prog.State) → {ρ₁ ρ₂ : Env} → {bs : BasicStmt} →
|
||
prog.code s = some bs → EvalBasicStmt ρ₁ bs ρ₂ → S.St ρ₁ → S.St ρ₂
|
||
valid : ∀ (s : prog.State) {ρ₁ ρ₂ : Env} {bs : BasicStmt}
|
||
{vs : VariableValues L prog} {st : S.St ρ₁},
|
||
(hcode : prog.code s = some bs) → (hbs : EvalBasicStmt ρ₁ bs ρ₂) → ⟦ vs ⟧ ρ₁ st →
|
||
⟦ E.eval s bs hcode vs ⟧ ρ₂ (step s hcode hbs st)
|
||
botV_init : ⟦ botV L prog ⟧ [] S.init
|
||
|
||
instance [LatticeInterpretation L] [ValidStmtEvaluator L prog] :
|
||
ValidStateEvaluator L prog where
|
||
step := by intro _ _ _ _ _ _ _; exact PUnit.unit
|
||
valid := by intro _ _ _ _ _ _ hcode hbs hvs; exact ValidStmtEvaluator.valid hcode hbs hvs
|
||
botV_init := by intro k l _ v hmem; cases hmem
|
||
|
||
section
|
||
variable [S : StateInterp L prog] [V : ValidStateEvaluator L prog]
|
||
|
||
noncomputable def stepStmtOrNone (s : prog.State) {ρ₁ ρ₂ : Env} :
|
||
(o : Option BasicStmt) → prog.code s = o → EvalBasicStmtOpt ρ₁ o ρ₂ →
|
||
S.St ρ₁ → S.St ρ₂
|
||
| none, _, .none, st => st
|
||
| some _, hco, .some hbs, st => V.step s hco hbs st
|
||
|
||
noncomputable def stepNode (s : prog.State) {ρ₁ ρ₂ : Env}
|
||
(h : EvalBasicStmtOpt ρ₁ (prog.code s) ρ₂) (st : S.St ρ₁) : S.St ρ₂ :=
|
||
stepStmtOrNone s (prog.code s) rfl h st
|
||
|
||
noncomputable def stepTraceState :
|
||
{s₁ s₂ : prog.State} → {ρ₁ ρ₂ : Env} →
|
||
Trace prog.cfg s₁ s₂ ρ₁ ρ₂ → S.St ρ₁ → S.St ρ₂
|
||
| s₁, _, _, _, .single hnode, st => stepNode s₁ hnode st
|
||
| s₁, _, _, _, .edge hnode _ subtr, st =>
|
||
stepTraceState subtr (stepNode s₁ hnode st)
|
||
|
||
omit [DecidableEq L] in
|
||
lemma evalStmtOrNone_valid {s : prog.State} {ρ₁ ρ₂ : Env} {st : S.St ρ₁}
|
||
{vs : VariableValues L prog} (o : Option BasicStmt) (hco : prog.code s = o)
|
||
(he : EvalBasicStmtOpt ρ₁ o ρ₂) (hvs : ⟦ vs ⟧ ρ₁ st) :
|
||
⟦ evalStmtOrNone s o hco vs ⟧ ρ₂ (stepStmtOrNone s o hco he st) := by
|
||
cases he with
|
||
| none => exact hvs
|
||
| some hbs => exact V.valid s hco hbs hvs
|
||
|
||
omit [DecidableEq L] in
|
||
lemma updateAll_matches {s : prog.State} {sv : StateVariables L prog}
|
||
{ρ₁ ρ₂ : Env} {st : S.St ρ₁}
|
||
(hnode : EvalBasicStmtOpt ρ₁ (prog.code s) ρ₂)
|
||
(hvs : ⟦ variablesAt s sv ⟧ ρ₁ st) :
|
||
⟦ variablesAt s (updateAll sv) ⟧ ρ₂ (stepNode s hnode st) := by
|
||
rw [variablesAt_updateAll]
|
||
exact evalStmtOrNone_valid (prog.code s) rfl hnode hvs
|
||
|
||
lemma stepTrace {s₁ : prog.State} {ρ₁ ρ₂ : Env} {st : S.St ρ₁}
|
||
(hjoin : ⟦ joinForKey s₁ (result L prog) ⟧ ρ₁ st)
|
||
(hnode : EvalBasicStmtOpt ρ₁ (prog.code s₁) ρ₂) :
|
||
⟦ variablesAt s₁ (result L prog) ⟧ ρ₂ (stepNode s₁ hnode st) := by
|
||
rw [result_eq L prog]
|
||
refine updateAll_matches hnode ?_
|
||
rw [variablesAt_joinAll]
|
||
exact hjoin
|
||
|
||
lemma walkTrace {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env} {st₁ : S.St ρ₁}
|
||
(hjoin : ⟦ joinForKey s₁ (result L prog) ⟧ ρ₁ st₁)
|
||
(tr : Trace prog.cfg s₁ s₂ ρ₁ ρ₂) :
|
||
⟦ variablesAt s₂ (result L prog) ⟧ ρ₂ (stepTraceState tr st₁) := by
|
||
induction tr with
|
||
| single hnode => exact stepTrace hjoin hnode
|
||
| @edge _ ρ' _ i₁ i₂ _ hnode hedge _ ih =>
|
||
have hstep : ⟦ variablesAt i₁ (result L prog) ⟧ ρ' (stepNode i₁ hnode st₁) :=
|
||
stepTrace hjoin hnode
|
||
have hmem : variablesAt i₁ (result L prog)
|
||
∈ (result L prog).valuesAt (prog.incoming i₂) :=
|
||
FiniteMap.mem_valuesAt prog.states_nodup
|
||
(prog.mem_incoming_of_edge hedge) (variablesAt_mem i₁ (result L prog))
|
||
exact ih (interp_foldr hstep hmem)
|
||
|
||
variable (L prog) in
|
||
theorem analyze_correct_state {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) :
|
||
⟦ variablesAt prog.finalState (result L prog) ⟧ ρ
|
||
(stepTraceState (prog.trace hrun) S.init) := by
|
||
refine walkTrace ?_ (prog.trace hrun)
|
||
rw [joinForKey_initialState]
|
||
exact ValidStateEvaluator.botV_init
|
||
|
||
end
|
||
|
||
variable (L prog) in
|
||
theorem analyze_correct [LatticeInterpretation L] [ValidStmtEvaluator L prog]
|
||
{ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) :
|
||
⟦ variablesAt prog.finalState (result L prog) ⟧ ρ () :=
|
||
analyze_correct_state L prog hrun
|
||
|
||
end Forward
|
||
|
||
end Spa
|