data ExprType
  = IntType
  | BoolType
  | StringType

repr : ExprType -> Type
repr IntType = Int
repr BoolType = Bool
repr StringType = String

intBoolImpossible : IntType = BoolType -> Void
intBoolImpossible Refl impossible

intStringImpossible : IntType = StringType -> Void
intStringImpossible Refl impossible

boolStringImpossible : BoolType = StringType -> Void
boolStringImpossible Refl impossible

decEq : (a : ExprType) -> (b : ExprType) -> Dec (a = b)
decEq IntType IntType = Yes Refl
decEq BoolType BoolType = Yes Refl
decEq StringType StringType = Yes Refl
decEq IntType BoolType = No intBoolImpossible
decEq BoolType IntType = No $ intBoolImpossible . sym
decEq IntType StringType = No intStringImpossible
decEq StringType IntType = No $ intStringImpossible . sym
decEq BoolType StringType = No boolStringImpossible 
decEq StringType BoolType = No $ boolStringImpossible . sym

data Op
  = Add
  | Subtract
  | Multiply
  | Divide

data Expr
  = IntLit Int
  | BoolLit Bool
  | StringLit String
  | BinOp Op Expr Expr
  | IfElse Expr Expr Expr

data SafeExpr : ExprType -> Type where
  IntLiteral : Int -> SafeExpr IntType
  BoolLiteral : Bool -> SafeExpr BoolType
  StringLiteral : String -> SafeExpr StringType
  BinOperation : (repr a -> repr b -> repr c) -> SafeExpr a -> SafeExpr b -> SafeExpr c
  IfThenElse : SafeExpr BoolType -> SafeExpr t -> SafeExpr t -> SafeExpr t

typecheckOp : Op -> (a : ExprType) -> (b : ExprType) -> Either String (c : ExprType ** repr a -> repr b -> repr c) 
typecheckOp Add IntType IntType = Right (IntType ** (+))
typecheckOp Subtract IntType IntType = Right (IntType ** (-))
typecheckOp Multiply IntType IntType = Right (IntType ** (*))
typecheckOp Divide IntType IntType = Right (IntType ** div)
typecheckOp _ _ _ = Left "Invalid binary operator application"

requireBool : (n : ExprType ** SafeExpr n) -> Either String (SafeExpr BoolType)
requireBool (BoolType ** e) = Right e
requireBool _ = Left "Not a boolean."

typecheck : Expr -> Either String (n : ExprType ** SafeExpr n)
typecheck (IntLit i) = Right (_ ** IntLiteral i)
typecheck (BoolLit b) = Right (_ ** BoolLiteral b)
typecheck (StringLit s) = Right (_ ** StringLiteral s)
typecheck (BinOp o l r) = do
  (lt ** le) <- typecheck l
  (rt ** re) <- typecheck r
  (ot ** f) <- typecheckOp o lt rt
  pure (_ ** BinOperation f le re)
typecheck (IfElse c t e) =
  do
    ce <- typecheck c >>= requireBool
    (tt ** te) <- typecheck t
    (et ** ee) <- typecheck e
    case decEq tt et of
      Yes p => pure (_ ** IfThenElse ce (replace p te) ee)
      No _ => Left "Incompatible branch types."

eval : SafeExpr t -> repr t
eval (IntLiteral i) = i
eval (BoolLiteral b) = b
eval (StringLiteral s) = s
eval (BinOperation f l r) = f (eval l) (eval r)
eval (IfThenElse c t e) = if (eval c) then (eval t) else (eval e)

resultStr : {t : ExprType} -> repr t -> String
resultStr {t=IntType} i = show i
resultStr {t=BoolType} b = show b
resultStr {t=StringType} s = show s

tryEval : Expr -> String
tryEval ex =
  case typecheck ex of
    Left err => "Type error: " ++ err
    Right (t ** e) => resultStr $ eval {t} e

main : IO ()
main = putStrLn $ tryEval $ BinOp Add (IfElse (BoolLit True) (IntLit 6) (IntLit 7)) (BinOp Multiply (IntLit 160) (IntLit 2))