149 lines
3.5 KiB
Haskell
149 lines
3.5 KiB
Haskell
{-# LANGUAGE PatternSynonyms #-}
|
|
|
|
module Type where
|
|
|
|
import Control.Monad.Except
|
|
import Control.Monad.RWS as R
|
|
( MonadReader (ask, local),
|
|
MonadState (get, put),
|
|
MonadWriter (tell),
|
|
RWST,
|
|
)
|
|
import qualified Data.Map as Map
|
|
import qualified Data.Set as Set
|
|
import qualified Premitive as P
|
|
|
|
type Name = String
|
|
|
|
newtype TVar = TV Name deriving (Eq, Show, Ord)
|
|
|
|
data Type
|
|
= TCon Name [Type]
|
|
| TVar TVar
|
|
deriving (Show, Eq, Ord)
|
|
|
|
pattern TInt = TCon "Int" []
|
|
|
|
pattern TBool = TCon "Bool" []
|
|
|
|
pattern a :-> b = TCon "->" [a, b]
|
|
|
|
pattern TNumber = TCon "Number" []
|
|
|
|
pattern TList a = TCon "List" [a]
|
|
|
|
pattern TNil = TCon "Nil" []
|
|
|
|
data Scheme = Forall (Set.Set TVar) Type
|
|
|
|
data Constraint = Constraint Type Type
|
|
|
|
type Context = Map.Map Name Scheme
|
|
|
|
type Count = Int
|
|
|
|
type Constraints = [Constraint]
|
|
|
|
type Infer a = RWST Context Constraints Count (Except String) a
|
|
|
|
-- constrain :: Type -> Type -> Infer ()
|
|
-- constrain t1 t2 = tell [Constraint t1 t2]
|
|
|
|
fresh :: Infer Type
|
|
fresh = do
|
|
count <- get
|
|
put (count + 1)
|
|
return . TVar . TV $ show count
|
|
|
|
type Subst = Map.Map TVar Type
|
|
|
|
compose :: Subst -> Subst -> Subst
|
|
compose a b = Map.map (apply a) (b `Map.union` a)
|
|
|
|
class Substitutable a where
|
|
apply :: Subst -> a -> a
|
|
tvs :: a -> Set.Set TVar
|
|
|
|
instance Substitutable Type where
|
|
tvs (TVar tv) = Set.singleton tv
|
|
tvs (TCon _ ts) = foldr (Set.union . tvs) Set.empty ts
|
|
apply s t@(TVar tv) = Map.findWithDefault t tv s
|
|
apply s (TCon c ts) = TCon c $ map (apply s) ts
|
|
|
|
instance Substitutable Scheme where
|
|
tvs (Forall vs t) = tvs t `Set.difference` vs
|
|
apply s (Forall vs t) = Forall vs $ apply (foldr Map.delete s vs) t
|
|
|
|
instance Substitutable Constraint where
|
|
tvs (Constraint t1 t2) = tvs t1 `Set.union` tvs t2
|
|
apply s (Constraint t1 t2) = Constraint (apply s t1) (apply s t2)
|
|
|
|
instance (Substitutable a) => Substitutable [a] where
|
|
apply s = map (apply s)
|
|
tvs = foldr (Set.union . tvs) Set.empty
|
|
|
|
generalize :: Context -> Type -> Scheme
|
|
generalize ctx t = Forall (tvs t `Set.difference` tvs (Map.elems ctx)) t
|
|
|
|
instantiate :: Scheme -> Infer Type
|
|
instantiate (Forall vs t) = do
|
|
let vars = Set.toList vs
|
|
ftvs <- traverse (const fresh) vars
|
|
let subst = Map.fromList (zip vars ftvs)
|
|
return $ apply subst t
|
|
|
|
constrain :: Type -> Type -> Infer ()
|
|
constrain a b = tell [Constraint a b]
|
|
|
|
infer :: P.Expr -> Infer Type
|
|
infer e = case e of
|
|
P.Num _ -> return TInt
|
|
P.Bool _ -> return TBool -- Boolean literal
|
|
P.IfExpr e1 e2 e3 -> do
|
|
-- If expression
|
|
t1 <- infer e1
|
|
constrain t1 TBool
|
|
t2 <- infer e2
|
|
t3 <- infer e3
|
|
constrain t2 t3
|
|
return t2
|
|
P.Bin op e1 e2 -> do
|
|
-- Binary operator
|
|
t1 <- infer e1
|
|
t2 <- infer e2
|
|
constrain t1 TInt
|
|
constrain t2 TInt
|
|
case op of
|
|
"+" -> return TInt
|
|
"-" -> return TInt
|
|
"*" -> return TInt
|
|
"/" -> return TInt
|
|
_ -> throwError $ "unknown operator: " ++ op
|
|
P.Con e1 e2 -> do
|
|
t1 <- infer e1
|
|
t2 <- infer e2
|
|
constrain t1 t2
|
|
return t2
|
|
P.Nil -> do
|
|
return TNil
|
|
P.List es -> do
|
|
ts <- traverse infer es
|
|
return $ TList (head ts)
|
|
P.Symbol s -> do
|
|
ctx <- ask
|
|
case Map.lookup s ctx of
|
|
Nothing -> throwError $ "unbound variable: " ++ s
|
|
Just s' -> instantiate s'
|
|
P.Declare names es e -> do
|
|
ts <- traverse infer es
|
|
let ctx = Map.fromList $ zip names (map (generalize Map.empty) ts)
|
|
local (Map.union ctx) (infer e)
|
|
P.Block s e -> do
|
|
t2 <- infer e
|
|
return t2
|
|
P.Process _ _ -> do
|
|
return TNil
|
|
P.Pipeln _ _ -> do
|
|
return TNil
|
|
P.Str _ -> do
|
|
return TNil |