data Reg = A | B | R

data Ty = IntTy | BoolTy

TypeState : Type
TypeState = (Ty, Ty, Ty)

getRegTy : Reg -> TypeState -> Ty
getRegTy A (a, _, _) = a
getRegTy B (_, b, _) = b
getRegTy R (_, _, r) = r

setRegTy : Reg -> Ty -> TypeState -> TypeState
setRegTy A a (_, b, r) = (a, b, r)
setRegTy B b (a, _, r) = (a, b, r)
setRegTy R r (a, b, _) = (a, b, r)

data Expr : TypeState -> Ty -> Type where
  Lit : Int -> Expr s IntTy
  Load : (r : Reg) -> Expr s (getRegTy r s)
  Add : Expr s IntTy -> Expr s IntTy -> Expr s IntTy
  Leq : Expr s IntTy -> Expr s IntTy -> Expr s BoolTy
  Not : Expr s BoolTy -> Expr s BoolTy

mutual
  data Stmt : TypeState -> TypeState -> TypeState -> Type where
    Store : (r : Reg) -> Expr s t -> Stmt l s (setRegTy r t s)
    If : Expr s BoolTy -> Prog l s n -> Prog l s n -> Stmt l s n
    Loop : Prog s s s -> Stmt l s s
    Break : Stmt s s s

  data Prog : TypeState -> TypeState -> TypeState -> Type where
    Nil : Prog l s s
    (::) : Stmt l s n -> Prog l n m -> Prog l s m

initialState : TypeState
initialState = (IntTy, IntTy, IntTy)

testProg : Prog Main.initialState Main.initialState Main.initialState
testProg =
  [ Store A (Lit 1 `Leq` Lit 2)
  , If (Load A)
    [ Store A (Lit 1) ]
    [ Store A (Lit 2) ]
  , Store B (Lit 2)
  , Store R (Add (Load A) (Load B))
  ]

prodProg : Prog Main.initialState Main.initialState Main.initialState
prodProg =
  [ Store A (Lit 7)
  , Store B (Lit 9)
  , Store R (Lit 0)
  , Loop
    [ If (Load A `Leq` Lit 0)
      [ Break ]
      [ Store R (Load R `Add` Load B)
      , Store A (Load A `Add` Lit (-1))
      ]
    ]
  ]

repr : Ty -> Type
repr IntTy = Int
repr BoolTy = Bool

data State : TypeState -> Type where
  MkState : (repr a, repr b, repr c) -> State (a, b, c)

getReg : (r : Reg) -> State s -> repr (getRegTy r s)
getReg A (MkState (a, _, _)) = a
getReg B (MkState (_, b, _)) = b
getReg R (MkState (_, _, r)) = r

setReg : (r : Reg) -> repr t -> State s -> State (setRegTy r t s)
setReg A a (MkState (_, b, r)) = MkState (a, b, r)
setReg B b (MkState (a, _, r)) = MkState (a, b, r)
setReg R r (MkState (a, b, _)) = MkState (a, b, r)

expr : Expr s t -> State s -> repr t
expr (Lit i) _ = i
expr (Load r) s = getReg r s
expr (Add l r) s = expr l s + expr r s
expr (Leq l r) s = expr l s <= expr r s
expr (Not e) s = not $ expr e s

mutual
  stmt : Stmt l s n -> State s -> Either (State l) (State n)
  stmt (Store r e) s = Right $ setReg r (expr e s) s
  stmt (If c t e) s = if expr c s then prog t s else prog e s
  stmt (Loop p) s =
    case prog p s >>= stmt (Loop p) of
      Right s => Right s
      Left s => Right s
  stmt Break s = Left s

  prog : Prog l s n -> State s -> Either (State l) (State n)
  prog Nil s = Right s
  prog (st::p) s = stmt st s >>= prog p

run : Prog l s l -> State s -> State l
run p s = either id id $ prog p s