Sunday, 15 August 2010

Hindley Milner Type Inference Sample Implementation

// Learn more about F# at http://fsharp.net

// Hindley Milner Type Inference Sample Implementation

// ************************************************
// AST
// ************************************************

type Literal
   = Char  of  char          // character literal
   | String of string        // string literal
   | Integer of  int         // integer literal
   | Float  of  float        // floating point literal   

type Exp
    = Var      of string               // variable   
    | Lam      of string * Exp         // lambda abstraction
    | App      of Exp * Exp            // application   
    | InfixApp of Exp * string * Exp   // infix application
    | Let      of string * Exp * Exp   // local definition   
    | Lit      of Literal              // literal

// ************************************************
// Type Tree
// ************************************************

type Type
    = TyLam of Type * Type
    | TyVar of string
    | TyCon of string * Type list
    with override this.ToString() =
            match this with
            | TyLam (t1, t2) -> sprintf "(%s -> %s)" (t1.ToString()) (t2.ToString())
            | TyVar a -> a
            | TyCon (s, ts) -> s

// ************************************************
// Subtitutions
// ************************************************

type Subst = Subst of Map<string, Type>

// extend: string -> Type -> Subst -> Subst
let extend v t (Subst s) = Subst (Map.add v t s)

// lookup: string -> Subst -> Type
let lookup v (Subst s) =
    if (Map.containsKey v s) then
        Map.find v s
    else
        TyVar v

// Apply a type substitution to a type
// subs: Type -> Subst -> Type
let rec subs t s =
    match t with
    | TyVar n ->
          let t' = lookup n s
          if t = t' then t'
          else subs t' s   
    | TyLam(a, r) ->
          TyLam(subs a s, subs r s)
    | TyCon(name, tyArgs) ->
            TyCon(name, tyArgs
            |> List.map (fun x -> subs x s))

// ************************************************
// Environments
// ************************************************

type TyScheme = TyScheme of Type * Set<string>

type Env = Env of Map<string, TyScheme>

// Calculate the list of type variables occurring in a type,
// without repeats
// getTVarsOfType: Type -> Set<string>
let rec getTVarsOfType t =
    match t with
    | TyVar n  -> Set.singleton n
    | TyLam(t1, t2) -> getTVarsOfType(t1) + getTVarsOfType(t2)
    | TyCon(_, args) ->
        List.fold (fun acc t -> acc + getTVarsOfType t) Set.empty args

// getTVarsOfScheme: TyScheme -> Set<string>
let getTVarsOfScheme (TyScheme (t, tvars)) =
    (getTVarsOfType t) - tvars

// getTVarsOfEnv: Env -> Set<string>
let getTVarsOfEnv (Env e) =
    let schemes = e |> Map.toSeq |> Seq.map snd
    Seq.fold (fun acc s -> acc + (getTVarsOfScheme s)) Set.empty schemes

// ************************************************************
// Unification
// ************************************************************

exception UnificationError of Type * Type

// Calculate the most general unifier of two types,
// raising a UnificationError if there isn't one
// mgu: Type -> Type -> Subst -> Subst
let rec mgu a b s =
    match (subs a s, subs b s) with
    | TyVar ta, TyVar tb when ta = tb -> s
    | TyVar ta, _ when not <| Set.contains ta (getTVarsOfType b) -> extend ta b s
    | _, TyVar _ -> mgu b a s
    | TyLam (a1, b1), TyLam (a2, b2) -> mgu a1 a2 (mgu b1 b2 s)
    | TyCon(name1, args1), TyCon(name2, args2) when name1 = name2 ->
            (s, args1, args2) |||> List.fold2 (fun subst t1 t2 -> mgu t1 t2 subst)
    | x,y -> raise <| UnificationError (x,y)

// ************************************************************
// State Monad
// ************************************************************

type State<'state, 'a> = State of ('state ->'a * 'state)

let run (State f) s = f s

type StateMonad() =
    member b.Bind(State m, f) =
      State (fun s -> 
               let (v,s') = m s in
               let (State n) = f v in n s')                                                   
    member b.Return x = State (fun s -> x, s)

    member b.ReturnFrom x = x

    member b.Zero () = State (fun s -> (), s)

    member b.Combine(r1, r2) = b.Bind(r1, fun () -> r2)

    member b.Delay f = State (fun s -> run (f()) s)



let state = StateMonad()

let getState = State (fun s -> s, s)
let setState s = State (fun _ -> (), s) 

let execute m s = match m with
                  | State f -> let r = f s
                               match r with
                               |(x,_) -> x

let mmap f xs =
    let rec MMap' (f, xs', out) =
          state { match xs' with
                  | h :: t ->
                      let! h' = f(h)
                      return! MMap'(f, t, List.append out [h'])
                  | [] -> return out }
    MMap' (f, xs, [])

