blog-static/code/cs325-langs/src/LanguageThree.hs

462 lines
15 KiB
Haskell

module LanguageThree where
import qualified CommonParsing as P
import qualified PythonAst as Py
import Control.Monad.State
import Data.Bifunctor
import Data.Foldable
import Data.Functor
import qualified Data.Map as Map
import Data.Maybe
import Text.Parsec hiding (State)
import Text.Parsec.Char
import Text.Parsec.Combinator
{- Data Types -}
data Op
= Add
| Subtract
| Multiply
| Divide
| LessThan
| LessThanEqual
| GreaterThan
| GreaterThanEqual
| Equal
| NotEqual
| And
| Or
data Expr
= TraverserCall String [Expr]
| FunctionCall String [Expr]
| BinOp Op Expr Expr
| Lambda [String] Expr
| Var String
| IntLiteral Int
| BoolLiteral Bool
| ListLiteral [Expr]
| TupleLiteral [Expr]
type Branch = (Expr, [Stmt])
data Stmt
= IfElse Branch [Branch] [Stmt]
| While Branch
| Traverser String [(String, Expr)]
| Let Pat Expr
| Return Expr
| Standalone Expr
data Pat
= VarPat String
| TuplePat [Pat]
data SortedMarker = Sorted | Unsorted deriving Eq
data Function = Function SortedMarker String [String] [Stmt]
data Prog = Prog [Function]
{- Parser -}
type Parser = Parsec String ()
parseVar :: Parser String
parseVar = P.var
[ "if", "elif", "else"
, "while", "let", "traverser"
, "function", "sort"
, "true", "false"
]
parseBool :: Parser Bool
parseBool = (string "true" $> True) <|> (string "false" $> False)
parseList :: Parser Expr
parseList = ListLiteral <$> P.list '[' ']' ',' parseExpr
parseTupleElems :: Parser [Expr]
parseTupleElems = P.list '(' ')' ',' parseExpr
parseTuple :: Parser Expr
parseTuple = do
es <- parseTupleElems
return $ case es of
e:[] -> e
_ -> TupleLiteral es
parseLambda :: Parser Expr
parseLambda = try $ do
vs <- P.list '(' ')' ',' parseVar
string "->" >> spaces
Lambda vs <$> parseExpr
parseCall :: Parser Expr
parseCall = try $ do
v <- parseVar
choice
[ TraverserCall v <$> (char '!' *> parseTupleElems)
, FunctionCall v <$> parseTupleElems
]
parseBasic :: Parser Expr
parseBasic = choice
[ IntLiteral <$> P.int
, BoolLiteral <$> parseBool
, try parseCall
, Var <$> parseVar
, parseList
, parseLambda
, parseTuple
]
parseExpr :: Parser Expr
parseExpr = P.precedence BinOp parseBasic
[ P.op "*" Multiply <|> P.op "/" Divide
, P.op "+" Add <|> P.op "-" Subtract
, P.op "==" Equal <|> P.op "!=" NotEqual <|>
try (P.op "<=" LessThanEqual) <|> P.op "<" LessThan <|>
try (P.op ">=" GreaterThanEqual) <|> P.op ">" GreaterThan
, P.op "and" And
, P.op "or" Or
]
parseBlock :: Parser [Stmt]
parseBlock = char '{' >> spaces >> many parseStmt <* char '}' <* spaces
parseBranch :: Parser Branch
parseBranch = (,) <$> (parseExpr <* spaces) <*> parseBlock
parseIf :: Parser Stmt
parseIf = do
i <- P.kwIf >> parseBranch
els <- many (P.kwElsif >> parseBranch)
e <- try (P.kwElse >> parseBlock) <|> return []
return $ IfElse i els e
parseWhile :: Parser Stmt
parseWhile = While <$> (P.kwWhile >> parseBranch)
parseTraverser :: Parser Stmt
parseTraverser = Traverser
<$> (P.kwTraverser *> parseVar)
<*> (P.list '(' ')' ',' parseKey) <* char ';' <* spaces
parseKey :: Parser (String, Expr)
parseKey = (,)
<$> (parseVar <* spaces <* char ':' <* spaces)
<*> parseExpr
parseLet :: Parser Stmt
parseLet = Let
<$> (P.kwLet >> parsePat <* char '=' <* spaces)
<*> parseExpr <* char ';' <* spaces
parseReturn :: Parser Stmt
parseReturn = Return <$> (P.kwReturn >> parseExpr <* char ';' <* spaces)
parsePat :: Parser Pat
parsePat = (VarPat <$> parseVar) <|> (TuplePat <$> P.list '(' ')' ',' parsePat)
parseStmt :: Parser Stmt
parseStmt = choice
[ parseTraverser
, parseLet
, parseIf
, parseWhile
, parseReturn
, Standalone <$> (parseExpr <* char ';' <* spaces)
]
parseFunction :: Parser Function
parseFunction = Function
<$> (P.kwSorted $> Sorted <|> return Unsorted)
<*> (P.kwFunction >> parseVar)
<*> (P.list '(' ')' ',' parseVar)
<*> parseBlock
parseProg :: Parser Prog
parseProg = Prog <$> many parseFunction
parse :: String -> String -> Either ParseError Prog
parse = runParser parseProg ()
{- Translation -}
data TraverserBounds = Range Py.PyExpr Py.PyExpr | Random
data TraverserData = TraverserData
{ list :: Maybe String
, bounds :: Maybe TraverserBounds
, rev :: Bool
}
data ValidTraverserData = ValidTraverserData
{ validList :: String
, validBounds :: TraverserBounds
, validRev :: Bool
}
type Translator = State (Map.Map String ValidTraverserData, [Py.PyStmt], Int)
getScoped :: Translator (Map.Map String ValidTraverserData)
getScoped = gets (\(m, _, _) -> m)
setScoped :: Map.Map String ValidTraverserData -> Translator ()
setScoped m = modify (\(_, ss, i) -> (m, ss, i))
scope :: Translator a -> Translator a
scope m = do
s <- getScoped
a <- m
setScoped s
return a
clearTraverser :: String -> Translator ()
clearTraverser s = modify (\(m, ss, i) -> (Map.delete s m, ss, i))
putTraverser :: String -> ValidTraverserData -> Translator ()
putTraverser s vtd = modify (\(m, ss, i) -> (Map.insert s vtd m, ss, i))
getTemp :: Translator String
getTemp = gets $ \(_, _, i) -> "temp" ++ show i
freshTemp :: Translator String
freshTemp = modify (second (+1)) >> getTemp
emitStatement :: Py.PyStmt -> Translator ()
emitStatement = modify . first . (:)
collectStatements :: Translator a -> Translator ([Py.PyStmt], a)
collectStatements t = do
modify (first $ const [])
a <- t
ss <- gets $ \(_, ss, _) -> ss
modify (first $ const [])
return (ss, a)
withdrawStatements :: Translator (Py.PyStmt) -> Translator [Py.PyStmt]
withdrawStatements ts =
(\(ss, s) -> ss ++ [s]) <$> (collectStatements ts)
requireTraverser :: String -> Translator ValidTraverserData
requireTraverser s = gets (\(m, _, _) -> Map.lookup s m) >>= handleMaybe
where
handleMaybe Nothing = fail "Invalid traverser"
handleMaybe (Just vtd) = return vtd
traverserIncrement :: Bool -> Py.PyExpr -> Py.PyExpr -> Py.PyExpr
traverserIncrement rev by e =
Py.BinOp op e (Py.BinOp Py.Multiply by (Py.IntLiteral 1))
where op = if rev then Py.Subtract else Py.Add
traverserValid :: Py.PyExpr -> ValidTraverserData -> Py.PyExpr
traverserValid e vtd =
case validBounds vtd of
Range f t ->
if validRev vtd
then Py.BinOp Py.GreaterThanEq e f
else Py.BinOp Py.LessThan e t
Random -> Py.BoolLiteral True
traverserStep :: String -> ValidTraverserData -> Py.PyStmt
traverserStep s vtd =
case validBounds vtd of
Range _ _ -> Py.Assign (Py.VarPat s) $ Py.BinOp op (Py.Var s) (Py.IntLiteral 1)
where op = if validRev vtd then Py.Subtract else Py.Add
Random -> traverserRandom s $ validList vtd
traverserRandom :: String -> String -> Py.PyStmt
traverserRandom s l =
Py.Assign (Py.VarPat s) $ Py.FunctionCall (Py.Var "random.randrange")
[Py.FunctionCall (Py.Var "len") [Py.Var l]]
hasVar :: String -> Py.PyPat -> Bool
hasVar s (Py.VarPat s') = s == s'
hasVar s (Py.TuplePat ps) = any (hasVar s) ps
hasVar s _ = False
substituteVariable :: String -> Py.PyExpr -> Py.PyExpr -> Py.PyExpr
substituteVariable s e (Py.BinOp o l r) =
Py.BinOp o (substituteVariable s e l) (substituteVariable s e r)
substituteVariable s e (Py.ListLiteral es) =
Py.ListLiteral $ map (substituteVariable s e) es
substituteVariable s e (Py.DictLiteral es) =
Py.DictLiteral $
map (first (substituteVariable s e) . second (substituteVariable s e)) es
substituteVariable s e (Py.Lambda ps e') =
Py.Lambda ps $ if any (hasVar s) ps then substituteVariable s e e' else e'
substituteVariable s e (Py.Var s')
| s == s' = e
| otherwise = Py.Var s'
substituteVariable s e (Py.TupleLiteral es) =
Py.TupleLiteral $ map (substituteVariable s e) es
substituteVariable s e (Py.FunctionCall e' es) =
Py.FunctionCall (substituteVariable s e e') $
map (substituteVariable s e) es
substituteVariable s e (Py.Access e' es) =
Py.Access (substituteVariable s e e') $
map (substituteVariable s e) es
substituteVariable s e (Py.Ternary i t e') =
Py.Ternary (substituteVariable s e i) (substituteVariable s e t)
(substituteVariable s e e')
substituteVariable s e (Py.Member e' m) =
Py.Member (substituteVariable s e e') m
substituteVariable s e (Py.In e1 e2) =
Py.In (substituteVariable s e e1) (substituteVariable s e e2)
substituteVariable s e (Py.NotIn e1 e2) =
Py.NotIn (substituteVariable s e e1) (substituteVariable s e e2)
substituteVariable s e (Py.Slice f t) =
Py.Slice (substituteVariable s e <$> f) (substituteVariable s e <$> t)
translateExpr :: Expr -> Translator Py.PyExpr
translateExpr (TraverserCall "pop" [Var s]) = do
l <- validList <$> requireTraverser s
return $ Py.FunctionCall (Py.Member (Py.Var l) "pop") [Py.Var s]
translateExpr (TraverserCall "pos" [Var s]) = do
requireTraverser s
return $ Py.Var s
translateExpr (TraverserCall "at" [Var s]) = do
l <- validList <$> requireTraverser s
return $ Py.Access (Py.Var l) [Py.Var s]
translateExpr (TraverserCall "at" [Var s, IntLiteral i]) = do
vtd <- requireTraverser s
return $ Py.Access (Py.Var $ validList vtd)
[traverserIncrement (validRev vtd) (Py.IntLiteral i) (Py.Var s)]
translateExpr (TraverserCall "step" [Var s]) = do
vtd <- requireTraverser s
emitStatement $ traverserStep s vtd
return $ Py.IntLiteral 0
translateExpr (TraverserCall "canstep" [Var s]) = do
vtd <- requireTraverser s
return $
traverserValid
(traverserIncrement (validRev vtd) (Py.IntLiteral 1) (Py.Var s)) vtd
translateExpr (TraverserCall "valid" [Var s]) = do
vtd <- requireTraverser s
return $ traverserValid (Py.Var s) vtd
translateExpr (TraverserCall "subset" [Var s1, Var s2]) = do
l1 <- validList <$> requireTraverser s1
l2 <- validList <$> requireTraverser s2
if l1 == l2
then return $ Py.Access (Py.Var l1) [Py.Slice (Just $ Py.Var s1) (Just $ Py.Var s2)]
else fail "Incompatible traversers!"
translateExpr (TraverserCall "bisect" [Var s, Lambda [x] e]) = do
vtd <- requireTraverser s
newTemp <- freshTemp
lambdaExpr <- translateExpr e
let access = Py.Access (Py.Var $ validList vtd) [Py.Var s]
let translated = substituteVariable x access lambdaExpr
let append s = Py.FunctionCall (Py.Member (Py.Var s) "append") [ access ]
let bisectStmt = Py.FunctionDef newTemp []
[ Py.Nonlocal [s]
, Py.Assign (Py.VarPat "l") (Py.ListLiteral [])
, Py.Assign (Py.VarPat "r") (Py.ListLiteral [])
, Py.While (traverserValid (Py.Var s) vtd)
[ Py.IfElse translated
[ Py.Standalone $ append "l" ]
[]
(Just [ Py.Standalone $ append "r" ])
, traverserStep s vtd
]
, Py.Return $ Py.TupleLiteral [Py.Var "l", Py.Var "r"]
]
emitStatement bisectStmt
return $ Py.FunctionCall (Py.Var newTemp) []
translateExpr (TraverserCall _ _) = fail "Invalid traverser operation"
translateExpr (FunctionCall f ps) = do
pes <- mapM translateExpr ps
return $ Py.FunctionCall (Py.Var f) pes
translateExpr (BinOp o l r) =
Py.BinOp (translateOp o) <$> translateExpr l <*> translateExpr r
translateExpr (Lambda ps e) =
Py.Lambda (map Py.VarPat ps) <$> translateExpr e
translateExpr (Var s) = return $ Py.Var s
translateExpr (IntLiteral i) = return $ Py.IntLiteral i
translateExpr (BoolLiteral b) = return $ Py.BoolLiteral b
translateExpr (ListLiteral es) = Py.ListLiteral <$> mapM translateExpr es
translateExpr (TupleLiteral es) = Py.TupleLiteral <$> mapM translateExpr es
applyOption :: TraverserData -> (String, Py.PyExpr) -> Maybe TraverserData
applyOption td ("list", Py.Var s) =
return $ td { list = Just s }
applyOption td ("span", Py.TupleLiteral [f, t]) =
return $ td { bounds = Just $ Range f t }
applyOption td ("random", Py.BoolLiteral True) =
return $ td { bounds = Just Random }
applyOption td ("reverse", Py.BoolLiteral b) =
return $ td { rev = b }
applyOption td _ = Nothing
translateOption :: (String, Expr) -> Translator (String, Py.PyExpr)
translateOption (s, e) = (,) s <$> translateExpr e
defaultTraverser :: TraverserData
defaultTraverser =
TraverserData { list = Nothing, bounds = Nothing, rev = False }
translateBranch :: Branch -> Translator (Py.PyExpr, [Py.PyStmt])
translateBranch (e, s) = (,) <$> translateExpr e <*>
(concat <$> mapM (withdrawStatements . translateStmt) s)
translateStmt :: Stmt -> Translator Py.PyStmt
translateStmt (IfElse i els e) = uncurry Py.IfElse
<$> (translateBranch i) <*> (mapM translateBranch els) <*> convertElse e
where
convertElse [] = return Nothing
convertElse es = Just . concat <$>
mapM (withdrawStatements . translateStmt) es
translateStmt (While b) = uncurry Py.While <$> translateBranch b
translateStmt (Traverser s os) =
foldlM applyOption defaultTraverser <$> mapM translateOption os >>= saveTraverser
where
saveTraverser :: Maybe TraverserData -> Translator Py.PyStmt
saveTraverser (Just (td@TraverserData { list = Just l, bounds = Just bs})) =
putTraverser s vtd $> translateInitialBounds s vtd
where
vtd = ValidTraverserData
{ validList = l
, validBounds = bs
, validRev = rev td
}
saveTraverser Nothing = fail "Invalid traverser (!)"
translateStmt (Let p e) = Py.Assign <$> translatePat p <*> translateExpr e
translateStmt (Return e) = Py.Return <$> translateExpr e
translateStmt (Standalone e) = Py.Standalone <$> translateExpr e
translateInitialBounds :: String -> ValidTraverserData -> Py.PyStmt
translateInitialBounds s vtd =
case (validBounds vtd, validRev vtd) of
(Random, _) -> traverserRandom s $ validList vtd
(Range l _, False) -> Py.Assign (Py.VarPat s) l
(Range _ r, True) -> Py.Assign (Py.VarPat s) r
translatePat :: Pat -> Translator Py.PyPat
translatePat (VarPat s) = clearTraverser s $> Py.VarPat s
translatePat (TuplePat ts) = Py.TuplePat <$> mapM translatePat ts
translateOp :: Op -> Py.PyBinOp
translateOp Add = Py.Add
translateOp Subtract = Py.Subtract
translateOp Multiply = Py.Multiply
translateOp Divide = Py.Divide
translateOp LessThan = Py.LessThan
translateOp LessThanEqual = Py.LessThanEq
translateOp GreaterThan = Py.GreaterThan
translateOp GreaterThanEqual = Py.GreaterThanEq
translateOp Equal = Py.Equal
translateOp NotEqual = Py.NotEqual
translateOp And = Py.And
translateOp Or = Py.Or
translateFunction :: Function -> [Py.PyStmt]
translateFunction (Function m s ps ss) = return $ Py.FunctionDef s ps $
[ Py.Standalone $ Py.FunctionCall (Py.Member (Py.Var p) "sort") []
| p <- take 1 ps, m == Sorted ] ++ stmts
where
stmts = concat $ evalState
(mapM (withdrawStatements . translateStmt) ss) (Map.empty, [], 0)
translate :: Prog -> [Py.PyStmt]
translate (Prog fs) =
(Py.FromImport "bisect" ["bisect"]) :
(Py.Import "random") : concatMap translateFunction fs