import Spa.Analysis.Forward.Lattices import Spa.Analysis.Forward.Evaluation import Spa.Analysis.Forward.Adapters import Spa.Fixedpoint namespace Spa namespace Forward variable {L : Type} [FiniteHeightLattice L] {prog : Program} [E : StmtEvaluator L prog] def updateVariablesForState (s : prog.State) (sv : StateVariables L prog) : VariableValues L prog := E.eval s (variablesAt s sv) lemma updateVariablesForState_mono (s : prog.State) : Monotone (updateVariablesForState (L := L) s) := fun _ _ hle => E.eval_mono s (variablesAt_le hle s) def updateAll (sv : StateVariables L prog) : StateVariables L prog := FiniteMap.generalizedUpdate id updateVariablesForState prog.states sv lemma updateAll_mono : Monotone (updateAll (L := L) (prog := prog)) := FiniteMap.generalizedUpdate_monotone monotone_id updateVariablesForState_mono lemma updateAll_mem_eq {s : prog.State} {vs : VariableValues L prog} {sv : StateVariables L prog} (hmem : (s, vs) ∈ updateAll sv) : vs = updateVariablesForState s sv := FiniteMap.generalizedUpdate_mem_eq (prog.states_complete s) hmem lemma variablesAt_updateAll (s : prog.State) (sv : StateVariables L prog) : variablesAt s (updateAll sv) = updateVariablesForState s sv := updateAll_mem_eq (variablesAt_mem s (updateAll sv)) def analyze (sv : StateVariables L prog) : StateVariables L prog := updateAll (joinAll sv) lemma analyze_mono : Monotone (analyze (L := L) (prog := prog)) := fun _ _ hle => updateAll_mono (joinAll_mono hle) variable [DecidableEq L] variable (L prog) in def result : StateVariables L prog := Fixedpoint.aFix analyze analyze_mono variable (L prog) in lemma result_eq : result L prog = analyze (result L prog) := Fixedpoint.aFix_eq analyze analyze_mono lemma joinForKey_initialState : joinForKey prog.initialState (result L prog) = botV L prog := by rw [joinForKey, prog.incoming_initialState_eq_nil] rfl class ValidStateEvaluator (L : Type) [FiniteHeightLattice L] (prog : Program) [E : StmtEvaluator L prog] [S : StateInterpretation L prog] where valid : ∀ (s₁ s₂ : prog.State) {ρ₁ ρ₂ ρ₃: Env} {vs : VariableValues L prog}, (tr : Traceₗ prog.cfg s₁ s₂ ρ₁ ρ₂) → (hbs : EvalBasicStmtOpt ρ₂ (prog.cfg.nodes s₂) ρ₃) → ⟦ vs ⟧ (S.Pre tr) → ⟦ E.eval s₂ vs ⟧ (S.Post (tr ++ hbs)) botV_init : ⟦ botV L prog ⟧ (S.Pre (Traceₗ.single prog.cfg prog.initialState [])) instance [LatticeInterpretation L] [ValidStmtEvaluator L prog] : ValidStateEvaluator L prog where valid := by intro _ _ _ _ _ _ tr hbs hvs; exact ValidStmtEvaluator.valid hbs hvs botV_init := by intro k l _ v hmem; cases hmem section variable [S : StateInterpretation L prog] [V : ValidStateEvaluator L prog] omit [DecidableEq L] in lemma updateAll_matches {s₁ s₂ : prog.State} {sv : StateVariables L prog} {ρ₁ ρ₂ ρ₃ : Env} (tr : Traceₗ prog.cfg s₁ s₂ ρ₁ ρ₂) (hnode : EvalBasicStmtOpt ρ₂ (prog.code s₂) ρ₃) (hvs : ⟦ variablesAt s₂ sv ⟧ (S.Pre tr)) : ⟦ variablesAt s₂ (updateAll sv) ⟧ (S.Post (tr ++ hnode)) := by rw [variablesAt_updateAll] exact V.valid s₁ s₂ tr hnode hvs lemma stepTrace {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env} (tr : Traceₗ prog.cfg s₁ s₂ ρ₁ ρ₂) (hjoin : ⟦ joinForKey s₂ (result L prog) ⟧ (S.Pre tr)) (hnode : EvalBasicStmtOpt ρ₂ (prog.code s₂) ρ₃) : ⟦ variablesAt s₂ (result L prog) ⟧ (S.Post (tr ++ hnode)) := by rw [result_eq L prog] refine updateAll_matches tr hnode ?_ rw [variablesAt_joinAll] exact hjoin /-- Soundness at *every* visited node: if the analysis result over-approximates the incoming environment at the start of the trace, then at each node reached along the way it over-approximates both the environment entering that node (via `joinForKey`) and the environment leaving it (via `variablesAt`). The intermediate `variablesAt` evidence used to be computed and discarded inside `walkTrace`; here it is returned. -/ lemma walkTrace_reaches {s₁ s₂ s₃: prog.State} {ρ₁ ρ₂ ρ₃: Env} {s : prog.State} {ρin ρout : Env} {tr : Trace prog.cfg s₂ s₃ ρ₂ ρ₃} (hr : Reaches tr s ρin ρout) (trₗ : Traceₗ prog.cfg s₁ s₂ ρ₁ ρ₂) (hjoin : ⟦ joinForKey s₂ (result L prog) ⟧ (S.Pre trₗ)) : ⟦ joinForKey s (result L prog) ⟧ (S.Pre (trₗ ++ hr.pre)) ∧ ⟦ variablesAt s (result L prog) ⟧ (S.Post (trₗ ++ hr.post)) := by induction hr with | single_here hnode => simp [Reaches.pre, Reaches.post] refine ⟨?_, ?_⟩ <;> try simpa [HAppend.hAppend] exact stepTrace trₗ hjoin hnode | edge_here hnode hedge rest => simp [Reaches.pre, Reaches.post] refine ⟨?_, ?_⟩ <;> try simpa [HAppend.hAppend] exact stepTrace trₗ hjoin hnode | edge_there hnode hedge rest hr' ih => have hstep := stepTrace trₗ hjoin hnode have hmem := FiniteMap.mem_valuesAt prog.states_nodup (prog.mem_incoming_of_edge hedge) (variablesAt_mem _ (result L prog)) simpa [Reaches.pre, Reaches.post, HAppend.hAppend] using ih ((trₗ ++ hnode).addEdge hedge) (interp_foldr (S.post_pre (trₗ ++ hnode) hedge hstep) hmem) omit [DecidableEq L] in /-- The final node of a trace is always reached, with the environment/state the trace ends in. Used to recover the final-state soundness theorem from `walkTrace_reaches`. -/ def reaches_final {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env} (tr : Trace prog.cfg s₁ s₂ ρ₁ ρ₂) : Σ ρin, Reaches tr s₂ ρin ρ₂ := match tr with | .single hnode => ⟨_, .single_here hnode⟩ | .edge hnode hedge rest => let ⟨ρin, r'⟩ := reaches_final rest; ⟨ρin, .edge_there hnode hedge _ r'⟩ omit [DecidableEq L] in /-- Reaching the final node covers the whole trace. -/ @[simp] lemma reaches_final_post {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env} (tr : Trace prog.cfg s₁ s₂ ρ₁ ρ₂) : (reaches_final tr).2.post = tr := by induction tr with | single hnode => rfl | edge hnode hedge rest ih => simp [reaches_final, Reaches.post, ih] variable (L prog) in /-- Soundness at every program point reached during execution: for any node `s` visited by the run `hrun` (witnessed by `hr`), the analysis result over-approximates both the environment entering `s` and the one leaving it. The final-state theorem `analyze_correct_state` is the special case where `s` is `prog.finalState`. -/ theorem analyze_correct_at {ρf : Env} (hrun : EvalStmt [] prog.rootStmt ρf) {s : prog.State} {ρin ρout : Env} (hr : Reaches (prog.trace hrun) s ρin ρout) : ⟦ joinForKey s (result L prog) ⟧ (S.Pre hr.pre) ∧ ⟦ variablesAt s (result L prog) ⟧ (S.Post hr.post) := by refine walkTrace_reaches hr (Traceₗ.single _ _ []) ?_ rw [joinForKey_initialState] exact ValidStateEvaluator.botV_init variable (L prog) in theorem analyze_correct' {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : ⟦ variablesAt prog.finalState (result L prog) ⟧ (S.Post (prog.trace hrun)) := by have h := (analyze_correct_at L prog hrun (reaches_final (prog.trace hrun)).2).2 rwa [reaches_final_post] at h end variable (L prog) in theorem analyze_correct [LatticeInterpretation L] [ValidStmtEvaluator L prog] {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : ⟦ variablesAt prog.finalState (result L prog) ⟧ ρ := analyze_correct' L prog hrun end Forward end Spa