Wednesday, January 9, 2008

Extra type safety using polymorphic types as first-level refinements

This post is literate Haskell.

I will demonstrate a new technique for using polytypes as first-level refinement types. (mirror). The goal, as usual, is for types to better express program invariants and ensure programs are safe.

I'm going to demonstrate using the risers function, as presented in Dana N. Xu's ESC/Haskell (mirror), which references Neil Mitchell's Catch.

> {-# OPTIONS -fglasgow-exts #-}
> -- The LANGUAGE pragma is usually a pain for exploratory programming.

Below are the risers functions as presented by Xu and Mitchell. They are the same function, though slightly syntacticly different. risers returns the sorted sublists of a list.

Risers has two properties that we are going to discuss:

  1. None of the lists in the returned value are empty
  2. If the argument is non-empty, the return value is also non-empty.

> risersXu :: (Ord t) => [t] -> [[t]]
> risersXu [] = []
> risersXu [x] = [[x]]
> risersXu (x:y:etc) =
>     let ss = risersXu (y : etc)
>     in case x <= y of
>          True -> (x : (head ss)) : (tail ss)
>          False -> ([x]) : ss
> risersMitchell :: Ord a => [a] -> [[a]]
> risersMitchell [] = []
> risersMitchell [x] = [[x]]
> risersMitchell (x:y:etc) = if x <= y
>                            then (x:s):ss
>                            else [x]:(s:ss)
>     where (s:ss) = risersMitchell (y:etc)

Neither one of these functions is obviously safe. Xu uses head and tail and ESC/Haskell to create a proof of their safety. The unsafe part of Mitchell's code is the where clause, and Mitchell also presents a tool to prove this safe.

Our goal will be to write this function in a safe way, with a type signature that ensures the two properties we expect from the function. We also aim to do this without having to change the shape of the code, only the list implementation we are using.

The present unsafe operations in risersXu and risersMitchell depend on the second property of the function: non-null inputs generate non-null outputs. We could write a type for this functions using a trusted library with phantom types for branding (paper mirror). This technique (called lightweight static capabilities) can do this and much else as well, but since clients lose all ability to pattern match (even in case), risers becomes much more verbose. We could also write a type signature guaranteeing this by using GADTs. Without using one of these, incomplete pattern matching or calling unsafe head and tail on the result of the recursive call seems inevitable.

Here's another way to write the function which does away with the need for the second property on the recursive call, substituting instead the need for the first property that no lists in the return value are empty:

> risersAlt :: (Ord t) => [t] -> [[t]]
> risersAlt [] = []
> risersAlt (x:xs) = 
>     case risersAlt xs of
>       [] -> [[x]]
>       w@((y:ys):z) -> 
>           if x <= y
>           then (x:y:ys):z
>           else ([x]):w
>       ([]:_) -> error "risersAlt"

The error is never reached.

Though ensuring the second property with our usual types seems tricky, ensuring the first is not too tough:

> type And1 a = (a,[a])
> risersAlt' :: Ord a => [a] -> [And1 a]
> risersAlt' [] = []
> risersAlt' (x:xs) = 
>     case risersAlt' xs of
>       [] -> [(x,[])]
>       w@(yys:z) -> 
>           if x <= fst yys
>           then (x,fst yys : snd yys):z
>           else (x,[]):w

It is now much easier to see that risers is safe: There is one pattern match and one case, and each is simple. No unsafe functions like head or tail are called. It does have three disadvantages, however.

First, the second property is still true, but the function type does not enforce it. This means that any other callers of risers may have to use incomplete pattern matching or unsafe functions, since they may not be so easy to transform. It is my intuition that it is not frequently the case that these functions are tricky to transform, but perhaps Neil Mitchell disagrees.

We could fix this by writing another risers function with type And1 a -> And1 (And1 a), but this brings us to the second problem: And1 a is not a subtype of [a]. This means that callers of our hypothetical other risers function (as well as consumers of the output from risersAlt') must explicitly coerce the results back to lists.

Finally, if we are wrong about the first property, and risers does return an empty list for some non-empty input i, then for any x, risers (x:i) is _|_, while risersAlt' (x:i) is [(x,[])]. Thus, the equivalence of these two function definitions depends on the truth of the second property on the first function, which is something we were trying to get out of proving in the first place! Of course, if we're interested in the correctness of risersAlt', rather than its equivalence with risersXu or risersMitchell, then it is not to difficult to reason about. But part of the point of this was to get they compiler to do some of this reasoning for us, without having to change the shape of the code.

Let's write more expressive signatures using something we may have noticed when using GHCi. The only value of type forall a. [a] (excluding lists of _|_s) is []. Every value of type forall a . Maybe a is Nothing, and every forall a . Either a Int has an Int in Right. Andrew Koenig noticed something similar when learning ML (archived), and the ST monad operates on a similar principle. (paper with details about ST, mirror)

> data List a n = Nil n
>               | forall r . a :| (List a r)

The only values of type forall a . List a n use the Nil constructor, and the only values of type forall n . List a n use the :| constructor.

> infixr :|
> box x = x :| Nil ()
> type NonEmpty a = forall n . List a n
> onebox :: NonEmpty Int
> onebox = box 1
> onebox' :: List Int Dummy
> onebox' = onebox
> data Dummy
> -- This doesn't compile
> -- empty :: NonEmpty a
> -- empty = Nil ()

NonEmpty a is a subtype of List a x for all types x.

> data Some fa = forall n . Some (fa n)
> safeHead :: NonEmpty a -> a
> safeHead x = unsafeHead x where
>     unsafeHead (x :| _) = x
> safeTail :: NonEmpty a -> Some (List a)
> safeTail x = unsafeTail x where
>     unsafeTail (_ :| xs) = Some xs

Unfortunately, we'll be forced to Some and un-Some some values, since Haskell does not have first-class existentials, and it takes some thinking to see that safeHead and safeTail are actually safe.

Here is a transformed version of Mitchell's risers:

> risersMitchell' :: Ord a => List a n -> List (NonEmpty a) n
> risersMitchell' (Nil x) = (Nil x)
> risersMitchell' (x :| Nil _) = box (box x)
> risersMitchell' ((x {- :: a -}) :| y :| etc) =
>     case risersMitchell' (y :| etc) {- :: NonEmpty (NonEmpty a) -} of
>       Nil _ -> error "risersMitchell'"
>       s :| ss -> if x <= y
>                  then (x :| s) :| ss
>                  else (box x) :| s :| ss

Since we can't put the recursive call in a where clause, we must use a case with some dead code. The type annotations are commented out here to show they are not needed, but uncommenting them shows that the recursive call really does return a non-empty lists, and so the Nil case really is dead code.

This type signature ensures both of the properties listed when introducing risers. The key to the non-empty-arguments-produce-non-empty-results property is that the variable n in the signature is used twice. That means applying risersMitchell' to a list with a fixed (or existential) type as its second parameter can't produce a NonEmpty list.

> risersXu' :: Ord a => List a r -> List (NonEmpty a) r
> risersXu' (Nil x) = Nil x
> risersXu' (x :| Nil _) = box (box x)
> risersXu' (((x :: a) :| y :| etc) :: List a r) = 
>     let ss = risersXu' (y :| etc)
>     in case x <= y of
>          True -> case safeTail ss of
>                    Some v -> (x :| (safeHead ss)) :| v
>          False -> (box x) :| ss

Here we see that the type annotation isn't necessary to infer that risers applied to a non-empty list returns a non-empty list. The value ss isn't given a type signature, but we can apply safeHead and safeTail. The case matching on safeTail is the pain of boxing up existentials.

This is the first version of risers with a type signature that gives us the original invariant Xu and Mitchell can infer, as well as calling no unsafe functions and containing no incomplete case or let matching. It also returns a list of lists, just like the original function, and has a definition in the same shape.

With first-class existentials, this would look just like Xu's risers (modulo built-in syntax for lists). With let binding for polymorphic values, risersMitchell' would look just like Mitchell's original risers, but be safe by construction. Let binding for polymorphic values would also allow non-trusted implementations of safeTail and safeHead to be actually safe.

For contrast, here is a GADT implementation of risers:

> data GList a n where
>     GNil :: GList a IsNil
>     (:||) :: a -> GList a m -> GList a IsCons
> infixr :||
> data IsNil
> data IsCons
> gbox x = x :|| GNil
> risersMitchellG :: Ord a => GList a n -> GList (GList a IsCons) n
> risersMitchellG GNil = GNil
> risersMitchellG (c :|| GNil) = gbox $ gbox c
> risersMitchellG (x :|| y :|| etc) =
>     case risersMitchellG (y :|| etc) of
> --    GHC complains, "Inaccessible case alternative: Can't match types `IsCons' and `IsNil'
> --    In the pattern: GNil"
> --    GNil -> error "risers" 
>       s :|| ss ->
>           if x <= y
>           then (x :|| s) :|| ss
>           else (gbox x) :|| s :|| ss

This is safe and has its safety checked by GHC. It also does not require existentials, though when using this encoding, many other list functions (such as filter) will.

Now here is a lightweight static capabilities version of risers:

> -- module Protected where
> -- Export type constructor, do not export value constructor
> newtype LWSC b a = LWSC_do_not_export_me [a]
> data Full
> type FullList = LWSC Full
> data Any
> lnil :: LWSC Any a
> lnil = LWSC_do_not_export_me []
> lcons :: a -> LWSC b a -> FullList a
> lcons x (LWSC_do_not_export_me xs) = LWSC_do_not_export_me (x:xs) 
> lwhead :: FullList a -> a
> lwhead (LWSC_do_not_export_me x) = head x
> data Some' f n = forall a . Some' (f a n)
> lwtail :: FullList a -> Some' LWSC a
> lwtail (LWSC_do_not_export_me a) = Some' (LWSC_do_not_export_me (tail a))
> deal :: LWSC b a -> LWSC Any c -> (FullList a -> FullList c) -> LWSC b c
> deal (LWSC_do_not_export_me []) _ _ = LWSC_do_not_export_me []
> deal (LWSC_do_not_export_me x) _ f = 
>     case f (LWSC_do_not_export_me x) of
>       LWSC_do_not_export_me z -> LWSC_do_not_export_me z
> nullcase :: LWSC b a -> c -> (FullList a -> c) -> c
> nullcase (LWSC_do_not_export_me []) z _ = z
> nullcase (LWSC_do_not_export_me x) _ f = f (LWSC_do_not_export_me x)
> -- module Risers where
> lbox x = lcons x lnil
> risersXuLW :: Ord a => LWSC b a -> LWSC b (FullList a)
> risersXuLW x = 
>     deal x
>     lnil
>     (\z -> let x = lwhead z
>            in case lwtail z of
>                 Some' rest -> 
>                     nullcase rest
>                     (lbox (lbox x))
>                     (\yetc -> 
>                          let y = lwhead yetc
>                              etc = lwtail yetc
>                              ss = risersXuLW yetc
>                          in if x <= y
>                             then case lwtail ss of
>                                    Some' v -> lcons (lcons x $ lwhead ss) v
>                             else lcons (lbox x) ss))

There is a good bit of code that must go in the protected module, including two different functions for case dispatch. These functions are used instead of pattern matching, and make the definition of risers much more verbose, though I may not have written it with all the oleg-magic possible to make it more usable.

Out of the three, I think GADTs are still the winning approach, but it's fun to explore this new idea, especially since Haskell doesn't have traditional subtyping.

1 comment:

Neil Mitchell said...

The reasons for the implementations being different is that ESC did not support if or where at the time the paper was written. In addition, the ESC paper does not use a type class, but is restricted to Int in the type of risers. Catch does use a very identical to the one you present.

You also miss one way to remove the error. If you have a highly-optimising compiler, such as Supero it can actually optimise away the error case. If that happens, you can guarantee the original code is safe.

I also don't necessarily disagree that functions are hard to transform, in the most part. Occasionally you get a tricky one, but the majority are probably trivial. The problem is that as soon as you introduce a few hoops in front of someone, they tend to stop, and don't check their code is correct. Catch is partly about proving difficult functions, and partly about automating ones you could have done manually.

It's nice that you can use these techniques to enforce the invariant, but you would still need to add some kind of "did the user use these techniques everywhere" check afterwards. Just because you supply a safe list, the user is free to use the original list data type in other places, and it won't be immediately obvious.