module LanguageThree where import qualified CommonParsing as P import qualified PythonAst as Py import Control.Monad.State import Data.Bifunctor import Data.Foldable import Data.Functor import qualified Data.Map as Map import Data.Maybe import Text.Parsec hiding (State) import Text.Parsec.Char import Text.Parsec.Combinator {- Data Types -} data Op = Add | Subtract | Multiply | Divide | LessThan | LessThanEqual | GreaterThan | GreaterThanEqual | Equal | NotEqual | And | Or data Expr = TraverserCall String [Expr] | FunctionCall String [Expr] | BinOp Op Expr Expr | Lambda [String] Expr | Var String | IntLiteral Int | BoolLiteral Bool | ListLiteral [Expr] | TupleLiteral [Expr] type Branch = (Expr, [Stmt]) data Stmt = IfElse Branch [Branch] [Stmt] | While Branch | Traverser String [(String, Expr)] | Let Pat Expr | Return Expr | Standalone Expr data Pat = VarPat String | TuplePat [Pat] data SortedMarker = Sorted | Unsorted deriving Eq data Function = Function SortedMarker String [String] [Stmt] data Prog = Prog [Function] {- Parser -} type Parser = Parsec String () parseVar :: Parser String parseVar = P.var [ "if", "elif", "else" , "while", "let", "traverser" , "function", "sort" , "true", "false" ] parseBool :: Parser Bool parseBool = (string "true" $> True) <|> (string "false" $> False) parseList :: Parser Expr parseList = ListLiteral <$> P.list '[' ']' ',' parseExpr parseTupleElems :: Parser [Expr] parseTupleElems = P.list '(' ')' ',' parseExpr parseTuple :: Parser Expr parseTuple = do es <- parseTupleElems return $ case es of e:[] -> e _ -> TupleLiteral es parseLambda :: Parser Expr parseLambda = try $ do vs <- P.list '(' ')' ',' parseVar string "->" >> spaces Lambda vs <$> parseExpr parseCall :: Parser Expr parseCall = try $ do v <- parseVar choice [ TraverserCall v <$> (char '!' *> parseTupleElems) , FunctionCall v <$> parseTupleElems ] parseBasic :: Parser Expr parseBasic = choice [ IntLiteral <$> P.int , BoolLiteral <$> parseBool , try parseCall , Var <$> parseVar , parseList , parseLambda , parseTuple ] parseExpr :: Parser Expr parseExpr = P.precedence BinOp parseBasic [ P.op "*" Multiply <|> P.op "/" Divide , P.op "+" Add <|> P.op "-" Subtract , P.op "==" Equal <|> P.op "!=" NotEqual <|> try (P.op "<=" LessThanEqual) <|> P.op "<" LessThan <|> try (P.op ">=" GreaterThanEqual) <|> P.op ">" GreaterThan , P.op "and" And , P.op "or" Or ] parseBlock :: Parser [Stmt] parseBlock = char '{' >> spaces >> many parseStmt <* char '}' <* spaces parseBranch :: Parser Branch parseBranch = (,) <$> (parseExpr <* spaces) <*> parseBlock parseIf :: Parser Stmt parseIf = do i <- P.kwIf >> parseBranch els <- many (P.kwElsif >> parseBranch) e <- try (P.kwElse >> parseBlock) <|> return [] return $ IfElse i els e parseWhile :: Parser Stmt parseWhile = While <$> (P.kwWhile >> parseBranch) parseTraverser :: Parser Stmt parseTraverser = Traverser <$> (P.kwTraverser *> parseVar) <*> (P.list '(' ')' ',' parseKey) <* char ';' <* spaces parseKey :: Parser (String, Expr) parseKey = (,) <$> (parseVar <* spaces <* char ':' <* spaces) <*> parseExpr parseLet :: Parser Stmt parseLet = Let <$> (P.kwLet >> parsePat <* char '=' <* spaces) <*> parseExpr <* char ';' <* spaces parseReturn :: Parser Stmt parseReturn = Return <$> (P.kwReturn >> parseExpr <* char ';' <* spaces) parsePat :: Parser Pat parsePat = (VarPat <$> parseVar) <|> (TuplePat <$> P.list '(' ')' ',' parsePat) parseStmt :: Parser Stmt parseStmt = choice [ parseTraverser , parseLet , parseIf , parseWhile , parseReturn , Standalone <$> (parseExpr <* char ';' <* spaces) ] parseFunction :: Parser Function parseFunction = Function <$> (P.kwSorted $> Sorted <|> return Unsorted) <*> (P.kwFunction >> parseVar) <*> (P.list '(' ')' ',' parseVar) <*> parseBlock parseProg :: Parser Prog parseProg = Prog <$> many parseFunction parse :: String -> String -> Either ParseError Prog parse = runParser parseProg () {- Translation -} data TraverserBounds = Range Py.PyExpr Py.PyExpr | Random data TraverserData = TraverserData { list :: Maybe String , bounds :: Maybe TraverserBounds , rev :: Bool } data ValidTraverserData = ValidTraverserData { validList :: String , validBounds :: TraverserBounds , validRev :: Bool } type Translator = State (Map.Map String ValidTraverserData, [Py.PyStmt], Int) getScoped :: Translator (Map.Map String ValidTraverserData) getScoped = gets (\(m, _, _) -> m) setScoped :: Map.Map String ValidTraverserData -> Translator () setScoped m = modify (\(_, ss, i) -> (m, ss, i)) scope :: Translator a -> Translator a scope m = do s <- getScoped a <- m setScoped s return a clearTraverser :: String -> Translator () clearTraverser s = modify (\(m, ss, i) -> (Map.delete s m, ss, i)) putTraverser :: String -> ValidTraverserData -> Translator () putTraverser s vtd = modify (\(m, ss, i) -> (Map.insert s vtd m, ss, i)) getTemp :: Translator String getTemp = gets $ \(_, _, i) -> "temp" ++ show i freshTemp :: Translator String freshTemp = modify (second (+1)) >> getTemp emitStatement :: Py.PyStmt -> Translator () emitStatement = modify . first . (:) collectStatements :: Translator a -> Translator ([Py.PyStmt], a) collectStatements t = do modify (first $ const []) a <- t ss <- gets $ \(_, ss, _) -> ss modify (first $ const []) return (ss, a) withdrawStatements :: Translator (Py.PyStmt) -> Translator [Py.PyStmt] withdrawStatements ts = (\(ss, s) -> ss ++ [s]) <$> (collectStatements ts) requireTraverser :: String -> Translator ValidTraverserData requireTraverser s = gets (\(m, _, _) -> Map.lookup s m) >>= handleMaybe where handleMaybe Nothing = fail "Invalid traverser" handleMaybe (Just vtd) = return vtd traverserIncrement :: Bool -> Py.PyExpr -> Py.PyExpr -> Py.PyExpr traverserIncrement rev by e = Py.BinOp op e (Py.BinOp Py.Multiply by (Py.IntLiteral 1)) where op = if rev then Py.Subtract else Py.Add traverserValid :: Py.PyExpr -> ValidTraverserData -> Py.PyExpr traverserValid e vtd = case validBounds vtd of Range f t -> if validRev vtd then Py.BinOp Py.GreaterThanEq e f else Py.BinOp Py.LessThan e t Random -> Py.BoolLiteral True traverserStep :: String -> ValidTraverserData -> Py.PyStmt traverserStep s vtd = case validBounds vtd of Range _ _ -> Py.Assign (Py.VarPat s) $ Py.BinOp op (Py.Var s) (Py.IntLiteral 1) where op = if validRev vtd then Py.Subtract else Py.Add Random -> traverserRandom s $ validList vtd traverserRandom :: String -> String -> Py.PyStmt traverserRandom s l = Py.Assign (Py.VarPat s) $ Py.FunctionCall (Py.Var "random.randrange") [Py.FunctionCall (Py.Var "len") [Py.Var l]] hasVar :: String -> Py.PyPat -> Bool hasVar s (Py.VarPat s') = s == s' hasVar s (Py.TuplePat ps) = any (hasVar s) ps hasVar s _ = False substituteVariable :: String -> Py.PyExpr -> Py.PyExpr -> Py.PyExpr substituteVariable s e (Py.BinOp o l r) = Py.BinOp o (substituteVariable s e l) (substituteVariable s e r) substituteVariable s e (Py.ListLiteral es) = Py.ListLiteral $ map (substituteVariable s e) es substituteVariable s e (Py.DictLiteral es) = Py.DictLiteral $ map (first (substituteVariable s e) . second (substituteVariable s e)) es substituteVariable s e (Py.Lambda ps e') = Py.Lambda ps $ if any (hasVar s) ps then substituteVariable s e e' else e' substituteVariable s e (Py.Var s') | s == s' = e | otherwise = Py.Var s' substituteVariable s e (Py.TupleLiteral es) = Py.TupleLiteral $ map (substituteVariable s e) es substituteVariable s e (Py.FunctionCall e' es) = Py.FunctionCall (substituteVariable s e e') $ map (substituteVariable s e) es substituteVariable s e (Py.Access e' es) = Py.Access (substituteVariable s e e') $ map (substituteVariable s e) es substituteVariable s e (Py.Ternary i t e') = Py.Ternary (substituteVariable s e i) (substituteVariable s e t) (substituteVariable s e e') substituteVariable s e (Py.Member e' m) = Py.Member (substituteVariable s e e') m substituteVariable s e (Py.In e1 e2) = Py.In (substituteVariable s e e1) (substituteVariable s e e2) substituteVariable s e (Py.NotIn e1 e2) = Py.NotIn (substituteVariable s e e1) (substituteVariable s e e2) substituteVariable s e (Py.Slice f t) = Py.Slice (substituteVariable s e <$> f) (substituteVariable s e <$> t) translateExpr :: Expr -> Translator Py.PyExpr translateExpr (TraverserCall "pop" [Var s]) = do l <- validList <$> requireTraverser s return $ Py.FunctionCall (Py.Member (Py.Var l) "pop") [Py.Var s] translateExpr (TraverserCall "pos" [Var s]) = do requireTraverser s return $ Py.Var s translateExpr (TraverserCall "at" [Var s]) = do l <- validList <$> requireTraverser s return $ Py.Access (Py.Var l) [Py.Var s] translateExpr (TraverserCall "at" [Var s, IntLiteral i]) = do vtd <- requireTraverser s return $ Py.Access (Py.Var $ validList vtd) [traverserIncrement (validRev vtd) (Py.IntLiteral i) (Py.Var s)] translateExpr (TraverserCall "step" [Var s]) = do vtd <- requireTraverser s emitStatement $ traverserStep s vtd return $ Py.IntLiteral 0 translateExpr (TraverserCall "canstep" [Var s]) = do vtd <- requireTraverser s return $ traverserValid (traverserIncrement (validRev vtd) (Py.IntLiteral 1) (Py.Var s)) vtd translateExpr (TraverserCall "valid" [Var s]) = do vtd <- requireTraverser s return $ traverserValid (Py.Var s) vtd translateExpr (TraverserCall "subset" [Var s1, Var s2]) = do l1 <- validList <$> requireTraverser s1 l2 <- validList <$> requireTraverser s2 if l1 == l2 then return $ Py.Access (Py.Var l1) [Py.Slice (Just $ Py.Var s1) (Just $ Py.Var s2)] else fail "Incompatible traversers!" translateExpr (TraverserCall "bisect" [Var s, Lambda [x] e]) = do vtd <- requireTraverser s newTemp <- freshTemp lambdaExpr <- translateExpr e let access = Py.Access (Py.Var $ validList vtd) [Py.Var s] let translated = substituteVariable x access lambdaExpr let append s = Py.FunctionCall (Py.Member (Py.Var s) "append") [ access ] let bisectStmt = Py.FunctionDef newTemp [] [ Py.Nonlocal [s] , Py.Assign (Py.VarPat "l") (Py.ListLiteral []) , Py.Assign (Py.VarPat "r") (Py.ListLiteral []) , Py.While (traverserValid (Py.Var s) vtd) [ Py.IfElse translated [ Py.Standalone $ append "l" ] [] (Just [ Py.Standalone $ append "r" ]) , traverserStep s vtd ] , Py.Return $ Py.TupleLiteral [Py.Var "l", Py.Var "r"] ] emitStatement bisectStmt return $ Py.FunctionCall (Py.Var newTemp) [] translateExpr (TraverserCall _ _) = fail "Invalid traverser operation" translateExpr (FunctionCall f ps) = do pes <- mapM translateExpr ps return $ Py.FunctionCall (Py.Var f) pes translateExpr (BinOp o l r) = Py.BinOp (translateOp o) <$> translateExpr l <*> translateExpr r translateExpr (Lambda ps e) = Py.Lambda (map Py.VarPat ps) <$> translateExpr e translateExpr (Var s) = return $ Py.Var s translateExpr (IntLiteral i) = return $ Py.IntLiteral i translateExpr (BoolLiteral b) = return $ Py.BoolLiteral b translateExpr (ListLiteral es) = Py.ListLiteral <$> mapM translateExpr es translateExpr (TupleLiteral es) = Py.TupleLiteral <$> mapM translateExpr es applyOption :: TraverserData -> (String, Py.PyExpr) -> Maybe TraverserData applyOption td ("list", Py.Var s) = return $ td { list = Just s } applyOption td ("span", Py.TupleLiteral [f, t]) = return $ td { bounds = Just $ Range f t } applyOption td ("random", Py.BoolLiteral True) = return $ td { bounds = Just Random } applyOption td ("reverse", Py.BoolLiteral b) = return $ td { rev = b } applyOption td _ = Nothing translateOption :: (String, Expr) -> Translator (String, Py.PyExpr) translateOption (s, e) = (,) s <$> translateExpr e defaultTraverser :: TraverserData defaultTraverser = TraverserData { list = Nothing, bounds = Nothing, rev = False } translateBranch :: Branch -> Translator (Py.PyExpr, [Py.PyStmt]) translateBranch (e, s) = (,) <$> translateExpr e <*> (concat <$> mapM (withdrawStatements . translateStmt) s) translateStmt :: Stmt -> Translator Py.PyStmt translateStmt (IfElse i els e) = uncurry Py.IfElse <$> (translateBranch i) <*> (mapM translateBranch els) <*> convertElse e where convertElse [] = return Nothing convertElse es = Just . concat <$> mapM (withdrawStatements . translateStmt) es translateStmt (While b) = uncurry Py.While <$> translateBranch b translateStmt (Traverser s os) = foldlM applyOption defaultTraverser <$> mapM translateOption os >>= saveTraverser where saveTraverser :: Maybe TraverserData -> Translator Py.PyStmt saveTraverser (Just (td@TraverserData { list = Just l, bounds = Just bs})) = putTraverser s vtd $> translateInitialBounds s vtd where vtd = ValidTraverserData { validList = l , validBounds = bs , validRev = rev td } saveTraverser Nothing = fail "Invalid traverser (!)" translateStmt (Let p e) = Py.Assign <$> translatePat p <*> translateExpr e translateStmt (Return e) = Py.Return <$> translateExpr e translateStmt (Standalone e) = Py.Standalone <$> translateExpr e translateInitialBounds :: String -> ValidTraverserData -> Py.PyStmt translateInitialBounds s vtd = case (validBounds vtd, validRev vtd) of (Random, _) -> traverserRandom s $ validList vtd (Range l _, False) -> Py.Assign (Py.VarPat s) l (Range _ r, True) -> Py.Assign (Py.VarPat s) r translatePat :: Pat -> Translator Py.PyPat translatePat (VarPat s) = clearTraverser s $> Py.VarPat s translatePat (TuplePat ts) = Py.TuplePat <$> mapM translatePat ts translateOp :: Op -> Py.PyBinOp translateOp Add = Py.Add translateOp Subtract = Py.Subtract translateOp Multiply = Py.Multiply translateOp Divide = Py.Divide translateOp LessThan = Py.LessThan translateOp LessThanEqual = Py.LessThanEq translateOp GreaterThan = Py.GreaterThan translateOp GreaterThanEqual = Py.GreaterThanEq translateOp Equal = Py.Equal translateOp NotEqual = Py.NotEqual translateOp And = Py.And translateOp Or = Py.Or translateFunction :: Function -> [Py.PyStmt] translateFunction (Function m s ps ss) = return $ Py.FunctionDef s ps $ [ Py.Standalone $ Py.FunctionCall (Py.Member (Py.Var p) "sort") [] | p <- take 1 ps, m == Sorted ] ++ stmts where stmts = concat $ evalState (mapM (withdrawStatements . translateStmt) ss) (Map.empty, [], 0) translate :: Prog -> [Py.PyStmt] translate (Prog fs) = (Py.FromImport "bisect" ["bisect"]) : (Py.Import "random") : concatMap translateFunction fs