diff --git a/lean/Spa.lean b/lean/Spa.lean index d248ebf..bb2bcd9 100644 --- a/lean/Spa.lean +++ b/lean/Spa.lean @@ -19,3 +19,10 @@ import Spa.Showable import Spa.Analysis.Utils import Spa.Analysis.Sign import Spa.Analysis.Constant +import Spa.Language.Tagged.Id +import Spa.Language.Tagged.Derive +import Spa.Language.Tagged.Basic +import Spa.Language.Tagged.Properties +import Spa.Language.Tagged.Graphs +import Spa.Analysis.Reaching +import Spa.Transformation.Licm diff --git a/lean/Spa/Analysis/Constant.lean b/lean/Spa/Analysis/Constant.lean index 62386de..3efa702 100644 --- a/lean/Spa/Analysis/Constant.lean +++ b/lean/Spa/Analysis/Constant.lean @@ -158,7 +158,7 @@ instance eval_valid : ValidExprEvaluator ConstLattice prog := by exact minus_valid h₁ h₂ theorem analyze_correct {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : - ⟦ variablesAt prog.finalState (result ConstLattice prog) ⟧ ρ := + ⟦ variablesAt prog.finalState (result ConstLattice prog) ⟧ ρ () := Forward.analyze_correct ConstLattice prog hrun end ConstAnalysis diff --git a/lean/Spa/Analysis/Forward.lean b/lean/Spa/Analysis/Forward.lean index 497226a..cb2777f 100644 --- a/lean/Spa/Analysis/Forward.lean +++ b/lean/Spa/Analysis/Forward.lean @@ -9,13 +9,22 @@ namespace Forward variable {L : Type} [FiniteHeightLattice L] {prog : Program} [E : StmtEvaluator L prog] +def evalStmtOrNone (s : prog.State) (o : Option BasicStmt) (hco : prog.code s = o) + (vs : VariableValues L prog) : VariableValues L prog := + o.elimEq vs (fun bs h => E.eval s bs (hco.trans h)) + +lemma evalStmtOrNone_mono (s : prog.State) (o : Option BasicStmt) + (hco : prog.code s = o) : Monotone (evalStmtOrNone (L := L) s o hco) := + elimEq_self_mono o (fun bs h vs => E.eval s bs (hco.trans h) vs) + (fun bs h => E.eval_mono s bs (hco.trans h)) + def updateVariablesForState (s : prog.State) (sv : StateVariables L prog) : VariableValues L prog := - (prog.code s).foldl (fun vs bs => E.eval s bs vs) (variablesAt s sv) + evalStmtOrNone s (prog.code s) rfl (variablesAt s sv) lemma updateVariablesForState_mono (s : prog.State) : Monotone (updateVariablesForState (L := L) s) := fun _ _ hle => - foldl_mono' (prog.code s) _ (E.eval_mono s ·) (variablesAt_le hle s) + evalStmtOrNone_mono s (prog.code s) rfl (variablesAt_le hle s) def updateAll (sv : StateVariables L prog) : StateVariables L prog := FiniteMap.generalizedUpdate id updateVariablesForState @@ -54,67 +63,99 @@ lemma joinForKey_initialState : rw [joinForKey, prog.incoming_initialState_eq_nil] rfl -variable [I : LatticeInterpretation L] [V : ValidStmtEvaluator L prog] +class ValidStateEvaluator (L : Type) [FiniteHeightLattice L] (prog : Program) + [E : StmtEvaluator L prog] [S : StateInterp L prog] where + step : (s : prog.State) → {ρ₁ ρ₂ : Env} → {bs : BasicStmt} → + prog.code s = some bs → EvalBasicStmt ρ₁ bs ρ₂ → S.St ρ₁ → S.St ρ₂ + valid : ∀ (s : prog.State) {ρ₁ ρ₂ : Env} {bs : BasicStmt} + {vs : VariableValues L prog} {st : S.St ρ₁}, + (hcode : prog.code s = some bs) → (hbs : EvalBasicStmt ρ₁ bs ρ₂) → ⟦ vs ⟧ ρ₁ st → + ⟦ E.eval s bs hcode vs ⟧ ρ₂ (step s hcode hbs st) + botV_init : ⟦ botV L prog ⟧ [] S.init + +instance [LatticeInterpretation L] [ValidStmtEvaluator L prog] : + ValidStateEvaluator L prog where + step := by intro _ _ _ _ _ _ _; exact PUnit.unit + valid := by intro _ _ _ _ _ _ hcode hbs hvs; exact ValidStmtEvaluator.valid hcode hbs hvs + botV_init := by intro k l _ v hmem; cases hmem + +section +variable [S : StateInterp L prog] [V : ValidStateEvaluator L prog] + +noncomputable def stepStmtOrNone (s : prog.State) {ρ₁ ρ₂ : Env} : + (o : Option BasicStmt) → prog.code s = o → EvalBasicStmtOpt ρ₁ o ρ₂ → + S.St ρ₁ → S.St ρ₂ + | none, _, .none, st => st + | some _, hco, .some hbs, st => V.step s hco hbs st + +noncomputable def stepNode (s : prog.State) {ρ₁ ρ₂ : Env} + (h : EvalBasicStmtOpt ρ₁ (prog.code s) ρ₂) (st : S.St ρ₁) : S.St ρ₂ := + stepStmtOrNone s (prog.code s) rfl h st + +noncomputable def stepTraceState : + {s₁ s₂ : prog.State} → {ρ₁ ρ₂ : Env} → + Trace prog.cfg s₁ s₂ ρ₁ ρ₂ → S.St ρ₁ → S.St ρ₂ + | s₁, _, _, _, .single hnode, st => stepNode s₁ hnode st + | s₁, _, _, _, .edge hnode _ subtr, st => + stepTraceState subtr (stepNode s₁ hnode st) omit [DecidableEq L] in -lemma eval_fold_valid {s : prog.State} {bss : List BasicStmt} - {vs : VariableValues L prog} {ρ₁ ρ₂ : Env} - (hbss : EvalBasicStmts ρ₁ bss ρ₂) (hvs : ⟦ vs ⟧ ρ₁) : - ⟦ bss.foldl (fun vs bs => E.eval s bs vs) vs ⟧ ρ₂ := by - induction hbss generalizing vs with - | nil => exact hvs - | cons hbs _ ih => exact ih (ValidStmtEvaluator.valid hbs hvs) - -omit [DecidableEq L] in -lemma updateVariablesForState_matches {s : prog.State} - {sv : StateVariables L prog} {ρ₁ ρ₂ : Env} - (hbss : EvalBasicStmts ρ₁ (prog.code s) ρ₂) - (hvs : ⟦ variablesAt s sv ⟧ ρ₁) : - ⟦ updateVariablesForState s sv ⟧ ρ₂ := - eval_fold_valid hbss hvs +lemma evalStmtOrNone_valid {s : prog.State} {ρ₁ ρ₂ : Env} {st : S.St ρ₁} + {vs : VariableValues L prog} (o : Option BasicStmt) (hco : prog.code s = o) + (he : EvalBasicStmtOpt ρ₁ o ρ₂) (hvs : ⟦ vs ⟧ ρ₁ st) : + ⟦ evalStmtOrNone s o hco vs ⟧ ρ₂ (stepStmtOrNone s o hco he st) := by + cases he with + | none => exact hvs + | some hbs => exact V.valid s hco hbs hvs omit [DecidableEq L] in lemma updateAll_matches {s : prog.State} {sv : StateVariables L prog} - {ρ₁ ρ₂ : Env} (hbss : EvalBasicStmts ρ₁ (prog.code s) ρ₂) - (hvs : ⟦ variablesAt s sv ⟧ ρ₁) : - ⟦ variablesAt s (updateAll sv) ⟧ ρ₂ := by + {ρ₁ ρ₂ : Env} {st : S.St ρ₁} + (hnode : EvalBasicStmtOpt ρ₁ (prog.code s) ρ₂) + (hvs : ⟦ variablesAt s sv ⟧ ρ₁ st) : + ⟦ variablesAt s (updateAll sv) ⟧ ρ₂ (stepNode s hnode st) := by rw [variablesAt_updateAll] - exact updateVariablesForState_matches hbss hvs + exact evalStmtOrNone_valid (prog.code s) rfl hnode hvs -lemma stepTrace {s₁ : prog.State} {ρ₁ ρ₂ : Env} - (hjoin : ⟦ joinForKey s₁ (result L prog) ⟧ ρ₁) - (hbss : EvalBasicStmts ρ₁ (prog.code s₁) ρ₂) : - ⟦ variablesAt s₁ (result L prog) ⟧ ρ₂ := by +lemma stepTrace {s₁ : prog.State} {ρ₁ ρ₂ : Env} {st : S.St ρ₁} + (hjoin : ⟦ joinForKey s₁ (result L prog) ⟧ ρ₁ st) + (hnode : EvalBasicStmtOpt ρ₁ (prog.code s₁) ρ₂) : + ⟦ variablesAt s₁ (result L prog) ⟧ ρ₂ (stepNode s₁ hnode st) := by rw [result_eq L prog] - refine updateAll_matches hbss ?_ + refine updateAll_matches hnode ?_ rw [variablesAt_joinAll] exact hjoin -lemma walkTrace {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env} - (hjoin : ⟦ joinForKey s₁ (result L prog) ⟧ ρ₁) +lemma walkTrace {s₁ s₂ : prog.State} {ρ₁ ρ₂ : Env} {st₁ : S.St ρ₁} + (hjoin : ⟦ joinForKey s₁ (result L prog) ⟧ ρ₁ st₁) (tr : Trace prog.cfg s₁ s₂ ρ₁ ρ₂) : - ⟦ variablesAt s₂ (result L prog) ⟧ ρ₂ := by + ⟦ variablesAt s₂ (result L prog) ⟧ ρ₂ (stepTraceState tr st₁) := by induction tr with - | single hbss => exact stepTrace hjoin hbss - | @edge _ ρ' _ i₁ i₂ _ hbss hedge _ ih => - have hstep : ⟦ variablesAt i₁ (result L prog) ⟧ ρ' := - stepTrace hjoin hbss + | single hnode => exact stepTrace hjoin hnode + | @edge _ ρ' _ i₁ i₂ _ hnode hedge _ ih => + have hstep : ⟦ variablesAt i₁ (result L prog) ⟧ ρ' (stepNode i₁ hnode st₁) := + stepTrace hjoin hnode have hmem : variablesAt i₁ (result L prog) ∈ (result L prog).valuesAt (prog.incoming i₂) := FiniteMap.mem_valuesAt prog.states_nodup (prog.mem_incoming_of_edge hedge) (variablesAt_mem i₁ (result L prog)) exact ih (interp_foldr hstep hmem) -omit V in -lemma interp_joinForKey_initialState : - ⟦ joinForKey prog.initialState (result L prog) ⟧ [] := by +variable (L prog) in +theorem analyze_correct_state {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : + ⟦ variablesAt prog.finalState (result L prog) ⟧ ρ + (stepTraceState (prog.trace hrun) S.init) := by + refine walkTrace ?_ (prog.trace hrun) rw [joinForKey_initialState] - exact interp_botV_nil + exact ValidStateEvaluator.botV_init + +end variable (L prog) in -theorem analyze_correct {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : - ⟦ variablesAt prog.finalState (result L prog) ⟧ ρ := - walkTrace interp_joinForKey_initialState (prog.trace hrun) +theorem analyze_correct [LatticeInterpretation L] [ValidStmtEvaluator L prog] + {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : + ⟦ variablesAt prog.finalState (result L prog) ⟧ ρ () := + analyze_correct_state L prog hrun end Forward diff --git a/lean/Spa/Analysis/Forward/Adapters.lean b/lean/Spa/Analysis/Forward/Adapters.lean index 8d0f9ca..8baeea4 100644 --- a/lean/Spa/Analysis/Forward/Adapters.lean +++ b/lean/Spa/Analysis/Forward/Adapters.lean @@ -14,14 +14,14 @@ lemma updateVariablesFromExpression_mono (k : String) (e : Expr) : Monotone (updateVariablesFromExpression (L := L) (prog := prog) k e) := FiniteMap.generalizedUpdate_monotone monotone_id (fun _ => E.eval_mono e) -def evalBasicStmt (_ : prog.State) (bs : BasicStmt) +def evalBasicStmt (s : prog.State) (bs : BasicStmt) (_h : prog.code s = some bs) (vs : VariableValues L prog) : VariableValues L prog := match bs with | .assign k e => updateVariablesFromExpression k e vs | .noop => vs -lemma evalBasicStmt_mono (s : prog.State) (bs : BasicStmt) : - Monotone (evalBasicStmt (L := L) (prog := prog) s bs) := by +lemma evalBasicStmt_mono (s : prog.State) (bs : BasicStmt) (h : prog.code s = some bs) : + Monotone (evalBasicStmt (L := L) (prog := prog) s bs h) := by cases bs with | assign k e => exact updateVariablesFromExpression_mono k e | noop => exact monotone_id @@ -32,7 +32,7 @@ instance ExprEvaluator.toStmtEvaluator : StmtEvaluator L prog := instance ExprEvaluator.toStmtEvaluator_valid [LatticeInterpretation L] [ValidExprEvaluator L prog] : ValidStmtEvaluator L prog := by constructor - intro s vs ρ₁ ρ₂ bs hbs hvs + intro s vs ρ₁ ρ₂ bs hcode hbs hvs cases hbs with | noop => exact hvs | assign k e v hev => diff --git a/lean/Spa/Analysis/Forward/Evaluation.lean b/lean/Spa/Analysis/Forward/Evaluation.lean index 9f30b9f..c55b121 100644 --- a/lean/Spa/Analysis/Forward/Evaluation.lean +++ b/lean/Spa/Analysis/Forward/Evaluation.lean @@ -7,8 +7,9 @@ namespace Forward variable (L : Type) [Lattice L] (prog : Program) class StmtEvaluator where - eval : prog.State → BasicStmt → VariableValues L prog → VariableValues L prog - eval_mono : ∀ s bs, Monotone (eval s bs) + eval : (s : prog.State) → (bs : BasicStmt) → prog.code s = some bs → + VariableValues L prog → VariableValues L prog + eval_mono : ∀ s bs h, Monotone (eval s bs h) class ExprEvaluator where eval : Expr → VariableValues L prog → L @@ -17,13 +18,13 @@ class ExprEvaluator where class ValidExprEvaluator [ExprEvaluator L prog] [I : LatticeInterpretation L] : Prop where valid : ∀ {vs : VariableValues L prog} {ρ : Env} {e : Expr} {v : Value}, - EvalExpr ρ e v → ⟦ vs ⟧ ρ → I.interp (ExprEvaluator.eval e vs) v + EvalExpr ρ e v → ⟦ vs ⟧ ρ () → I.interp (ExprEvaluator.eval e vs) v class ValidStmtEvaluator [E : StmtEvaluator L prog] [LatticeInterpretation L] : Prop where valid : ∀ {s : prog.State} {vs : VariableValues L prog} {ρ₁ ρ₂ : Env} - {bs : BasicStmt}, - EvalBasicStmt ρ₁ bs ρ₂ → ⟦ vs ⟧ ρ₁ → ⟦ E.eval s bs vs ⟧ ρ₂ + {bs : BasicStmt} (hcode : prog.code s = some bs), + EvalBasicStmt ρ₁ bs ρ₂ → ⟦ vs ⟧ ρ₁ () → ⟦ E.eval s bs hcode vs ⟧ ρ₂ () end Forward diff --git a/lean/Spa/Analysis/Forward/Lattices.lean b/lean/Spa/Analysis/Forward/Lattices.lean index 80b7e3b..1af6aa4 100644 --- a/lean/Spa/Analysis/Forward/Lattices.lean +++ b/lean/Spa/Analysis/Forward/Lattices.lean @@ -64,39 +64,47 @@ lemma variablesAt_joinAll (s : prog.State) (sv : StateVariables L prog) : variablesAt s (joinAll sv) = joinForKey s sv := joinAll_mem_eq (variablesAt_mem s (joinAll sv)) -/-! ### Lifting an interpretation to variable maps -/ +class StateInterp (L : Type) [Lattice L] (prog : Program) where + St : Env → Type + init : St [] + interp : VariableValues L prog → (ρ : Env) → St ρ → Prop + interp_sup : ∀ {vs₁ vs₂ : VariableValues L prog} {ρ : Env} {st : St ρ}, + interp vs₁ ρ st ∨ interp vs₂ ρ st → interp (vs₁ ⊔ vs₂) ρ st + interp_inf : ∀ {vs₁ vs₂ : VariableValues L prog} {ρ : Env} {st : St ρ}, + interp vs₁ ρ st ∧ interp vs₂ ρ st → interp (vs₁ ⊓ vs₂) ρ st -variable [I : LatticeInterpretation L] +instance [S : StateInterp L prog] : + Interp (VariableValues L prog) ((ρ : Env) → S.St ρ → Prop) := + ⟨S.interp⟩ -omit [FiniteHeightLattice L] in -instance : Interp (VariableValues L prog) (Env → Prop) where - interp (vs : VariableValues L prog) (ρ : Env) : Prop := - ∀ (k : String) (l : L), (k, l) ∈ vs → - ∀ (v : Value), Env.Mem (k, v) ρ → I.interp l v - -lemma interp_botV_nil : ⟦ botV L prog ⟧ [] := by - intro k l _ v hmem - cases hmem - -omit [FiniteHeightLattice L] in -lemma interp_sup {vs₁ vs₂ : VariableValues L prog} {ρ : Env} - (h : ⟦ vs₁⟧ ρ ∨ ⟦ vs₂ ⟧ ρ) : ⟦ vs₁ ⊔ vs₂ ⟧ ρ := by - intro k l hmem v hv - obtain ⟨l₁, l₂, rfl, h₁, h₂⟩ := FiniteMap.mem_sup hmem - rcases h with h | h - · exact I.interp_sup v (Or.inl (h _ _ h₁ _ hv)) - · exact I.interp_sup v (Or.inr (h _ _ h₂ _ hv)) - -lemma interp_foldr {vs : VariableValues L prog} - {vss : List (VariableValues L prog)} {ρ : Env} - (hvs : ⟦ vs ⟧ ρ) (hmem : vs ∈ vss) : - ⟦ vss.foldr (· ⊔ ·) (botV L prog) ⟧ ρ := by +lemma interp_foldr [S : StateInterp L prog] + {vs : VariableValues L prog} {vss : List (VariableValues L prog)} + {ρ : Env} {st : S.St ρ} (hvs : ⟦ vs ⟧ ρ st) (hmem : vs ∈ vss) : + ⟦ vss.foldr (· ⊔ ·) (botV L prog) ⟧ ρ st := by induction vss with | nil => cases hmem | cons vs' vss' ih => rcases List.mem_cons.mp hmem with rfl | hmem' - · exact interp_sup (Or.inl hvs) - · exact interp_sup (Or.inr (ih hmem')) + · exact S.interp_sup (Or.inl hvs) + · exact S.interp_sup (Or.inr (ih hmem')) + +variable [I : LatticeInterpretation L] + +instance : StateInterp L prog where + St := fun _ => PUnit + init := PUnit.unit + interp vs ρ _ := ∀ (k : String) (l : L), (k, l) ∈ vs → + ∀ (v : Value), Env.Mem (k, v) ρ → I.interp l v + interp_sup := by + intro vs₁ vs₂ ρ st h k l hmem v hv + obtain ⟨l₁, l₂, rfl, h₁, h₂⟩ := FiniteMap.mem_sup hmem + rcases h with h | h + · exact I.interp_sup v (Or.inl (h _ _ h₁ _ hv)) + · exact I.interp_sup v (Or.inr (h _ _ h₂ _ hv)) + interp_inf := by + intro vs₁ vs₂ ρ st h k l hmem v hv + obtain ⟨l₁, l₂, rfl, h₁, h₂⟩ := FiniteMap.mem_inf hmem + exact I.interp_inf v ⟨h.1 _ _ h₁ _ hv, h.2 _ _ h₂ _ hv⟩ end Forward diff --git a/lean/Spa/Analysis/Reaching.lean b/lean/Spa/Analysis/Reaching.lean index b34137f..e156b7b 100644 --- a/lean/Spa/Analysis/Reaching.lean +++ b/lean/Spa/Analysis/Reaching.lean @@ -1,5 +1,7 @@ import Spa.Analysis.Forward import Spa.Lattice.Bool +import Spa.Lattice.Tuple +import Spa.Language.Tagged.Graphs import Spa.Showable namespace Spa @@ -8,23 +10,31 @@ open Forward instance : Showable Bool := ⟨fun b => if b then "true" else "false"⟩ -abbrev DefSet (prog : Program) : Type := FiniteMap prog.State Bool prog.states +instance {n : ℕ} {β : Type*} [Showable β] : Showable (Fin n → β) := + ⟨fun f => + "{" ++ (List.finRange n).foldr + (fun i rest => show' i ++ " ↦ " ++ show' (f i) ++ ", " ++ rest) "" + ++ "}"⟩ + +abbrev DefSet (prog : Program) : Type := prog.NodeId → Bool namespace ReachingAnalysis variable (prog : Program) -def genSet (s : prog.State) : DefSet prog := - FiniteMap.updating (⊥ : DefSet prog) [s] (fun _ => true) +def genSet (s : prog.State) {bs : BasicStmt} (h : prog.code s = some bs) : + DefSet prog := + Function.update (⊥ : DefSet prog) (prog.nodeIdOfNonempty s h) true def eval (s : prog.State) : - BasicStmt → VariableValues (DefSet prog) prog → VariableValues (DefSet prog) prog - | .assign k _, vs => - FiniteMap.generalizedUpdate id (fun _ _ => genSet prog s) [k] vs - | .noop, vs => vs + (bs : BasicStmt) → prog.code s = some bs → + VariableValues (DefSet prog) prog → VariableValues (DefSet prog) prog + | .assign k _, h, vs => + FiniteMap.generalizedUpdate id (fun _ _ => genSet prog s h) [k] vs + | .noop, _, vs => vs -lemma eval_mono (s : prog.State) (bs : BasicStmt) : - Monotone (eval prog s bs) := by +lemma eval_mono (s : prog.State) (bs : BasicStmt) (h : prog.code s = some bs) : + Monotone (eval prog s bs h) := by cases bs with | assign k e => exact FiniteMap.generalizedUpdate_monotone monotone_id (fun _ => monotone_const) @@ -36,6 +46,86 @@ instance stmtEvaluator : StmtEvaluator (DefSet prog) prog := def output : String := show' (result (DefSet prog) prog) +inductive Run (prog : Program) where + | nil : Run prog + | cons (s : prog.State) (bs : BasicStmt) (hc : prog.code s = some bs) + (rest : Run prog) : Run prog + +inductive LastAssign (prog : Program) (x : String) : Run prog → prog.NodeId → Prop + | here (s : prog.State) (e : Expr) (hc : prog.code s = some (.assign x e)) + (rest : Run prog) : + LastAssign prog x (Run.cons s (.assign x e) hc rest) (prog.nodeIdOfNonempty s hc) + | there (s : prog.State) (bs : BasicStmt) (hc : prog.code s = some bs) + (rest : Run prog) {n : prog.NodeId} : + (∀ e, bs ≠ .assign x e) → LastAssign prog x rest n → + LastAssign prog x (Run.cons s bs hc rest) n + +lemma lastAssign_cons_here {x : String} {s : prog.State} {e : Expr} + {hc : prog.code s = some (.assign x e)} {rest : Run prog} {n : prog.NodeId} + (h : LastAssign prog x (Run.cons s (.assign x e) hc rest) n) : + n = prog.nodeIdOfNonempty s hc := by + cases h with + | here _ _ _ _ => rfl + | there _ _ _ _ hne _ => exact absurd rfl (hne e) + +lemma lastAssign_cons_of_ne {x : String} {s : prog.State} {bs : BasicStmt} + {hc : prog.code s = some bs} {rest : Run prog} {n : prog.NodeId} + (h : LastAssign prog x (Run.cons s bs hc rest) n) + (hne : ∀ e, bs ≠ .assign x e) : LastAssign prog x rest n := by + cases h with + | here _ e' _ _ => exact absurd rfl (hne e') + | there _ _ _ _ _ hp => exact hp + +instance stateInterp : StateInterp (DefSet prog) prog where + St := fun _ => Run prog + init := Run.nil + interp vs _ run := ∀ (x : String) (assigners : DefSet prog), (x, assigners) ∈ vs → + ∀ (n : prog.NodeId), LastAssign prog x run n → assigners n = true + interp_sup := by + intro vs₁ vs₂ ρ run h x assigners hmem n hla + obtain ⟨a₁, a₂, rfl, h₁, h₂⟩ := FiniteMap.mem_sup hmem + rw [Pi.sup_apply] + rcases h with h | h + · aesop + · aesop + interp_inf := by + intro vs₁ vs₂ ρ run h x assigners hmem n hla + obtain ⟨a₁, a₂, rfl, h₁, h₂⟩ := FiniteMap.mem_inf hmem + rw [Pi.inf_apply] + aesop + +instance validStateEvaluator : ValidStateEvaluator (DefSet prog) prog where + step := by intro s _ _ bs hcode _ rest; exact Run.cons s bs hcode rest + valid := by + intro s ρ₁ ρ₂ bs vs st hcode hbs hvs + cases hbs with + | noop => + intro x assigners hmem n hla + exact hvs x assigners hmem n + (lastAssign_cons_of_ne prog hla (fun _ h => BasicStmt.noConfusion h)) + | assign x e v hev => + intro k assigners hmem n hla + have hmem2 : (k, assigners) ∈ + FiniteMap.generalizedUpdate id (fun _ _ => genSet prog s hcode) [x] vs := hmem + by_cases hx : k = x + · subst hx + have hd := FiniteMap.generalizedUpdate_mem_eq (List.mem_singleton.mpr rfl) hmem2 + have hn := lastAssign_cons_here prog hla + subst hd + rw [hn] + simp only [genSet, Function.update_self] + · have hp := lastAssign_cons_of_ne prog hla + (by intro e' h; injection h with h1 _; exact hx h1.symm) + have hmem' := FiniteMap.generalizedUpdate_not_mem_backward + (fun hc => hx (List.mem_singleton.mp hc)) hmem2 + exact hvs k assigners hmem' n hp + botV_init := by intro x assigners _ n hla; cases hla + +theorem analyze_correct {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : + ⟦ variablesAt prog.finalState (result (DefSet prog) prog) ⟧ ρ + (stepTraceState (prog.trace hrun) (stateInterp prog).init) := + Forward.analyze_correct_state (DefSet prog) prog hrun + end ReachingAnalysis end Spa diff --git a/lean/Spa/Analysis/Sign.lean b/lean/Spa/Analysis/Sign.lean index 24e43ac..bd91104 100644 --- a/lean/Spa/Analysis/Sign.lean +++ b/lean/Spa/Analysis/Sign.lean @@ -216,7 +216,7 @@ instance eval_valid : ValidExprEvaluator SignLattice prog := by exact minus_valid h₁ h₂ theorem analyze_correct {ρ : Env} (hrun : EvalStmt [] prog.rootStmt ρ) : - ⟦ variablesAt prog.finalState (result SignLattice prog) ⟧ ρ := + ⟦ variablesAt prog.finalState (result SignLattice prog) ⟧ ρ () := Forward.analyze_correct SignLattice prog hrun end SignAnalysis diff --git a/lean/Spa/Language.lean b/lean/Spa/Language.lean index 69400ce..f42f75c 100644 --- a/lean/Spa/Language.lean +++ b/lean/Spa/Language.lean @@ -23,7 +23,7 @@ def initialState : p.State := p.rootStmt.cfg.wrapInput def finalState : p.State := p.rootStmt.cfg.wrapOutput -theorem trace {ρ : Env} (h : EvalStmt [] p.rootStmt ρ) : +noncomputable def trace {ρ : Env} (h : EvalStmt [] p.rootStmt ρ) : Trace p.cfg p.initialState p.finalState [] ρ := by obtain ⟨i₁, h₁, i₂, h₂, tr⟩ := EndToEndTrace.wrap (Stmt.cfg_sufficient h) rw [Graph.wrap_inputs, List.mem_singleton] at h₁ @@ -41,7 +41,7 @@ lemma states_complete (s : p.State) : s ∈ p.states := p.cfg.mem_indices s lemma states_nodup : p.states.Nodup := p.cfg.nodup_indices -def code (st : p.State) : List BasicStmt := p.cfg.nodes st +def code (st : p.State) : Option BasicStmt := p.cfg.nodes st def incoming (s : p.State) : List p.State := p.cfg.predecessors s diff --git a/lean/Spa/Language/Graphs.lean b/lean/Spa/Language/Graphs.lean index 59c5223..6f7503d 100644 --- a/lean/Spa/Language/Graphs.lean +++ b/lean/Spa/Language/Graphs.lean @@ -130,9 +130,9 @@ def loopOut (g : GGraph α) : Fin (2 + g.size) := (1 : Fin 2).castAdd g.size This is technically sloppy (see module comment), but it's simple. -/ -def loop (g : GGraph (List β)) : GGraph (List β) where +def loop (g : GGraph (Option β)) : GGraph (Option β) where size := 2 + g.size - nodes := Fin.append (fun _ : Fin 2 => []) g.nodes + nodes := Fin.append (fun _ : Fin 2 => none) g.nodes edges := g.edges.finNatAddProd 2 ++ ((g.loopIn, ·) <$> g.inputs.finNatAdd 2) ++ ((·, g.loopOut) <$> g.outputs.finNatAdd 2) ++ @@ -140,9 +140,9 @@ def loop (g : GGraph (List β)) : GGraph (List β) where inputs := [g.loopIn] outputs := [g.loopOut] -@[simp] lemma loop_inputs (g : GGraph (List β)) : (loop g).inputs = [g.loopIn] := rfl +@[simp] lemma loop_inputs (g : GGraph (Option β)) : (loop g).inputs = [g.loopIn] := rfl -@[simp] lemma loop_outputs (g : GGraph (List β)) : (loop g).outputs = [g.loopOut] := rfl +@[simp] lemma loop_outputs (g : GGraph (Option β)) : (loop g).outputs = [g.loopOut] := rfl /-- Creates a single-node graph whose node contains the given value. -/ def singleton (a : α) : GGraph α where @@ -154,8 +154,8 @@ def singleton (a : α) : GGraph α where /-- Creates a new graph with a single input and single output node. Useful to ensure there's a single point of entry and single point of exit. -/ -def wrap (g : GGraph (List β)) : GGraph (List β) := - singleton [] ⤳ g ⤳ singleton [] +def wrap (g : GGraph (Option β)) : GGraph (Option β) := + singleton none ⤳ g ⤳ singleton none @[simp] lemma map_singleton (f : α → β) (a : α) : f <$> singleton a = singleton (f a) := rfl @@ -176,16 +176,16 @@ def wrap (g : GGraph (List β)) : GGraph (List β) := funext i refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right] -@[simp] lemma map_loop (h : β → γ) (g : GGraph (List β)) : - (List.map h) <$> (loop g) = loop (List.map h <$> g) := by +@[simp] lemma map_loop (h : β → γ) (g : GGraph (Option β)) : + (Option.map h) <$> (loop g) = loop (Option.map h <$> g) := by rcases g with ⟨n, nd, e, i, o⟩ simp only [Functor.map, GGraph.loop] congr 1 funext i refine Fin.addCases ?_ ?_ i <;> intro j <;> simp [Fin.append_left, Fin.append_right] -@[simp] lemma map_wrap (h : β → γ) (g : GGraph (List β)) : - (List.map h) <$> wrap g = wrap (List.map h <$> g) := by +@[simp] lemma map_wrap (h : β → γ) (g : GGraph (Option β)) : + (Option.map h) <$> wrap g = wrap (Option.map h <$> g) := by simp [GGraph.wrap, GGraph.map_sequence, GGraph.map_singleton] variable (g : GGraph α) @@ -220,8 +220,8 @@ lemma edge_of_mem_predecessors {idx₁ idx₂ : g.Index} end GGraph /-- "Normal" graphs, for the purposes of the analyses in this - framework, have basic blocks in their nodes, and nothing else. -/ -abbrev Graph : Type := GGraph (List BasicStmt) + framework, have basic statements in their nodes, and nothing else. -/ +abbrev Graph : Type := GGraph (Option BasicStmt) namespace Graph @@ -235,7 +235,7 @@ end Graph open Graph in def Stmt.cfg : Stmt → Graph -- A basic statement goes into a single basic block - | .basic bs => singleton [bs] + | .basic bs => singleton (some bs) -- Sequencing of statements corresponds naturally to CFG sequencing | .andThen s₁ s₂ => s₁.cfg ⤳ s₂.cfg -- An if can execute either one branch or the other; overlap them. diff --git a/lean/Spa/Language/Properties.lean b/lean/Spa/Language/Properties.lean index e55c69d..505b2c6 100644 --- a/lean/Spa/Language/Properties.lean +++ b/lean/Spa/Language/Properties.lean @@ -17,7 +17,7 @@ section Embeddings variable {g₁ g₂ : Graph} {ρ₁ ρ₂ : Env} -lemma Trace.overlay_left {idx₁ idx₂ : g₁.Index} +noncomputable def Trace.overlay_left {idx₁ idx₂ : g₁.Index} (tr : Trace g₁ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ∙ g₂) (idx₁.castAdd g₂.size) (idx₂.castAdd g₂.size) ρ₁ ρ₂ := by induction tr with @@ -29,7 +29,7 @@ lemma Trace.overlay_left {idx₁ idx₂ : g₁.Index} · rwa [show (g₁ ∙ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_left] · exact List.mem_append_left _ (List.mem_map_of_mem _ he) -lemma Trace.overlay_right {idx₁ idx₂ : g₂.Index} +noncomputable def Trace.overlay_right {idx₁ idx₂ : g₂.Index} (tr : Trace g₂ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ∙ g₂) (idx₁.natAdd g₁.size) (idx₂.natAdd g₁.size) ρ₁ ρ₂ := by induction tr with @@ -41,7 +41,7 @@ lemma Trace.overlay_right {idx₁ idx₂ : g₂.Index} · rwa [show (g₁ ∙ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_right] · exact List.mem_append_right _ (List.mem_map_of_mem _ he) -lemma Trace.sequence_left {idx₁ idx₂ : g₁.Index} +noncomputable def Trace.sequence_left {idx₁ idx₂ : g₁.Index} (tr : Trace g₁ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ⤳ g₂) (idx₁.castAdd g₂.size) (idx₂.castAdd g₂.size) ρ₁ ρ₂ := by induction tr with @@ -53,7 +53,7 @@ lemma Trace.sequence_left {idx₁ idx₂ : g₁.Index} · rwa [show (g₁ ⤳ g₂).nodes = Fin.append g₁.nodes g₂.nodes from rfl, Fin.append_left] · exact List.mem_append_left _ (List.mem_append_left _ (List.mem_map_of_mem _ he)) -lemma Trace.sequence_right {idx₁ idx₂ : g₂.Index} +noncomputable def Trace.sequence_right {idx₁ idx₂ : g₂.Index} (tr : Trace g₂ idx₁ idx₂ ρ₁ ρ₂) : Trace (g₁ ⤳ g₂) (idx₁.natAdd g₁.size) (idx₂.natAdd g₁.size) ρ₁ ρ₂ := by induction tr with @@ -66,21 +66,21 @@ lemma Trace.sequence_right {idx₁ idx₂ : g₂.Index} · exact List.mem_append_left _ (List.mem_append_right _ (List.mem_map_of_mem _ he)) -lemma EndToEndTrace.overlay_left (etr : EndToEndTrace g₁ ρ₁ ρ₂) : +noncomputable def EndToEndTrace.overlay_left (etr : EndToEndTrace g₁ ρ₁ ρ₂) : EndToEndTrace (g₁ ∙ g₂) ρ₁ ρ₂ := by obtain ⟨i₁, h₁, i₂, h₂, tr⟩ := etr exact ⟨i₁.castAdd g₂.size, List.mem_append_left _ (List.mem_map_of_mem _ h₁), i₂.castAdd g₂.size, List.mem_append_left _ (List.mem_map_of_mem _ h₂), tr.overlay_left⟩ -lemma EndToEndTrace.overlay_right (etr : EndToEndTrace g₂ ρ₁ ρ₂) : +noncomputable def EndToEndTrace.overlay_right (etr : EndToEndTrace g₂ ρ₁ ρ₂) : EndToEndTrace (g₁ ∙ g₂) ρ₁ ρ₂ := by obtain ⟨i₁, h₁, i₂, h₂, tr⟩ := etr exact ⟨i₁.natAdd g₁.size, List.mem_append_right _ (List.mem_map_of_mem _ h₁), i₂.natAdd g₁.size, List.mem_append_right _ (List.mem_map_of_mem _ h₂), tr.overlay_right⟩ -lemma EndToEndTrace.concat {ρ₃ : Env} (etr₁ : EndToEndTrace g₁ ρ₁ ρ₂) +noncomputable def EndToEndTrace.concat {ρ₃ : Env} (etr₁ : EndToEndTrace g₁ ρ₁ ρ₂) (etr₂ : EndToEndTrace g₂ ρ₂ ρ₃) : EndToEndTrace (g₁ ⤳ g₂) ρ₁ ρ₃ := by obtain ⟨i₁, h₁, i₂, h₂, tr₁⟩ := etr₁ obtain ⟨j₁, k₁, j₂, k₂, tr₂⟩ := etr₂ @@ -98,29 +98,29 @@ section Loop variable {g : Graph} {ρ₁ ρ₂ ρ₃ : Env} -lemma Trace.loop {idx₁ idx₂ : g.Index} (tr : Trace g idx₁ idx₂ ρ₁ ρ₂) : +noncomputable def Trace.loop {idx₁ idx₂ : g.Index} (tr : Trace g idx₁ idx₂ ρ₁ ρ₂) : Trace (Graph.loop g) (idx₁.natAdd 2) (idx₂.natAdd 2) ρ₁ ρ₂ := by induction tr with | single hbs => exact Trace.single (by - rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => []) g.nodes from rfl, + rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => none) g.nodes from rfl, Fin.append_right]) | edge hbs he _ ih => refine Trace.edge ?_ ?_ ih - · rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => []) g.nodes from rfl, + · rwa [show (Graph.loop g).nodes = Fin.append (fun _ : Fin 2 => none) g.nodes from rfl, Fin.append_right] · exact List.mem_append_left _ (List.mem_append_left _ (List.mem_append_left _ (List.mem_map_of_mem _ he))) private lemma loop_nodes_at_in : - (Graph.loop g).nodes g.loopIn = [] := - Fin.append_left (fun _ : Fin 2 => []) g.nodes 0 + (Graph.loop g).nodes g.loopIn = none := + Fin.append_left (fun _ : Fin 2 => none) g.nodes 0 private lemma loop_nodes_at_out : - (Graph.loop g).nodes g.loopOut = [] := - Fin.append_left (fun _ : Fin 2 => []) g.nodes 1 + (Graph.loop g).nodes g.loopOut = none := + Fin.append_left (fun _ : Fin 2 => none) g.nodes 1 -lemma EndToEndTrace.loop (etr : EndToEndTrace g ρ₁ ρ₂) : +noncomputable def EndToEndTrace.loop (etr : EndToEndTrace g ρ₁ ρ₂) : EndToEndTrace (Graph.loop g) ρ₁ ρ₂ := by obtain ⟨i₁, h₁, i₂, h₂, tr⟩ := etr -- the edge in → (2 ↑ʳ i₁), reached through the second edge group @@ -132,15 +132,15 @@ lemma EndToEndTrace.loop (etr : EndToEndTrace g ρ₁ ρ₂) : refine List.mem_append_left _ (List.mem_append_right _ ?_) exact List.mem_map_of_mem _ (List.mem_map_of_mem _ h₂) refine ⟨g.loopIn, List.mem_singleton_self _, g.loopOut, List.mem_singleton_self _, ?_⟩ - exact Trace.concat (Trace.single (loop_nodes_at_in ▸ EvalBasicStmts.nil)) hin - (Trace.concat tr.loop hout (Trace.single (loop_nodes_at_out ▸ EvalBasicStmts.nil))) + exact Trace.concat (Trace.single (loop_nodes_at_in ▸ EvalBasicStmtOpt.none)) hin + (Trace.concat tr.loop hout (Trace.single (loop_nodes_at_out ▸ EvalBasicStmtOpt.none))) private lemma loop_edge_out_in : ((g.loopOut, g.loopIn) : (Graph.loop g).Edge) ∈ (Graph.loop g).edges := by refine List.mem_append_right _ ?_ exact List.mem_cons_self _ _ -lemma EndToEndTrace.loop_concat (etr₁ : EndToEndTrace (Graph.loop g) ρ₁ ρ₂) +noncomputable def EndToEndTrace.loop_concat (etr₁ : EndToEndTrace (Graph.loop g) ρ₁ ρ₂) (etr₂ : EndToEndTrace (Graph.loop g) ρ₂ ρ₃) : EndToEndTrace (Graph.loop g) ρ₁ ρ₃ := by obtain ⟨i₁, h₁, i₂, h₂, tr₁⟩ := etr₁ @@ -150,35 +150,35 @@ lemma EndToEndTrace.loop_concat (etr₁ : EndToEndTrace (Graph.loop g) ρ₁ ρ exact ⟨g.loopIn, List.mem_singleton_self _, g.loopOut, List.mem_singleton_self _, Trace.concat tr₁ loop_edge_out_in tr₂⟩ -lemma EndToEndTrace.loop_empty {ρ : Env} : EndToEndTrace (Graph.loop g) ρ ρ := by +noncomputable def EndToEndTrace.loop_empty {ρ : Env} : EndToEndTrace (Graph.loop g) ρ ρ := by have hedge : ((g.loopIn, g.loopOut) : (Graph.loop g).Edge) ∈ (Graph.loop g).edges := List.mem_append_right _ (List.mem_cons_of_mem _ (List.mem_cons_self _ _)) exact ⟨g.loopIn, List.mem_singleton_self _, g.loopOut, List.mem_singleton_self _, - Trace.concat (Trace.single (loop_nodes_at_in ▸ EvalBasicStmts.nil)) hedge - (Trace.single (loop_nodes_at_out ▸ EvalBasicStmts.nil))⟩ + Trace.concat (Trace.single (loop_nodes_at_in ▸ EvalBasicStmtOpt.none)) hedge + (Trace.single (loop_nodes_at_out ▸ EvalBasicStmtOpt.none))⟩ end Loop /-! ### Singletons, wrap, and the main result -/ -lemma EndToEndTrace.singleton {bss : List BasicStmt} {ρ₁ ρ₂ : Env} - (h : EvalBasicStmts ρ₁ bss ρ₂) : EndToEndTrace (Graph.singleton bss) ρ₁ ρ₂ := +noncomputable def EndToEndTrace.singleton {o : Option BasicStmt} {ρ₁ ρ₂ : Env} + (h : EvalBasicStmtOpt ρ₁ o ρ₂) : EndToEndTrace (Graph.singleton o) ρ₁ ρ₂ := ⟨(0 : Fin 1), List.mem_singleton_self _, (0 : Fin 1), List.mem_singleton_self _, Trace.single h⟩ -lemma EndToEndTrace.singleton_nil (ρ : Env) : - EndToEndTrace (Graph.singleton []) ρ ρ := - EndToEndTrace.singleton EvalBasicStmts.nil +noncomputable def EndToEndTrace.singleton_nil (ρ : Env) : + EndToEndTrace (Graph.singleton none) ρ ρ := + EndToEndTrace.singleton EvalBasicStmtOpt.none -lemma EndToEndTrace.wrap {g : Graph} {ρ₁ ρ₂ : Env} +noncomputable def EndToEndTrace.wrap {g : Graph} {ρ₁ ρ₂ : Env} (etr : EndToEndTrace g ρ₁ ρ₂) : EndToEndTrace (Graph.wrap g) ρ₁ ρ₂ := (EndToEndTrace.singleton_nil ρ₁).concat (etr.concat (EndToEndTrace.singleton_nil ρ₂)) -theorem Stmt.cfg_sufficient {s : Stmt} {ρ₁ ρ₂ : Env} +noncomputable def Stmt.cfg_sufficient {s : Stmt} {ρ₁ ρ₂ : Env} (h : EvalStmt ρ₁ s ρ₂) : EndToEndTrace s.cfg ρ₁ ρ₂ := by induction h with | basic ρ₁ ρ₂ bs hbs => - exact EndToEndTrace.singleton (EvalBasicStmts.cons hbs EvalBasicStmts.nil) + exact EndToEndTrace.singleton (EvalBasicStmtOpt.some hbs) | andThen ρ₁ ρ₂ ρ₃ s₁ s₂ _ _ ih₁ ih₂ => exact ih₁.concat ih₂ | ifTrue ρ₁ ρ₂ e z s₁ s₂ _ _ _ ih => @@ -193,7 +193,7 @@ theorem Stmt.cfg_sufficient {s : Stmt} {ρ₁ ρ₂ : Env} /-! ### The wrapped graph's entry has no predecessors (Agda's "ugly" block) -/ def Graph.wrapInput (g : Graph) : (Graph.wrap g).Index := - (0 : Fin 1).castAdd ((g ⤳ Graph.singleton []).size) + (0 : Fin 1).castAdd ((g ⤳ Graph.singleton none).size) def Graph.wrapOutput (g : Graph) : (Graph.wrap g).Index := Fin.natAdd 1 ((Fin.natAdd g.size (0 : Fin 1))) @@ -205,9 +205,9 @@ lemma Graph.wrap_outputs (g : Graph) : (Graph.wrap g).outputs = [g.wrapOutput] := rfl private lemma not_mem_edges_castAdd_sequence {g₂ : Graph} (i : Fin 1) - (idx : (Graph.singleton [] ⤳ g₂).Index) : - ((idx, i.castAdd g₂.size) : (Graph.singleton [] ⤳ g₂).Edge) - ∉ (Graph.singleton [] ⤳ g₂).edges := by + (idx : (Graph.singleton none ⤳ g₂).Index) : + ((idx, i.castAdd g₂.size) : (Graph.singleton none ⤳ g₂).Edge) + ∉ (Graph.singleton none ⤳ g₂).edges := by intro h rcases List.mem_append.mp h with h' | h' · rcases List.mem_append.mp h' with h'' | h'' @@ -228,6 +228,6 @@ lemma Graph.wrap_predecessors_eq_nil (g : Graph) (idx : (Graph.wrap g).Index) subst h rw [GGraph.predecessors, List.filter_eq_nil_iff] intro idx' _ - simpa using not_mem_edges_castAdd_sequence (g₂ := g ⤳ Graph.singleton []) 0 idx' + simpa using not_mem_edges_castAdd_sequence (g₂ := g ⤳ Graph.singleton none) 0 idx' end Spa diff --git a/lean/Spa/Language/Semantics.lean b/lean/Spa/Language/Semantics.lean index 0eedc50..e2e6633 100644 --- a/lean/Spa/Language/Semantics.lean +++ b/lean/Spa/Language/Semantics.lean @@ -46,22 +46,20 @@ inductive EvalExpr : Env → Expr → Value → Prop /-- Inference rules for evaluating a basic statement (`Spa.BasicStmt`) in a given environment, potentially changing the environment. Pretty standard big-step evaluation. -/ -inductive EvalBasicStmt : Env → BasicStmt → Env → Prop +inductive EvalBasicStmt : Env → BasicStmt → Env → Type | noop (ρ : Env) : EvalBasicStmt ρ .noop ρ | assign (ρ : Env) (x : String) (e : Expr) (v : Value) : EvalExpr ρ e v → EvalBasicStmt ρ (.assign x e) ((x, v) :: ρ) -/-- Inference rules for evaluating a sequence of basic statements. -/ -inductive EvalBasicStmts : Env → List BasicStmt → Env → Prop - | nil {ρ : Env} : EvalBasicStmts ρ [] ρ - | cons {ρ₁ ρ₂ ρ₃ : Env} {bs : BasicStmt} {bss : List BasicStmt} : - EvalBasicStmt ρ₁ bs ρ₂ → EvalBasicStmts ρ₂ bss ρ₃ → - EvalBasicStmts ρ₁ (bs :: bss) ρ₃ +inductive EvalBasicStmtOpt : Env → Option BasicStmt → Env → Type + | none {ρ : Env} : EvalBasicStmtOpt ρ Option.none ρ + | some {ρ₁ ρ₂ : Env} {bs : BasicStmt} : + EvalBasicStmt ρ₁ bs ρ₂ → EvalBasicStmtOpt ρ₁ (Option.some bs) ρ₂ /-- Inference rules for evaluating statements (`Spa.Stmt`) in a given environment, potentially changing the environment. Pretty standard big-step evaluation. -/ -inductive EvalStmt : Env → Stmt → Env → Prop +inductive EvalStmt : Env → Stmt → Env → Type | basic (ρ₁ ρ₂ : Env) (bs : BasicStmt) : EvalBasicStmt ρ₁ bs ρ₂ → EvalStmt ρ₁ (.basic bs) ρ₂ | andThen (ρ₁ ρ₂ ρ₃ : Env) (s₁ s₂ : Stmt) : diff --git a/lean/Spa/Language/Tagged/Basic.lean b/lean/Spa/Language/Tagged/Basic.lean index fb3a937..ee61b20 100644 --- a/lean/Spa/Language/Tagged/Basic.lean +++ b/lean/Spa/Language/Tagged/Basic.lean @@ -6,12 +6,13 @@ derive_tagged Spa.Expr Spa.BasicStmt Spa.Stmt namespace Spa -def tagStmt (s : Stmt) : Stmt.Tagged NodeId := (s.tag 0).1 +def tagStmt (s : Stmt) : Stmt.Tagged RawId := (s.tag 0).1 -def Stmt.Tagged.subtreeIds (s : Stmt.Tagged NodeId) : List NodeId := +def Stmt.Tagged.subtreeIds {τ : Type} (s : Stmt.Tagged τ) : List τ := s.foldTags (· :: ·) [] -def Stmt.Tagged.isInLoopBody (body : Stmt.Tagged NodeId) (id : NodeId) : Bool := +def Stmt.Tagged.isInLoopBody {τ : Type} [DecidableEq τ] + (body : Stmt.Tagged τ) (id : τ) : Bool := decide (id ∈ body.subtreeIds) end Spa diff --git a/lean/Spa/Language/Tagged/Derive.lean b/lean/Spa/Language/Tagged/Derive.lean index e946e9d..b021a93 100644 --- a/lean/Spa/Language/Tagged/Derive.lean +++ b/lean/Spa/Language/Tagged/Derive.lean @@ -13,8 +13,8 @@ inductive types and generates, for each `Tᵢ`: carries a leading `tag : τ` field and every field whose type is a family member is retyped to its `.Tagged τ` counterpart; * `Tᵢ.Tagged.erase : Tᵢ.Tagged τ → Tᵢ`, forgetting all tags; -* `Tᵢ.tag : Tᵢ → ℕ → Tᵢ.Tagged NodeId × ℕ`, assigning every node a unique - `NodeId` (its postorder index) by a single unified traversal that threads a +* `Tᵢ.tag : Tᵢ → ℕ → Tᵢ.Tagged RawId × ℕ`, assigning every node a unique + `RawId` (its postorder index) by a single unified traversal that threads a counter; the whole family shares one counter, so identifiers are unique across types. @@ -54,6 +54,45 @@ def eraseOf (n : Name) : Name := n ++ `Tagged ++ `erase def rootTagOf (n : Name) : Name := n ++ `Tagged ++ `rootTag def tagOf (n : Name) : Name := n ++ `tag def foldTagsOf (n : Name) : Name := n ++ `Tagged ++ `foldTags +def wfOf (n : Name) : Name := n ++ `Tagged ++ `WF +def narrowOf (n : Name) : Name := n ++ `Tagged ++ `narrow +def narrowEraseOf (n : Name) : Name := n ++ `Tagged ++ `narrow_erase +def tagLeOf (n : Name) : Name := n ++ `tag_le +def tagRootTagPostOf (n : Name) : Name := n ++ `tag_rootTag_post +def tagWfOf (n : Name) : Name := n ++ `tag_wf + +/-- Project the `i`-th conjunct (1-based) out of `hyp`, which has type a +right-nested `And` of `total` conjuncts, e.g. `hyp |>.2 |>.2 |>.1`. -/ +def projAnd {m : Type → Type} [Monad m] [MonadQuotation m] + (hyp : Term) (i total : Nat) : m Term := do + let mut t := hyp + for _ in [0:i-1] do + t ← `($t |>.2) + if i < total then + t ← `($t |>.1) + return t + +/-- Combine a non-empty array of propositions into a right-nested conjunction. -/ +def mkAndR {m : Type → Type} [Monad m] [MonadQuotation m] + (cs : Array Term) : m Term := do + let mut t := cs.back! + for c in cs.pop.reverse do + t ← `($c ∧ $t) + return t + +/-- For a constructor, return one entry per *recursive* field: its argument +identifier, the family member it references, and the start-counter expression at +which it is tagged (`n`, then `(a.tag n).2`, …) — the same threading `mkTag` +uses. -/ +def recChildren (cd : CtorData) (argNames : Array Ident) (nStart : Term) : + CommandElabM (Array (Ident × Name × Term)) := do + let mut res : Array (Ident × Name × Term) := #[] + let mut cur := nStart + for (f, a) in cd.fields.zip argNames do + if f.isRec then + res := res.push (a, f.recType, cur) + cur ← `(($(mkIdent (tagOf f.recType)) $a $cur) |>.2) + return res /-- Inspect the family, classifying each constructor field. -/ def gather (family : Array Name) (τ : Ident) : TermElabM (Array TypeData) := do @@ -158,7 +197,7 @@ def mkRootTag (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) : /-- The postorder `tag` functions, one per family member (separate defs in dependency order). -/ def mkTag (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do - let nId := mkIdent ``Spa.NodeId + let nId := mkIdent ``Spa.RawId tds.mapM fun td => do let mut pats : Array Term := #[] let mut rhss : Array Term := #[] @@ -219,6 +258,229 @@ def mkFoldTags (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) $(mkIdent (taggedOf td.name)) $τ → $m := fun x => match x with $[| $pats => $rhss]*) +/-- The well-formedness predicate `T.Tagged.WF : T.Tagged RawId → Prop`: every +recursive child's root tag has a strictly smaller postorder index than the node's +own tag, and each child is itself well-formed. Leaf constructors are `True`. -/ +def mkWF (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let tId := mkIdent `t + let rawId := mkIdent ``Spa.RawId + tds.mapM fun td => do + let mut pats : Array Term := #[] + let mut rhss : Array Term := #[] + for cd in td.ctors do + let hasRec := cd.fields.any (·.isRec) + let mut patArgs : Array Term := #[] + let mut recArgs : Array Ident := #[] + let mut i := 0 + for f in cd.fields do + if f.isRec then + let a := mkIdent (.mkSimple s!"a{i}") + patArgs := patArgs.push a + recArgs := recArgs.push a + else + patArgs := patArgs.push (← `(_)) + i := i + 1 + let tagBind : Term ← if hasRec then `($tId) else `(_) + let pat ← `($(mkIdent (taggedOf td.name ++ cd.shortName)) $tagBind $patArgs*) + let rhs ← if recArgs.isEmpty then `(True) else do + let bounds ← recArgs.mapM fun a => `($(a).rootTag.post < $(tId).post) + let wfs ← recArgs.mapM fun a => `($(a).WF) + mkAndR (bounds ++ wfs) + pats := pats.push pat + rhss := rhss.push rhs + `(command| def $(mkIdent (wfOf td.name)) : + $(mkIdent (taggedOf td.name)) $rawId → Prop := + fun x => match x with $[| $pats => $rhss]*) + +/-- The `narrow` coercion `T.Tagged RawId → T.Tagged (Fin N)`, given a bound on +the root tag and a well-formedness proof. Each node's tag becomes the `Fin N` +built from its postorder index, and recursion threads the bound through `lt_trans` +and the (definitionally unfolded) `WF` conjunction. -/ +def mkNarrow (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let rawId := mkIdent ``Spa.RawId + let tId := mkIdent `t + let nId := mkIdent `N + let hId := mkIdent `h + let hwfId := mkIdent `hwf + let tgId := mkIdent `tg + tds.mapM fun td => do + let self ← `($(mkIdent (taggedOf td.name)) $rawId) + let mut patss : Array (Array Term) := #[] + let mut rhss : Array Term := #[] + for cd in td.ctors do + let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}") + let ctorPat ← `($(mkIdent (taggedOf td.name ++ cd.shortName)) $tgId $argNames*) + let k := (cd.fields.filter (·.isRec)).size + let mut newArgs : Array Term := #[] + let mut ri := 0 + for (f, a) in cd.fields.zip argNames do + if f.isRec then + let bound ← projAnd hwfId (ri + 1) (2 * k) + let wf ← projAnd hwfId (k + ri + 1) (2 * k) + newArgs := newArgs.push (← `($(a).narrow (lt_trans $bound $hId) $wf)) + ri := ri + 1 + else + newArgs := newArgs.push a + let built ← `($(mkIdent (taggedOf td.name ++ cd.shortName)) ⟨$(tgId).post, $hId⟩ $newArgs*) + let nPat ← `(_) + let hPat ← `($hId) + let hwfPat : Term ← if k == 0 then `(_) else `($hwfId) + patss := patss.push #[ctorPat, nPat, hPat, hwfPat] + rhss := rhss.push built + `(command| def $(mkIdent (narrowOf td.name)) : ($tId : $self) → {$nId : ℕ} → + $(tId).rootTag.post < $nId → $(tId).WF → $(mkIdent (taggedOf td.name)) (Fin $nId) + $[| $[$patss],* => $rhss]*) + +/-- `T.tag_rootTag_post`: the root tag of a freshly tagged node is exactly one +below the threaded-out counter, i.e. the node itself is numbered last (postorder). +A uniform `cases <;> simp` discharges every constructor. -/ +def mkTagRootTagPost (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let eId := mkIdent `e + let nId := mkIdent `n + tds.mapM fun td => + `(command| theorem $(mkIdent (tagRootTagPostOf td.name)) + ($eId : $(mkIdent td.name)) ($nId : ℕ) : + ($(eId).tag $nId).1.rootTag.post + 1 = ($(eId).tag $nId).2 := by + cases $eId:ident <;> + simp [$(mkIdent (tagOf td.name)):ident, $(mkIdent (rootTagOf td.name)):ident]) + +/-- `T.tag_le`: tagging only ever advances the counter (`n ≤ (e.tag n).2`). +Proved by induction; each arm threads the counter through its recursive children +(using the relevant `tag_le`/induction hypothesis) and closes with `omega`. -/ +def mkTagLe (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let eId := mkIdent `e + let nId := mkIdent `n + tds.mapM fun td => do + let mut ctorLabels : Array Ident := #[] + let mut binderss : Array (Array Ident) := #[] + let mut tacs : Array (TSyntax ``Lean.Parser.Tactic.tacticSeq) := #[] + for cd in td.ctors do + let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}") + let mut ihBinders : Array Ident := #[] + let mut haveTacs : Array (TSyntax `tactic) := #[] + let mut cur : Term ← `($nId) + let mut i := 0 + for (f, a) in cd.fields.zip argNames do + if f.isRec then + let fact ← if f.recType == td.name then + `($(mkIdent (.mkSimple s!"ih{i}")) $cur) + else + `($(mkIdent (tagLeOf f.recType)) $a $cur) + if f.recType == td.name then + ihBinders := ihBinders.push (mkIdent (.mkSimple s!"ih{i}")) + haveTacs := haveTacs.push (← `(tactic| have := $fact)) + cur ← `(($(mkIdent (tagOf f.recType)) $a $cur) |>.2) + i := i + 1 + let simpTac ← `(tactic| simp only [$(mkIdent (tagOf td.name)):ident]) + let omegaTac ← `(tactic| omega) + let allTacs := #[simpTac] ++ haveTacs ++ #[omegaTac] + ctorLabels := ctorLabels.push (mkIdent cd.shortName) + binderss := binderss.push (argNames ++ ihBinders) + tacs := tacs.push (← `(tacticSeq| $[$allTacs]*)) + `(command| theorem $(mkIdent (tagLeOf td.name)) ($eId : $(mkIdent td.name)) ($nId : ℕ) : + $nId ≤ ($(eId).tag $nId).2 := by + induction $eId:ident generalizing $nId:ident with + $[| $ctorLabels:ident $binderss* => $tacs]*) + +/-- `T.tag_wf`: a freshly tagged term is well-formed. Each recursive child's +bound conjunct is closed by `omega` from that child's `tag_rootTag_post` plus the +`tag_le` of every later child (which bounds the threaded-out counter), and each +well-formedness conjunct is the child's induction hypothesis / `tag_wf`. -/ +def mkTagWf (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let eId := mkIdent `e + let nId := mkIdent `n + tds.mapM fun td => do + let mut ctorLabels : Array Ident := #[] + let mut binderss : Array (Array Ident) := #[] + let mut tacs : Array (TSyntax ``Lean.Parser.Tactic.tacticSeq) := #[] + for cd in td.ctors do + let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}") + -- recursive children: (arg, recType, startCounter, sameType?, fieldIndex) + let mut recs : Array (Ident × Name × Term × Bool × Nat) := #[] + let mut cur : Term ← `($nId) + let mut i := 0 + for (f, a) in cd.fields.zip argNames do + if f.isRec then + recs := recs.push (a, f.recType, cur, f.recType == td.name, i) + cur ← `(($(mkIdent (tagOf f.recType)) $a $cur) |>.2) + i := i + 1 + let k := recs.size + let ihBinders := (recs.filter (·.2.2.2.1)).map fun r => mkIdent (.mkSimple s!"ih{r.2.2.2.2}") + let tac : TSyntax ``Lean.Parser.Tactic.tacticSeq ← if k == 0 then + `(tacticSeq| exact True.intro) + else do + let mut comps : Array Term := #[] + -- bound conjuncts + for idx in [0:k] do + let (a, rt, s, _, _) := recs[idx]! + let mut bHaves : Array (TSyntax `tactic) := + #[← `(tactic| have := $(mkIdent (tagRootTagPostOf rt)) $a $s)] + for j in [idx+1:k] do + let (aj, rtj, sj, _, _) := recs[j]! + bHaves := bHaves.push (← `(tactic| have := $(mkIdent (tagLeOf rtj)) $aj $sj)) + bHaves := bHaves.push (← `(tactic| omega)) + comps := comps.push (← `(by $(← `(tacticSeq| $[$bHaves]*)))) + -- well-formedness conjuncts + for idx in [0:k] do + let (a, rt, s, same, fi) := recs[idx]! + comps := comps.push <| ← if same then `($(mkIdent (.mkSimple s!"ih{fi}")) $s) + else `($(mkIdent (tagWfOf rt)) $a $s) + let simpTac ← `(tactic| simp only + [$(mkIdent (tagOf td.name)):ident, $(mkIdent (wfOf td.name)):ident]) + let exactTac ← `(tactic| exact ⟨$comps,*⟩) + `(tacticSeq| $[$(#[simpTac, exactTac])]*) + ctorLabels := ctorLabels.push (mkIdent cd.shortName) + binderss := binderss.push (argNames ++ ihBinders) + tacs := tacs.push tac + `(command| theorem $(mkIdent (tagWfOf td.name)) ($eId : $(mkIdent td.name)) ($nId : ℕ) : + ($(eId).tag $nId).1.WF := by + induction $eId:ident generalizing $nId:ident with + $[| $ctorLabels:ident $binderss* => $tacs]*) + +/-- `T.Tagged.narrow_erase`: narrowing the tag type does not change the erased +(untagged) term. A per-constructor `simp` with the local `narrow`/`erase` +equations, the lower members' `narrow_erase`, and the induction hypotheses. -/ +def mkNarrowErase (tds : Array TypeData) : CommandElabM (Array (TSyntax `command)) := do + let rawId := mkIdent ``Spa.RawId + let tId := mkIdent `t + let nId := mkIdent `N + let hId := mkIdent `h + let hwfId := mkIdent `hwf + let tgId := mkIdent `tg + tds.mapM fun td => do + let mut ctorLabels : Array Ident := #[] + let mut binderss : Array (Array Ident) := #[] + let mut tacs : Array (TSyntax ``Lean.Parser.Tactic.tacticSeq) := #[] + for cd in td.ctors do + let argNames := (Array.range cd.fields.size).map fun i => mkIdent (.mkSimple s!"a{i}") + let mut lemmas : Array Term := + #[← `($(mkIdent (narrowOf td.name))), ← `($(mkIdent (eraseOf td.name)))] + let mut ihBinders : Array Ident := #[] + let mut seenLower : Array Name := #[] + let mut i := 0 + for f in cd.fields do + if f.isRec then + if f.recType == td.name then + let ih := mkIdent (.mkSimple s!"ih{i}") + ihBinders := ihBinders.push ih + lemmas := lemmas.push (← `($ih)) + else if !seenLower.contains f.recType then + seenLower := seenLower.push f.recType + lemmas := lemmas.push (← `($(mkIdent (narrowEraseOf f.recType)))) + i := i + 1 + let introTac ← `(tactic| intro $nId $hId $hwfId) + let simpTac ← `(tactic| simp [$[$lemmas:term],*]) + ctorLabels := ctorLabels.push (mkIdent cd.shortName) + binderss := binderss.push (#[tgId] ++ argNames ++ ihBinders) + tacs := tacs.push (← `(tacticSeq| $[$(#[introTac, simpTac])]*)) + `(command| theorem $(mkIdent (narrowEraseOf td.name)) : + ($tId : $(mkIdent (taggedOf td.name)) $rawId) → ∀ {$nId : ℕ} + ($hId : $(tId).rootTag.post < $nId) ($hwfId : $(tId).WF), + ($(tId).narrow $hId $hwfId).erase = $(tId).erase := by + intro $tId:ident + induction $tId:ident with + $[| $ctorLabels:ident $binderss* => $tacs]*) + /-- `derive_tagged T₁ … Tₙ` — generate tagged mirrors, `erase`, and `tag` for the given family of inductives. -/ syntax (name := deriveTaggedCmd) "derive_tagged " ident+ : command @@ -236,6 +498,12 @@ def elabDeriveTagged : CommandElab := fun stx => do for d in (← mkErase tds) do elabCommand d for d in (← mkTag tds) do elabCommand d for d in (← mkFoldTags tds) do elabCommand d + for d in (← mkWF tds) do elabCommand d + for d in (← mkNarrow tds) do elabCommand d + for d in (← mkTagRootTagPost tds) do elabCommand d + for d in (← mkTagLe tds) do elabCommand d + for d in (← mkTagWf tds) do elabCommand d + for d in (← mkNarrowErase tds) do elabCommand d | _ => throwUnsupportedSyntax end Spa.DeriveTagged diff --git a/lean/Spa/Language/Tagged/Graphs.lean b/lean/Spa/Language/Tagged/Graphs.lean index e0e2f50..88353b1 100644 --- a/lean/Spa/Language/Tagged/Graphs.lean +++ b/lean/Spa/Language/Tagged/Graphs.lean @@ -7,14 +7,14 @@ namespace Spa open GGraph -def Stmt.Tagged.cfg : Stmt.Tagged NodeId → GGraph (List (BasicStmt.Tagged NodeId)) - | .basic _ bs => GGraph.singleton [bs] +def Stmt.Tagged.cfg {τ : Type} : Stmt.Tagged τ → GGraph (Option (BasicStmt.Tagged τ)) + | .basic _ bs => GGraph.singleton (some bs) | .andThen _ s₁ s₂ => s₁.cfg ⤳ s₂.cfg | .ifElse _ _ s₁ s₂ => s₁.cfg ∙ s₂.cfg | .whileLoop _ _ s => GGraph.loop s.cfg -theorem Stmt.Tagged.cfg_graph : ∀ (t : Stmt.Tagged NodeId), - t.cfg.map (List.map BasicStmt.Tagged.erase) = t.erase.cfg +theorem Stmt.Tagged.cfg_graph {τ : Type} : ∀ (t : Stmt.Tagged τ), + (Option.map BasicStmt.Tagged.erase) <$> t.cfg = t.erase.cfg | .basic _ bs => by simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, BasicStmt.Tagged.erase] | .andThen _ s₁ s₂ => by simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, Stmt.Tagged.cfg_graph s₁, Stmt.Tagged.cfg_graph s₂] @@ -23,13 +23,16 @@ theorem Stmt.Tagged.cfg_graph : ∀ (t : Stmt.Tagged NodeId), | .whileLoop _ _ s => by simp [Stmt.Tagged.cfg, Stmt.cfg, Stmt.Tagged.erase, Stmt.Tagged.cfg_graph s] -def GGraph.nodeLabel (g : GGraph (List (BasicStmt.Tagged NodeId))) (i : g.Index) : Option NodeId := - (g.nodes i).head?.map BasicStmt.Tagged.rootTag +def GGraph.nodeLabel {τ : Type} (g : GGraph (Option (BasicStmt.Tagged τ))) (i : g.Index) : + Option τ := + (g.nodes i).map BasicStmt.Tagged.rootTag -def GGraph.stateOf (g : GGraph (List (BasicStmt.Tagged NodeId))) (id : NodeId) : Option g.Index := +def GGraph.stateOf {τ : Type} [DecidableEq τ] (g : GGraph (Option (BasicStmt.Tagged τ))) + (id : τ) : Option g.Index := g.indices.find? (fun i => decide (g.nodeLabel i = some id)) -theorem GGraph.stateOf_label {g : GGraph (List (BasicStmt.Tagged NodeId))} {id : NodeId} +theorem GGraph.stateOf_label {τ : Type} [DecidableEq τ] + {g : GGraph (Option (BasicStmt.Tagged τ))} {id : τ} {i : g.Index} (h : g.stateOf id = some i) : g.nodeLabel i = some id := by rw [GGraph.stateOf] at h simpa using List.find?_some h @@ -38,26 +41,64 @@ namespace Program variable (p : Program) -def tagged : Stmt.Tagged NodeId := tagStmt p.rootStmt +def tagged : Stmt.Tagged RawId := tagStmt p.rootStmt -def taggedCfg : GGraph (List (BasicStmt.Tagged NodeId)) := - GGraph.wrap p.tagged.cfg +def size : ℕ := p.tagged.rootTag.post + 1 + +theorem size_pos : 0 < p.size := Nat.succ_pos _ + +abbrev NodeId : Type := Fin p.size + +theorem tagged_wf : p.tagged.WF := Stmt.tag_wf p.rootStmt 0 + +def taggedFin : Stmt.Tagged p.NodeId := + p.tagged.narrow (Nat.lt_succ_self _) p.tagged_wf + +def taggedCfg : GGraph (Option (BasicStmt.Tagged p.NodeId)) := + GGraph.wrap p.taggedFin.cfg theorem taggedCfg_erase : - p.taggedCfg.map (List.map BasicStmt.Tagged.erase) = p.cfg := by - rw [taggedCfg, GGraph.map_wrap, Stmt.Tagged.cfg_graph, tagged, erase_tagStmt] + (Option.map BasicStmt.Tagged.erase) <$> p.taggedCfg = p.cfg := by + rw [taggedCfg, GGraph.map_wrap, Stmt.Tagged.cfg_graph, taggedFin, + Stmt.Tagged.narrow_erase, tagged, erase_tagStmt] rfl theorem taggedCfg_size : p.taggedCfg.size = p.cfg.size := by conv_rhs => rw [← p.taggedCfg_erase] rfl -def nodeIdOf (s : p.State) : Option NodeId := +def nodeIdOf (s : p.State) : Option p.NodeId := p.taggedCfg.nodeLabel (Fin.cast p.taggedCfg_size.symm s) -def stateOfNodeId (id : NodeId) : Option p.State := +def stateOfNodeId (id : p.NodeId) : Option p.State := (p.taggedCfg.stateOf id).map (Fin.cast p.taggedCfg_size) +theorem cfg_nodes_eq (s : p.State) : + p.cfg.nodes s = Option.map BasicStmt.Tagged.erase + (p.taggedCfg.nodes (Fin.cast p.taggedCfg_size.symm s)) := by + have key : ∀ (g : Graph) (hsz : p.taggedCfg.size = g.size), + (Option.map BasicStmt.Tagged.erase) <$> p.taggedCfg = g → + ∀ i : Fin g.size, + g.nodes i = Option.map BasicStmt.Tagged.erase + (p.taggedCfg.nodes (Fin.cast hsz.symm i)) := by + intro g hsz hg i + subst hg + rfl + exact key p.cfg p.taggedCfg_size p.taggedCfg_erase s + +theorem nodeIdOf_isSome_of_code {s : p.State} {bs : BasicStmt} + (h : p.code s = some bs) : (p.nodeIdOf s).isSome = true := by + have hc : Option.map BasicStmt.Tagged.erase + (p.taggedCfg.nodes (Fin.cast p.taggedCfg_size.symm s)) = some bs := by + rw [← p.cfg_nodes_eq s]; exact h + unfold Program.nodeIdOf GGraph.nodeLabel + cases hcase : p.taggedCfg.nodes (Fin.cast p.taggedCfg_size.symm s) with + | none => rw [hcase] at hc; simp at hc + | some tbs => simp + +def nodeIdOfNonempty (s : p.State) {bs : BasicStmt} (h : p.code s = some bs) : p.NodeId := + (p.nodeIdOf s).get (p.nodeIdOf_isSome_of_code h) + end Program end Spa diff --git a/lean/Spa/Language/Tagged/Id.lean b/lean/Spa/Language/Tagged/Id.lean index b572c7b..06b658a 100644 --- a/lean/Spa/Language/Tagged/Id.lean +++ b/lean/Spa/Language/Tagged/Id.lean @@ -2,7 +2,7 @@ import Mathlib.Data.Nat.Notation namespace Spa -structure NodeId where +structure RawId where post : ℕ deriving DecidableEq, Repr diff --git a/lean/Spa/Language/Traces.lean b/lean/Spa/Language/Traces.lean index 703b4f4..49ea6af 100644 --- a/lean/Spa/Language/Traces.lean +++ b/lean/Spa/Language/Traces.lean @@ -3,14 +3,14 @@ import Spa.Language.Graphs namespace Spa -inductive Trace (g : Graph) : g.Index → g.Index → Env → Env → Prop +inductive Trace (g : Graph) : g.Index → g.Index → Env → Env → Type | single {ρ₁ ρ₂ : Env} {idx : g.Index} : - EvalBasicStmts ρ₁ (g.nodes idx) ρ₂ → Trace g idx idx ρ₁ ρ₂ + EvalBasicStmtOpt ρ₁ (g.nodes idx) ρ₂ → Trace g idx idx ρ₁ ρ₂ | edge {ρ₁ ρ₂ ρ₃ : Env} {idx₁ idx₂ idx₃ : g.Index} : - EvalBasicStmts ρ₁ (g.nodes idx₁) ρ₂ → (idx₁, idx₂) ∈ g.edges → + EvalBasicStmtOpt ρ₁ (g.nodes idx₁) ρ₂ → (idx₁, idx₂) ∈ g.edges → Trace g idx₂ idx₃ ρ₂ ρ₃ → Trace g idx₁ idx₃ ρ₁ ρ₃ -lemma Trace.concat {g : Graph} {idx₁ idx₂ idx₃ idx₄ : g.Index} +noncomputable def Trace.concat {g : Graph} {idx₁ idx₂ idx₃ idx₄ : g.Index} {ρ₁ ρ₂ ρ₃ : Env} (tr₁ : Trace g idx₁ idx₂ ρ₁ ρ₂) (he : (idx₂, idx₃) ∈ g.edges) (tr₂ : Trace g idx₃ idx₄ ρ₂ ρ₃) : Trace g idx₁ idx₄ ρ₁ ρ₃ := by @@ -18,7 +18,7 @@ lemma Trace.concat {g : Graph} {idx₁ idx₂ idx₃ idx₄ : g.Index} | single hbs => exact Trace.edge hbs he tr₂ | edge hbs he' _ ih => exact Trace.edge hbs he' (ih he tr₂) -inductive EndToEndTrace (g : Graph) (ρ₁ ρ₂ : Env) : Prop +inductive EndToEndTrace (g : Graph) (ρ₁ ρ₂ : Env) : Type | intro (idx₁ : g.Index) (idx₁_mem : idx₁ ∈ g.inputs) (idx₂ : g.Index) (idx₂_mem : idx₂ ∈ g.outputs) (trace : Trace g idx₁ idx₂ ρ₁ ρ₂) : EndToEndTrace g ρ₁ ρ₂ diff --git a/lean/Spa/Lattice.lean b/lean/Spa/Lattice.lean index dcb4e8b..f6ce9f2 100644 --- a/lean/Spa/Lattice.lean +++ b/lean/Spa/Lattice.lean @@ -12,6 +12,18 @@ etc.. What remains are a couple of theorems about folds, as well as `FiniteHeightLattice`, the core concept of lattice-based static program analyses. See the documentation on that class for more information. -/ +namespace Option + +/-- Equality-sensitive eliminator for options in which the `some` case + is sensitive to the base `β`. This makes it mirror a one-element fold + more closely. -/ +def elimEq {α : Type*} {β : Sort*} : + (o : Option α) → β → ((a : α) → o = some a → β → β) → β + | none, b, _ => b + | some a, b, f => f a rfl b + +end Option + namespace Spa /-- Predicate for binary functions independently monotone in both their arguments. -/ @@ -61,6 +73,16 @@ lemma foldl_mono' (l : List α) (f : β → α → β) | nil => exact hb | cons x xs ih => exact ih (hf x hb) +omit [Preorder α] in +/-- The equality-aware eliminator (that also alters its behavior dependent on base case) + for option is monotonic. -/ +lemma elimEq_self_mono (o : Option α) (g : (a : α) → o = some a → β → β) + (hg : ∀ a h, Monotone (g a h)) : + Monotone (o.elimEq · g) := by + cases o with + | none => exact monotone_id + | some a => exact hg a rfl + end Folds /-- Predicate on types with `Preorder` that claims all $<$ chains in the type have at most `n` comparisons. -/ diff --git a/lean/Spa/Lattice/FiniteMap.lean b/lean/Spa/Lattice/FiniteMap.lean index 4f3ee6c..45e1900 100644 --- a/lean/Spa/Lattice/FiniteMap.lean +++ b/lean/Spa/Lattice/FiniteMap.lean @@ -72,6 +72,12 @@ lemma mem_sup {fm₁ fm₂ : FiniteMap A B ks} {k : A} {v : B} obtain ⟨i, hi, rfl⟩ := h exact ⟨fm₁ i, fm₂ i, rfl, ⟨i, hi, rfl⟩, ⟨i, hi, rfl⟩⟩ +lemma mem_inf {fm₁ fm₂ : FiniteMap A B ks} {k : A} {v : B} + (h : (k, v) ∈ fm₁ ⊓ fm₂) : + ∃ v₁ v₂, v = v₁ ⊓ v₂ ∧ (k, v₁) ∈ fm₁ ∧ (k, v₂) ∈ fm₂ := by + obtain ⟨i, hi, rfl⟩ := h + exact ⟨fm₁ i, fm₂ i, rfl, ⟨i, hi, rfl⟩, ⟨i, hi, rfl⟩⟩ + section Updating variable [DecidableEq A] diff --git a/lean/Spa/Transformation/Licm.lean b/lean/Spa/Transformation/Licm.lean index 63f7e98..d1ee5c3 100644 --- a/lean/Spa/Transformation/Licm.lean +++ b/lean/Spa/Transformation/Licm.lean @@ -16,7 +16,7 @@ The pipeline, for each assignment immediately enclosed by a loop: (`joinForKey s result` — the join over predecessors, i.e. before the assignment itself runs); 3. union the definition sets of the RHS variables; -4. map each definition site back to its `NodeId` (`Program.nodeIdOf`) and check +4. map each definition site back to its `RawId` (`Program.nodeIdOf`) and check it is **not** inside the loop body (structural `subtreeIds` membership). If every reaching definition of every RHS variable lies outside the loop, the @@ -33,13 +33,13 @@ open Forward /-- An assignment found inside a loop, paired with the data needed to test its invariance against that (immediately enclosing) loop. -/ -structure Candidate where +structure Candidate (prog : Program) where /-- The enclosing `whileLoop`'s tag (for reporting). -/ - loopId : NodeId - /-- Every `NodeId` inside the loop body (the "is-child-of-loop" set). -/ - bodyIds : List NodeId + loopId : prog.NodeId + /-- Every node id inside the loop body (the "is-child-of-loop" set). -/ + bodyIds : List prog.NodeId /-- The assignment `BasicStmt`'s tag — what labels its CFG node. -/ - assignId : NodeId + assignId : prog.NodeId /-- The variables read by the assignment's RHS. -/ rhsVars : List String @@ -47,47 +47,45 @@ structure Candidate where `enclosing` carries the current loop's tag and body id-set, or `none` outside any loop (in which case assignments are skipped — only in-loop assignments are candidates). -/ -def collectCandidates (enc : Option (NodeId × List NodeId)) : - Stmt.Tagged NodeId → List Candidate +def collectCandidates (prog : Program) (enc : Option (prog.NodeId × List prog.NodeId)) : + Stmt.Tagged prog.NodeId → List (Candidate prog) | .basic _ bs => match bs, enc with | .assign t _ e, some (loopId, bodyIds) => [{ loopId := loopId, bodyIds := bodyIds, assignId := t, rhsVars := e.erase.vars.sort (· ≤ ·) }] | _, _ => [] - | .andThen _ a b => collectCandidates enc a ++ collectCandidates enc b - | .ifElse _ _ a b => collectCandidates enc a ++ collectCandidates enc b + | .andThen _ a b => collectCandidates prog enc a ++ collectCandidates prog enc b + | .ifElse _ _ a b => collectCandidates prog enc a ++ collectCandidates prog enc b | .whileLoop loopT _ body => - collectCandidates (some (loopT, body.subtreeIds)) body + collectCandidates prog (some (loopT, body.subtreeIds)) body /-- Read the definition set assigned to variable `k`, or `⊥` if absent. -/ def lookupDef (prog : Program) (vs : VariableValues (DefSet prog) prog) (k : String) : DefSet prog := if h : FiniteMap.MemKey k vs then (FiniteMap.locate h).1 else ⊥ -/-- The CFG states marked as definition sites in a `DefSet` (those mapped to -`true`). -/ -def defSites (prog : Program) (d : DefSet prog) : List prog.State := - prog.states.filter (fun s => - if h : FiniteMap.MemKey s d then (FiniteMap.locate h).1 else false) +/-- The AST node ids marked as definition sites in a `DefSet` (those mapped to +`true`). With the AST-id-keyed lattice these are recovered directly. -/ +def defSites (prog : Program) (d : DefSet prog) : List prog.NodeId := + (List.finRange prog.size).filter (fun i => d i) /-- Is the candidate assignment loop-invariant: do all reaching definitions of -its RHS variables lie outside the loop body? -/ -def isInvariant (prog : Program) (c : Candidate) : Bool := +its RHS variables lie outside the loop body? Reaching sets are now keyed by AST +node id, so we compare against the loop-body ids directly (embedding the raw +body ids into `p.NodeId`). -/ +def isInvariant (prog : Program) (c : Candidate prog) : Bool := match prog.stateOfNodeId c.assignId with | none => false | some s => let entry := joinForKey s (result (DefSet prog) prog) let combined : DefSet prog := c.rhsVars.foldl (fun acc k => acc ⊔ lookupDef prog entry k) ⊥ - (defSites prog combined).all (fun site => - match prog.nodeIdOf site with - | some nid => ! decide (nid ∈ c.bodyIds) - | none => false) + (defSites prog combined).all (fun nid => ! decide (nid ∈ c.bodyIds)) /-- The loop-invariant assignments of `prog`, as `(loopId, assignId)` pairs. -/ -def licmCandidates (prog : Program) : List (NodeId × NodeId) := - (collectCandidates none prog.tagged).filterMap (fun c => +def licmCandidates (prog : Program) : List (prog.NodeId × prog.NodeId) := + (collectCandidates prog none prog.taggedFin).filterMap (fun c => if isInvariant prog c then some (c.loopId, c.assignId) else none) /-- A human-readable report of the loop-invariant assignments. -/ @@ -97,7 +95,7 @@ def output (prog : Program) : String := | cands => "loop-invariant assignments (loop ↦ assignment):\n" ++ String.intercalate "\n" - (cands.map (fun p => s!" loop #{p.1.post}: assignment #{p.2.post}")) + (cands.map (fun p => s!" loop #{p.1.val}: assignment #{p.2.val}")) end LicmTransformation