// ************************************************************
// Alpha converter (Converts T4, T5, T6 to 'a, 'b, 'c)
// ************************************************************

let getName k =
    let containsKey k = state { let! (map, id) = getState
                                return Map.containsKey k map }

    let addName k = state { let! (map, id) = getState
                            let newid = char (int id + 1)
                            do! setState(Map.add k id map, newid)
                            return () }

    state { let! success = containsKey k
            if (not success) then
                do! addName k
            let! (map, id) = getState
            return Map.find k map }

// renameTVarsToLetters: Type -> Type
let renameTVarsToLetters t =   
    let rec run x =
        state {
                match x with
                | TyVar(name) ->
                    let! newName = getName name            
                    return TyVar(sprintf "'%c" newName)
                | TyLam(arg, res) ->
                    let! t1 = run arg
                    let! t2 = run res
                    return TyLam(t1, t2)
                | TyCon(name, typeArgs) ->
                    let! list = mmap (fun x -> run x) typeArgs
                    return TyCon(name, list) }
    execute (run t) (Map.empty, 'a')

// ****************************************************************
// Calculate principal Type
// ****************************************************************

let newTyVar =
    state { let! x = getState
            do! setState(x + 1)
            return TyVar(sprintf "T%d" x) }

let integerCon = TyCon("int", [])
let floatCon = TyCon("float", [])
let charCon = TyCon("char", [])
let stringCon = TyCon("string", [])

let litToTy lit =
    match lit with
    | Integer _ -> integerCon
    | Float _ -> floatCon
    | Char _ -> charCon
    | String  _ -> stringCon

// Calculate the principal type scheme for an expression in a given
// typing environment
// tp: Env -> Exp -> Type -> Subst -> State<int, Subst>
let rec tp env exp bt s =
  let findSc n (Env e) = Map.find n e
  let containsSc n (Env e) = Map.containsKey n e
  let addSc n sc (Env e) = Env (Map.add n sc e)
  state {
        match exp with
        | Lit v -> 
            return mgu (litToTy v) bt s
        | Var n -> 
            if not (containsSc n env)
            then failwith "Name %s no found" n
            let (TyScheme (t, _)) = findSc n env
            return mgu (subs t s) bt s
        | Lam (x, e) -> 
            let! a = newTyVar
            let! b = newTyVar
            let s1 = mgu bt (TyLam (a, b)) s
            let newEnv = addSc x (TyScheme (a, Set.empty)) env
            return! tp newEnv e b s1
        | App(e1, e2) ->
            let! a = newTyVar
            let! s1 = tp env e1 (TyLam(a, bt)) s
            return! tp env e2 a s1
        | InfixApp(e1, op, e2) ->
            let exp1 = App(App(Var op, e1), e2)
            return! tp env exp1 bt s
        | Let(name, inV, body) ->
            let! a = newTyVar
            let! s1 = tp env inV a s
            let t = subs a s1
            let newScheme = TyScheme (t, ((getTVarsOfType t) - (getTVarsOfEnv env)))
            return! tp (addSc name newScheme env) body bt s1 }

let predefinedEnv =
    Env(["+", TyScheme (TyLam(integerCon, TyLam(integerCon, integerCon)), Set.empty)
         "*", TyScheme (TyLam(integerCon, TyLam(integerCon, integerCon)), Set.empty)
         "-", TyScheme (TyLam(integerCon, TyLam(integerCon, integerCon)), Set.empty)
           ] |> Map.ofList)

// typeOf: Exp -> Type
let typeOf exp =
   let typeOf' exp =
    state { let! (a:Type) = newTyVar
            let emptySubst = Subst (Map.empty)
            let! s1 = tp predefinedEnv exp a emptySubst
            return subs a s1 }
   in execute (typeOf' exp) 0 |> renameTVarsToLetters

// ***********************************************************
// Example
// ***********************************************************
   
let composeAst = Let("compose",
                    Lam("f",
                        Lam("g",
                            Lam ("x",
                                App(Var "g", App(Var "f", Var "x"))))),
                        Var "compose")

let t = typeOf composeAst

printfn "%s" (t.ToString())

1 comment:

Anonymous said...

Great fun, and more fun is that I have programmed something very much along the same lines. I have found that the following two functions are great to use:

let mapState f = State(fun s-> (), f s)
let foldState f = State (fun s -> f s)

Having defined operators on the "environement" that is carried by the state monad, these two operators can be used to convert them into state monad compatible form. For example, having newVar: Env -> (Type, Env), then I can write:
let newVar' = foldState newVar and in the computation expression write:
let! v1 = newVar'

James