462 lines
15 KiB
Haskell
462 lines
15 KiB
Haskell
|
module LanguageThree where
|
||
|
import qualified CommonParsing as P
|
||
|
import qualified PythonAst as Py
|
||
|
import Control.Monad.State
|
||
|
import Data.Bifunctor
|
||
|
import Data.Foldable
|
||
|
import Data.Functor
|
||
|
import qualified Data.Map as Map
|
||
|
import Data.Maybe
|
||
|
import Text.Parsec hiding (State)
|
||
|
import Text.Parsec.Char
|
||
|
import Text.Parsec.Combinator
|
||
|
|
||
|
{- Data Types -}
|
||
|
data Op
|
||
|
= Add
|
||
|
| Subtract
|
||
|
| Multiply
|
||
|
| Divide
|
||
|
| LessThan
|
||
|
| LessThanEqual
|
||
|
| GreaterThan
|
||
|
| GreaterThanEqual
|
||
|
| Equal
|
||
|
| NotEqual
|
||
|
| And
|
||
|
| Or
|
||
|
|
||
|
data Expr
|
||
|
= TraverserCall String [Expr]
|
||
|
| FunctionCall String [Expr]
|
||
|
| BinOp Op Expr Expr
|
||
|
| Lambda [String] Expr
|
||
|
| Var String
|
||
|
| IntLiteral Int
|
||
|
| BoolLiteral Bool
|
||
|
| ListLiteral [Expr]
|
||
|
| TupleLiteral [Expr]
|
||
|
|
||
|
type Branch = (Expr, [Stmt])
|
||
|
|
||
|
data Stmt
|
||
|
= IfElse Branch [Branch] [Stmt]
|
||
|
| While Branch
|
||
|
| Traverser String [(String, Expr)]
|
||
|
| Let Pat Expr
|
||
|
| Return Expr
|
||
|
| Standalone Expr
|
||
|
|
||
|
data Pat
|
||
|
= VarPat String
|
||
|
| TuplePat [Pat]
|
||
|
|
||
|
data SortedMarker = Sorted | Unsorted deriving Eq
|
||
|
|
||
|
data Function = Function SortedMarker String [String] [Stmt]
|
||
|
|
||
|
data Prog = Prog [Function]
|
||
|
|
||
|
{- Parser -}
|
||
|
type Parser = Parsec String ()
|
||
|
|
||
|
parseVar :: Parser String
|
||
|
parseVar = P.var
|
||
|
[ "if", "elif", "else"
|
||
|
, "while", "let", "traverser"
|
||
|
, "function", "sort"
|
||
|
, "true", "false"
|
||
|
]
|
||
|
|
||
|
parseBool :: Parser Bool
|
||
|
parseBool = (string "true" $> True) <|> (string "false" $> False)
|
||
|
|
||
|
parseList :: Parser Expr
|
||
|
parseList = ListLiteral <$> P.list '[' ']' ',' parseExpr
|
||
|
|
||
|
parseTupleElems :: Parser [Expr]
|
||
|
parseTupleElems = P.list '(' ')' ',' parseExpr
|
||
|
|
||
|
parseTuple :: Parser Expr
|
||
|
parseTuple = do
|
||
|
es <- parseTupleElems
|
||
|
return $ case es of
|
||
|
e:[] -> e
|
||
|
_ -> TupleLiteral es
|
||
|
|
||
|
parseLambda :: Parser Expr
|
||
|
parseLambda = try $ do
|
||
|
vs <- P.list '(' ')' ',' parseVar
|
||
|
string "->" >> spaces
|
||
|
Lambda vs <$> parseExpr
|
||
|
|
||
|
parseCall :: Parser Expr
|
||
|
parseCall = try $ do
|
||
|
v <- parseVar
|
||
|
choice
|
||
|
[ TraverserCall v <$> (char '!' *> parseTupleElems)
|
||
|
, FunctionCall v <$> parseTupleElems
|
||
|
]
|
||
|
|
||
|
parseBasic :: Parser Expr
|
||
|
parseBasic = choice
|
||
|
[ IntLiteral <$> P.int
|
||
|
, BoolLiteral <$> parseBool
|
||
|
, try parseCall
|
||
|
, Var <$> parseVar
|
||
|
, parseList
|
||
|
, parseLambda
|
||
|
, parseTuple
|
||
|
]
|
||
|
|
||
|
parseExpr :: Parser Expr
|
||
|
parseExpr = P.precedence BinOp parseBasic
|
||
|
[ P.op "*" Multiply <|> P.op "/" Divide
|
||
|
, P.op "+" Add <|> P.op "-" Subtract
|
||
|
, P.op "==" Equal <|> P.op "!=" NotEqual <|>
|
||
|
try (P.op "<=" LessThanEqual) <|> P.op "<" LessThan <|>
|
||
|
try (P.op ">=" GreaterThanEqual) <|> P.op ">" GreaterThan
|
||
|
, P.op "and" And
|
||
|
, P.op "or" Or
|
||
|
]
|
||
|
|
||
|
parseBlock :: Parser [Stmt]
|
||
|
parseBlock = char '{' >> spaces >> many parseStmt <* char '}' <* spaces
|
||
|
|
||
|
parseBranch :: Parser Branch
|
||
|
parseBranch = (,) <$> (parseExpr <* spaces) <*> parseBlock
|
||
|
|
||
|
parseIf :: Parser Stmt
|
||
|
parseIf = do
|
||
|
i <- P.kwIf >> parseBranch
|
||
|
els <- many (P.kwElsif >> parseBranch)
|
||
|
e <- try (P.kwElse >> parseBlock) <|> return []
|
||
|
return $ IfElse i els e
|
||
|
|
||
|
parseWhile :: Parser Stmt
|
||
|
parseWhile = While <$> (P.kwWhile >> parseBranch)
|
||
|
|
||
|
parseTraverser :: Parser Stmt
|
||
|
parseTraverser = Traverser
|
||
|
<$> (P.kwTraverser *> parseVar)
|
||
|
<*> (P.list '(' ')' ',' parseKey) <* char ';' <* spaces
|
||
|
|
||
|
parseKey :: Parser (String, Expr)
|
||
|
parseKey = (,)
|
||
|
<$> (parseVar <* spaces <* char ':' <* spaces)
|
||
|
<*> parseExpr
|
||
|
|
||
|
parseLet :: Parser Stmt
|
||
|
parseLet = Let
|
||
|
<$> (P.kwLet >> parsePat <* char '=' <* spaces)
|
||
|
<*> parseExpr <* char ';' <* spaces
|
||
|
|
||
|
parseReturn :: Parser Stmt
|
||
|
parseReturn = Return <$> (P.kwReturn >> parseExpr <* char ';' <* spaces)
|
||
|
|
||
|
parsePat :: Parser Pat
|
||
|
parsePat = (VarPat <$> parseVar) <|> (TuplePat <$> P.list '(' ')' ',' parsePat)
|
||
|
|
||
|
parseStmt :: Parser Stmt
|
||
|
parseStmt = choice
|
||
|
[ parseTraverser
|
||
|
, parseLet
|
||
|
, parseIf
|
||
|
, parseWhile
|
||
|
, parseReturn
|
||
|
, Standalone <$> (parseExpr <* char ';' <* spaces)
|
||
|
]
|
||
|
|
||
|
parseFunction :: Parser Function
|
||
|
parseFunction = Function
|
||
|
<$> (P.kwSorted $> Sorted <|> return Unsorted)
|
||
|
<*> (P.kwFunction >> parseVar)
|
||
|
<*> (P.list '(' ')' ',' parseVar)
|
||
|
<*> parseBlock
|
||
|
|
||
|
parseProg :: Parser Prog
|
||
|
parseProg = Prog <$> many parseFunction
|
||
|
|
||
|
parse :: String -> String -> Either ParseError Prog
|
||
|
parse = runParser parseProg ()
|
||
|
|
||
|
{- Translation -}
|
||
|
data TraverserBounds = Range Py.PyExpr Py.PyExpr | Random
|
||
|
|
||
|
data TraverserData = TraverserData
|
||
|
{ list :: Maybe String
|
||
|
, bounds :: Maybe TraverserBounds
|
||
|
, rev :: Bool
|
||
|
}
|
||
|
|
||
|
data ValidTraverserData = ValidTraverserData
|
||
|
{ validList :: String
|
||
|
, validBounds :: TraverserBounds
|
||
|
, validRev :: Bool
|
||
|
}
|
||
|
|
||
|
type Translator = State (Map.Map String ValidTraverserData, [Py.PyStmt], Int)
|
||
|
|
||
|
getScoped :: Translator (Map.Map String ValidTraverserData)
|
||
|
getScoped = gets (\(m, _, _) -> m)
|
||
|
|
||
|
setScoped :: Map.Map String ValidTraverserData -> Translator ()
|
||
|
setScoped m = modify (\(_, ss, i) -> (m, ss, i))
|
||
|
|
||
|
scope :: Translator a -> Translator a
|
||
|
scope m = do
|
||
|
s <- getScoped
|
||
|
a <- m
|
||
|
setScoped s
|
||
|
return a
|
||
|
|
||
|
clearTraverser :: String -> Translator ()
|
||
|
clearTraverser s = modify (\(m, ss, i) -> (Map.delete s m, ss, i))
|
||
|
|
||
|
putTraverser :: String -> ValidTraverserData -> Translator ()
|
||
|
putTraverser s vtd = modify (\(m, ss, i) -> (Map.insert s vtd m, ss, i))
|
||
|
|
||
|
getTemp :: Translator String
|
||
|
getTemp = gets $ \(_, _, i) -> "temp" ++ show i
|
||
|
|
||
|
freshTemp :: Translator String
|
||
|
freshTemp = modify (second (+1)) >> getTemp
|
||
|
|
||
|
emitStatement :: Py.PyStmt -> Translator ()
|
||
|
emitStatement = modify . first . (:)
|
||
|
|
||
|
collectStatements :: Translator a -> Translator ([Py.PyStmt], a)
|
||
|
collectStatements t = do
|
||
|
modify (first $ const [])
|
||
|
a <- t
|
||
|
ss <- gets $ \(_, ss, _) -> ss
|
||
|
modify (first $ const [])
|
||
|
return (ss, a)
|
||
|
|
||
|
withdrawStatements :: Translator (Py.PyStmt) -> Translator [Py.PyStmt]
|
||
|
withdrawStatements ts =
|
||
|
(\(ss, s) -> ss ++ [s]) <$> (collectStatements ts)
|
||
|
|
||
|
requireTraverser :: String -> Translator ValidTraverserData
|
||
|
requireTraverser s = gets (\(m, _, _) -> Map.lookup s m) >>= handleMaybe
|
||
|
where
|
||
|
handleMaybe Nothing = fail "Invalid traverser"
|
||
|
handleMaybe (Just vtd) = return vtd
|
||
|
|
||
|
traverserIncrement :: Bool -> Py.PyExpr -> Py.PyExpr -> Py.PyExpr
|
||
|
traverserIncrement rev by e =
|
||
|
Py.BinOp op e (Py.BinOp Py.Multiply by (Py.IntLiteral 1))
|
||
|
where op = if rev then Py.Subtract else Py.Add
|
||
|
|
||
|
traverserValid :: Py.PyExpr -> ValidTraverserData -> Py.PyExpr
|
||
|
traverserValid e vtd =
|
||
|
case validBounds vtd of
|
||
|
Range f t ->
|
||
|
if validRev vtd
|
||
|
then Py.BinOp Py.GreaterThanEq e f
|
||
|
else Py.BinOp Py.LessThan e t
|
||
|
Random -> Py.BoolLiteral True
|
||
|
|
||
|
traverserStep :: String -> ValidTraverserData -> Py.PyStmt
|
||
|
traverserStep s vtd =
|
||
|
case validBounds vtd of
|
||
|
Range _ _ -> Py.Assign (Py.VarPat s) $ Py.BinOp op (Py.Var s) (Py.IntLiteral 1)
|
||
|
where op = if validRev vtd then Py.Subtract else Py.Add
|
||
|
Random -> traverserRandom s $ validList vtd
|
||
|
|
||
|
traverserRandom :: String -> String -> Py.PyStmt
|
||
|
traverserRandom s l =
|
||
|
Py.Assign (Py.VarPat s) $ Py.FunctionCall (Py.Var "random.randrange")
|
||
|
[Py.FunctionCall (Py.Var "len") [Py.Var l]]
|
||
|
|
||
|
hasVar :: String -> Py.PyPat -> Bool
|
||
|
hasVar s (Py.VarPat s') = s == s'
|
||
|
hasVar s (Py.TuplePat ps) = any (hasVar s) ps
|
||
|
hasVar s _ = False
|
||
|
|
||
|
substituteVariable :: String -> Py.PyExpr -> Py.PyExpr -> Py.PyExpr
|
||
|
substituteVariable s e (Py.BinOp o l r) =
|
||
|
Py.BinOp o (substituteVariable s e l) (substituteVariable s e r)
|
||
|
substituteVariable s e (Py.ListLiteral es) =
|
||
|
Py.ListLiteral $ map (substituteVariable s e) es
|
||
|
substituteVariable s e (Py.DictLiteral es) =
|
||
|
Py.DictLiteral $
|
||
|
map (first (substituteVariable s e) . second (substituteVariable s e)) es
|
||
|
substituteVariable s e (Py.Lambda ps e') =
|
||
|
Py.Lambda ps $ if any (hasVar s) ps then substituteVariable s e e' else e'
|
||
|
substituteVariable s e (Py.Var s')
|
||
|
| s == s' = e
|
||
|
| otherwise = Py.Var s'
|
||
|
substituteVariable s e (Py.TupleLiteral es) =
|
||
|
Py.TupleLiteral $ map (substituteVariable s e) es
|
||
|
substituteVariable s e (Py.FunctionCall e' es) =
|
||
|
Py.FunctionCall (substituteVariable s e e') $
|
||
|
map (substituteVariable s e) es
|
||
|
substituteVariable s e (Py.Access e' es) =
|
||
|
Py.Access (substituteVariable s e e') $
|
||
|
map (substituteVariable s e) es
|
||
|
substituteVariable s e (Py.Ternary i t e') =
|
||
|
Py.Ternary (substituteVariable s e i) (substituteVariable s e t)
|
||
|
(substituteVariable s e e')
|
||
|
substituteVariable s e (Py.Member e' m) =
|
||
|
Py.Member (substituteVariable s e e') m
|
||
|
substituteVariable s e (Py.In e1 e2) =
|
||
|
Py.In (substituteVariable s e e1) (substituteVariable s e e2)
|
||
|
substituteVariable s e (Py.NotIn e1 e2) =
|
||
|
Py.NotIn (substituteVariable s e e1) (substituteVariable s e e2)
|
||
|
substituteVariable s e (Py.Slice f t) =
|
||
|
Py.Slice (substituteVariable s e <$> f) (substituteVariable s e <$> t)
|
||
|
|
||
|
translateExpr :: Expr -> Translator Py.PyExpr
|
||
|
translateExpr (TraverserCall "pop" [Var s]) = do
|
||
|
l <- validList <$> requireTraverser s
|
||
|
return $ Py.FunctionCall (Py.Member (Py.Var l) "pop") [Py.Var s]
|
||
|
translateExpr (TraverserCall "pos" [Var s]) = do
|
||
|
requireTraverser s
|
||
|
return $ Py.Var s
|
||
|
translateExpr (TraverserCall "at" [Var s]) = do
|
||
|
l <- validList <$> requireTraverser s
|
||
|
return $ Py.Access (Py.Var l) [Py.Var s]
|
||
|
translateExpr (TraverserCall "at" [Var s, IntLiteral i]) = do
|
||
|
vtd <- requireTraverser s
|
||
|
return $ Py.Access (Py.Var $ validList vtd)
|
||
|
[traverserIncrement (validRev vtd) (Py.IntLiteral i) (Py.Var s)]
|
||
|
translateExpr (TraverserCall "step" [Var s]) = do
|
||
|
vtd <- requireTraverser s
|
||
|
emitStatement $ traverserStep s vtd
|
||
|
return $ Py.IntLiteral 0
|
||
|
translateExpr (TraverserCall "canstep" [Var s]) = do
|
||
|
vtd <- requireTraverser s
|
||
|
return $
|
||
|
traverserValid
|
||
|
(traverserIncrement (validRev vtd) (Py.IntLiteral 1) (Py.Var s)) vtd
|
||
|
translateExpr (TraverserCall "valid" [Var s]) = do
|
||
|
vtd <- requireTraverser s
|
||
|
return $ traverserValid (Py.Var s) vtd
|
||
|
translateExpr (TraverserCall "subset" [Var s1, Var s2]) = do
|
||
|
l1 <- validList <$> requireTraverser s1
|
||
|
l2 <- validList <$> requireTraverser s2
|
||
|
if l1 == l2
|
||
|
then return $ Py.Access (Py.Var l1) [Py.Slice (Just $ Py.Var s1) (Just $ Py.Var s2)]
|
||
|
else fail "Incompatible traversers!"
|
||
|
translateExpr (TraverserCall "bisect" [Var s, Lambda [x] e]) = do
|
||
|
vtd <- requireTraverser s
|
||
|
newTemp <- freshTemp
|
||
|
lambdaExpr <- translateExpr e
|
||
|
let access = Py.Access (Py.Var $ validList vtd) [Py.Var s]
|
||
|
let translated = substituteVariable x access lambdaExpr
|
||
|
let append s = Py.FunctionCall (Py.Member (Py.Var s) "append") [ access ]
|
||
|
let bisectStmt = Py.FunctionDef newTemp []
|
||
|
[ Py.Nonlocal [s]
|
||
|
, Py.Assign (Py.VarPat "l") (Py.ListLiteral [])
|
||
|
, Py.Assign (Py.VarPat "r") (Py.ListLiteral [])
|
||
|
, Py.While (traverserValid (Py.Var s) vtd)
|
||
|
[ Py.IfElse translated
|
||
|
[ Py.Standalone $ append "l" ]
|
||
|
[]
|
||
|
(Just [ Py.Standalone $ append "r" ])
|
||
|
, traverserStep s vtd
|
||
|
]
|
||
|
, Py.Return $ Py.TupleLiteral [Py.Var "l", Py.Var "r"]
|
||
|
]
|
||
|
emitStatement bisectStmt
|
||
|
return $ Py.FunctionCall (Py.Var newTemp) []
|
||
|
translateExpr (TraverserCall _ _) = fail "Invalid traverser operation"
|
||
|
translateExpr (FunctionCall f ps) = do
|
||
|
pes <- mapM translateExpr ps
|
||
|
return $ Py.FunctionCall (Py.Var f) pes
|
||
|
translateExpr (BinOp o l r) =
|
||
|
Py.BinOp (translateOp o) <$> translateExpr l <*> translateExpr r
|
||
|
translateExpr (Lambda ps e) =
|
||
|
Py.Lambda (map Py.VarPat ps) <$> translateExpr e
|
||
|
translateExpr (Var s) = return $ Py.Var s
|
||
|
translateExpr (IntLiteral i) = return $ Py.IntLiteral i
|
||
|
translateExpr (BoolLiteral b) = return $ Py.BoolLiteral b
|
||
|
translateExpr (ListLiteral es) = Py.ListLiteral <$> mapM translateExpr es
|
||
|
translateExpr (TupleLiteral es) = Py.TupleLiteral <$> mapM translateExpr es
|
||
|
|
||
|
applyOption :: TraverserData -> (String, Py.PyExpr) -> Maybe TraverserData
|
||
|
applyOption td ("list", Py.Var s) =
|
||
|
return $ td { list = Just s }
|
||
|
applyOption td ("span", Py.TupleLiteral [f, t]) =
|
||
|
return $ td { bounds = Just $ Range f t }
|
||
|
applyOption td ("random", Py.BoolLiteral True) =
|
||
|
return $ td { bounds = Just Random }
|
||
|
applyOption td ("reverse", Py.BoolLiteral b) =
|
||
|
return $ td { rev = b }
|
||
|
applyOption td _ = Nothing
|
||
|
|
||
|
translateOption :: (String, Expr) -> Translator (String, Py.PyExpr)
|
||
|
translateOption (s, e) = (,) s <$> translateExpr e
|
||
|
|
||
|
defaultTraverser :: TraverserData
|
||
|
defaultTraverser =
|
||
|
TraverserData { list = Nothing, bounds = Nothing, rev = False }
|
||
|
|
||
|
translateBranch :: Branch -> Translator (Py.PyExpr, [Py.PyStmt])
|
||
|
translateBranch (e, s) = (,) <$> translateExpr e <*>
|
||
|
(concat <$> mapM (withdrawStatements . translateStmt) s)
|
||
|
|
||
|
translateStmt :: Stmt -> Translator Py.PyStmt
|
||
|
translateStmt (IfElse i els e) = uncurry Py.IfElse
|
||
|
<$> (translateBranch i) <*> (mapM translateBranch els) <*> convertElse e
|
||
|
where
|
||
|
convertElse [] = return Nothing
|
||
|
convertElse es = Just . concat <$>
|
||
|
mapM (withdrawStatements . translateStmt) es
|
||
|
translateStmt (While b) = uncurry Py.While <$> translateBranch b
|
||
|
translateStmt (Traverser s os) =
|
||
|
foldlM applyOption defaultTraverser <$> mapM translateOption os >>= saveTraverser
|
||
|
where
|
||
|
saveTraverser :: Maybe TraverserData -> Translator Py.PyStmt
|
||
|
saveTraverser (Just (td@TraverserData { list = Just l, bounds = Just bs})) =
|
||
|
putTraverser s vtd $> translateInitialBounds s vtd
|
||
|
where
|
||
|
vtd = ValidTraverserData
|
||
|
{ validList = l
|
||
|
, validBounds = bs
|
||
|
, validRev = rev td
|
||
|
}
|
||
|
saveTraverser Nothing = fail "Invalid traverser (!)"
|
||
|
translateStmt (Let p e) = Py.Assign <$> translatePat p <*> translateExpr e
|
||
|
translateStmt (Return e) = Py.Return <$> translateExpr e
|
||
|
translateStmt (Standalone e) = Py.Standalone <$> translateExpr e
|
||
|
|
||
|
translateInitialBounds :: String -> ValidTraverserData -> Py.PyStmt
|
||
|
translateInitialBounds s vtd =
|
||
|
case (validBounds vtd, validRev vtd) of
|
||
|
(Random, _) -> traverserRandom s $ validList vtd
|
||
|
(Range l _, False) -> Py.Assign (Py.VarPat s) l
|
||
|
(Range _ r, True) -> Py.Assign (Py.VarPat s) r
|
||
|
|
||
|
translatePat :: Pat -> Translator Py.PyPat
|
||
|
translatePat (VarPat s) = clearTraverser s $> Py.VarPat s
|
||
|
translatePat (TuplePat ts) = Py.TuplePat <$> mapM translatePat ts
|
||
|
|
||
|
translateOp :: Op -> Py.PyBinOp
|
||
|
translateOp Add = Py.Add
|
||
|
translateOp Subtract = Py.Subtract
|
||
|
translateOp Multiply = Py.Multiply
|
||
|
translateOp Divide = Py.Divide
|
||
|
translateOp LessThan = Py.LessThan
|
||
|
translateOp LessThanEqual = Py.LessThanEq
|
||
|
translateOp GreaterThan = Py.GreaterThan
|
||
|
translateOp GreaterThanEqual = Py.GreaterThanEq
|
||
|
translateOp Equal = Py.Equal
|
||
|
translateOp NotEqual = Py.NotEqual
|
||
|
translateOp And = Py.And
|
||
|
translateOp Or = Py.Or
|
||
|
|
||
|
translateFunction :: Function -> [Py.PyStmt]
|
||
|
translateFunction (Function m s ps ss) = return $ Py.FunctionDef s ps $
|
||
|
[ Py.Standalone $ Py.FunctionCall (Py.Member (Py.Var p) "sort") []
|
||
|
| p <- take 1 ps, m == Sorted ] ++ stmts
|
||
|
where
|
||
|
stmts = concat $ evalState
|
||
|
(mapM (withdrawStatements . translateStmt) ss) (Map.empty, [], 0)
|
||
|
|
||
|
translate :: Prog -> [Py.PyStmt]
|
||
|
translate (Prog fs) =
|
||
|
(Py.FromImport "bisect" ["bisect"]) :
|
||
|
(Py.Import "random") : concatMap translateFunction fs
|