meo/src/Type.hs
2023-12-18 18:29:13 +08:00

149 lines
3.5 KiB
Haskell

{-# LANGUAGE PatternSynonyms #-}
module Type where
import Control.Monad.Except
import Control.Monad.RWS as R
( MonadReader (ask, local),
MonadState (get, put),
MonadWriter (tell),
RWST,
)
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Premitive as P
type Name = String
newtype TVar = TV Name deriving (Eq, Show, Ord)
data Type
= TCon Name [Type]
| TVar TVar
deriving (Show, Eq, Ord)
pattern TInt = TCon "Int" []
pattern TBool = TCon "Bool" []
pattern a :-> b = TCon "->" [a, b]
pattern TNumber = TCon "Number" []
pattern TList a = TCon "List" [a]
pattern TNil = TCon "Nil" []
data Scheme = Forall (Set.Set TVar) Type
data Constraint = Constraint Type Type
type Context = Map.Map Name Scheme
type Count = Int
type Constraints = [Constraint]
type Infer a = RWST Context Constraints Count (Except String) a
-- constrain :: Type -> Type -> Infer ()
-- constrain t1 t2 = tell [Constraint t1 t2]
fresh :: Infer Type
fresh = do
count <- get
put (count + 1)
return . TVar . TV $ show count
type Subst = Map.Map TVar Type
compose :: Subst -> Subst -> Subst
compose a b = Map.map (apply a) (b `Map.union` a)
class Substitutable a where
apply :: Subst -> a -> a
tvs :: a -> Set.Set TVar
instance Substitutable Type where
tvs (TVar tv) = Set.singleton tv
tvs (TCon _ ts) = foldr (Set.union . tvs) Set.empty ts
apply s t@(TVar tv) = Map.findWithDefault t tv s
apply s (TCon c ts) = TCon c $ map (apply s) ts
instance Substitutable Scheme where
tvs (Forall vs t) = tvs t `Set.difference` vs
apply s (Forall vs t) = Forall vs $ apply (foldr Map.delete s vs) t
instance Substitutable Constraint where
tvs (Constraint t1 t2) = tvs t1 `Set.union` tvs t2
apply s (Constraint t1 t2) = Constraint (apply s t1) (apply s t2)
instance (Substitutable a) => Substitutable [a] where
apply s = map (apply s)
tvs = foldr (Set.union . tvs) Set.empty
generalize :: Context -> Type -> Scheme
generalize ctx t = Forall (tvs t `Set.difference` tvs (Map.elems ctx)) t
instantiate :: Scheme -> Infer Type
instantiate (Forall vs t) = do
let vars = Set.toList vs
ftvs <- traverse (const fresh) vars
let subst = Map.fromList (zip vars ftvs)
return $ apply subst t
constrain :: Type -> Type -> Infer ()
constrain a b = tell [Constraint a b]
infer :: P.Expr -> Infer Type
infer e = case e of
P.Num _ -> return TInt
P.Bool _ -> return TBool -- Boolean literal
P.IfExpr e1 e2 e3 -> do
-- If expression
t1 <- infer e1
constrain t1 TBool
t2 <- infer e2
t3 <- infer e3
constrain t2 t3
return t2
P.Bin op e1 e2 -> do
-- Binary operator
t1 <- infer e1
t2 <- infer e2
constrain t1 TInt
constrain t2 TInt
case op of
"+" -> return TInt
"-" -> return TInt
"*" -> return TInt
"/" -> return TInt
_ -> throwError $ "unknown operator: " ++ op
P.Con e1 e2 -> do
t1 <- infer e1
t2 <- infer e2
constrain t1 t2
return t2
P.Nil -> do
return TNil
P.List es -> do
ts <- traverse infer es
return $ TList (head ts)
P.Symbol s -> do
ctx <- ask
case Map.lookup s ctx of
Nothing -> throwError $ "unbound variable: " ++ s
Just s' -> instantiate s'
P.Declare names es e -> do
ts <- traverse infer es
let ctx = Map.fromList $ zip names (map (generalize Map.empty) ts)
local (Map.union ctx) (infer e)
P.Block s e -> do
t2 <- infer e
return t2
P.Process _ _ -> do
return TNil
P.Pipeln _ _ -> do
return TNil
P.Str _ -> do
return TNil