{-# LANGUAGE PatternSynonyms #-} module Type where import Control.Monad.Except import Control.Monad.RWS as R ( MonadReader (ask, local), MonadState (get, put), MonadWriter (tell), RWST, ) import qualified Data.Map as Map import qualified Data.Set as Set import qualified Premitive as P type Name = String newtype TVar = TV Name deriving (Eq, Show, Ord) data Type = TCon Name [Type] | TVar TVar deriving (Show, Eq, Ord) pattern TInt = TCon "Int" [] pattern TBool = TCon "Bool" [] pattern a :-> b = TCon "->" [a, b] pattern TNumber = TCon "Number" [] pattern TList a = TCon "List" [a] pattern TNil = TCon "Nil" [] data Scheme = Forall (Set.Set TVar) Type data Constraint = Constraint Type Type type Context = Map.Map Name Scheme type Count = Int type Constraints = [Constraint] type Infer a = RWST Context Constraints Count (Except String) a -- constrain :: Type -> Type -> Infer () -- constrain t1 t2 = tell [Constraint t1 t2] fresh :: Infer Type fresh = do count <- get put (count + 1) return . TVar . TV $ show count type Subst = Map.Map TVar Type compose :: Subst -> Subst -> Subst compose a b = Map.map (apply a) (b `Map.union` a) class Substitutable a where apply :: Subst -> a -> a tvs :: a -> Set.Set TVar instance Substitutable Type where tvs (TVar tv) = Set.singleton tv tvs (TCon _ ts) = foldr (Set.union . tvs) Set.empty ts apply s t@(TVar tv) = Map.findWithDefault t tv s apply s (TCon c ts) = TCon c $ map (apply s) ts instance Substitutable Scheme where tvs (Forall vs t) = tvs t `Set.difference` vs apply s (Forall vs t) = Forall vs $ apply (foldr Map.delete s vs) t instance Substitutable Constraint where tvs (Constraint t1 t2) = tvs t1 `Set.union` tvs t2 apply s (Constraint t1 t2) = Constraint (apply s t1) (apply s t2) instance (Substitutable a) => Substitutable [a] where apply s = map (apply s) tvs = foldr (Set.union . tvs) Set.empty generalize :: Context -> Type -> Scheme generalize ctx t = Forall (tvs t `Set.difference` tvs (Map.elems ctx)) t instantiate :: Scheme -> Infer Type instantiate (Forall vs t) = do let vars = Set.toList vs ftvs <- traverse (const fresh) vars let subst = Map.fromList (zip vars ftvs) return $ apply subst t constrain :: Type -> Type -> Infer () constrain a b = tell [Constraint a b] infer :: P.Expr -> Infer Type infer e = case e of P.Num _ -> return TInt P.Bool _ -> return TBool -- Boolean literal P.IfExpr e1 e2 e3 -> do -- If expression t1 <- infer e1 constrain t1 TBool t2 <- infer e2 t3 <- infer e3 constrain t2 t3 return t2 P.Bin op e1 e2 -> do -- Binary operator t1 <- infer e1 t2 <- infer e2 constrain t1 TInt constrain t2 TInt case op of "+" -> return TInt "-" -> return TInt "*" -> return TInt "/" -> return TInt _ -> throwError $ "unknown operator: " ++ op P.Con e1 e2 -> do t1 <- infer e1 t2 <- infer e2 constrain t1 t2 return t2 P.Nil -> do return TNil P.List es -> do ts <- traverse infer es return $ TList (head ts) P.Symbol s -> do ctx <- ask case Map.lookup s ctx of Nothing -> throwError $ "unbound variable: " ++ s Just s' -> instantiate s' P.Declare names es e -> do ts <- traverse infer es let ctx = Map.fromList $ zip names (map (generalize Map.empty) ts) local (Map.union ctx) (infer e) P.Block s e -> do t2 <- infer e return t2 P.Process _ _ -> do return TNil P.Pipeln _ _ -> do return TNil P.Str _ -> do return TNil