F# code: memoize a recursive function

In this article we’ll look at how to memoize a general function f : ('a -> 'b) -> 'a -> 'b in F#, along with some interesting applications.

Generic version

Most versions of this code that you can find online only deal with the simple case 'a -> 'b, which doesn’t allow the function to call a memoized form of itself for recursion. This version does. The trick is to define mutually recursive functions g : 'a -> 'b and h : ('a -> 'b) -> 'a -> 'b, where g is the memoized form of f that will get passed to f itself, and h does the actual work.

open System.Collections.Generic

// memoize : (('a -> 'b) -> 'a -> 'b) -> ('a -> 'b)
let memoize f =
  let mem = Dictionary<'a, 'b>();
  let rec g key = h g key
  and h r key =
    match mem.TryGetValue(key) with
    | (true, value) -> value
    | _ ->
      let value = f g key
      mem.Add(key, value)
      value
  g

Example (Fibonacci numbers):

let fib r n =
  match n with
  | 0 -> 0
  | 1 -> 1
  | n -> r (n - 1) + r (n - 2)

printfn "%d" ((memoize fib) 20)

Array-based versions

The generic memoize function can work with any key type 'a that Dictionary supports, but this can be unnecessarily slow if only need int keys and we have a known bound for the key. In that case, using an array is substantially faster than using memoize with an int key:

let memoize1D ni invalid (f : (int -> 'a) -> int -> 'a) =
  let mem = Array.create ni invalid
  let rec g i =
    h g i
  and h r i =
    if 0 <= i && i < ni then
      match mem.[i] with
      | value when value <> invalid -> value
      | value ->
        let value = f g i
        mem.[i] <- value
        value
    else
      f g i
  g

Example (Unbounded knapsack problem):

// Computes the maximum possible total that is less than or equal to n,
// given positive weights w.
let knapsack w n =
  let memoized = memoize1D (n + 1) -1 (fun r k ->
    Array.fold (fun acc x -> if x <= k then max acc (x + r (k - x)) else acc) 0 w)
  memoized n

F# has 2D and 3D arrays, which we can take advantage of when we need int * int or int * int * int keys:

let memoize2D ni nj invalid (f : (int -> int -> 'a) -> int -> int -> 'a) =
  let mem = Array2D.create ni nj invalid
  let rec g i j =
    h g i j
  and h r i j =
    if 0 <= i && i < ni && 0 <= j && j < nj then
      match mem.[i, j] with
      | value when value <> invalid -> value
      | value ->
        let value = f g i j
        mem.[i, j] <- value
        value
    else
      f g i j
  g

let memoize3D ni nj nk invalid (f : (int -> int -> int -> 'a) -> int -> int -> int -> 'a) =
  let mem = Array3D.create ni nj nk invalid
  let rec g i j k =
    h g i j k
  and h r i j k =
    if 0 <= i && i < ni && 0 <= j && j < nj && 0 <= k && k < nk then
      match mem.[i, j, k] with
      | value when value <> invalid -> value
      | value ->
        let value = f g i j k
        mem.[i, j, k] <- value
        value
    else
      f g i j k
  g

Example (Money changing problem):

// How many ways are there to make change for $n, given an unlimited
// supply of coins with specified dollar values?
let waysToMakeChange (coins : int array) n =
  let memoized = memoize2D coins.Length (n + 1) -1L (fun r minCoin remaining ->
    if remaining = 0 then
      1L
    else
      let mutable sum = 0L
      for i = minCoin to coins.Length - 1 do
        if remaining >= coins.[i] then
          sum <- sum + r i (remaining - coins.[i])
      sum)
  memoized 0 n

For U.S. coin denominations and a total value of $1, we get:

> waysToMakeChange [|1; 5; 10; 25; 50; 100|] 100;;
val it : int64 = 293L

Minimizing stack usage

The above code doesn’t work for $30 because we quickly run out of stack space:

> waysToMakeChange [|1; 5; 10; 25; 50; 100|] 3000;;

Process is terminated due to StackOverflowException.

This occurs because the expression memoized 0 n expands to an extremely long chain of recursive calls, which we don’t have enough stack space for. To fix this, we can use what we know about the problem to enforce a specific order of execution. Since memoized i n calls memoized i' n' for \(n’ < n\) only, we should insert the following block just before memoized 0 n:

  for i = 0 to n - 1 do
    memoized 0 i |> ignore

Now we get:

> waysToMakeChange [|1; 5; 10; 25; 50; 100|] 3000;;
val it : int64 = 379747086L
> waysToMakeChange [|1; 5; 10; 25; 50; 100|] 10000;;
val it : int64 = 139946140451L

Infinite sequences

Given coin values \(1 \le c_1 < \cdots < c_m \), the solution to the money changing problem is the coefficient of \(x^n\) in the generating function $$ \frac{1}{(1-x^{c_1})\cdots(1-x^{c_m})}. $$ When the number of coins is small, a closed form solution can easily be found by hand using partial fraction decomposition and the formula $$ \frac{1}{(1-x)^{k+1}} = \sum_{k\ge 0}\binom{n+k}{k}x^n.$$ This quickly becomes infeasible as the number of coins increases, while the algorithm in waysToMakeChange scales fairly well.

We can extend the algorithm to work with an infinite sequence of coins with values \(1 \le c_1 < c_2 < \cdots\), which is possible because F# sequences (really just IEnumerables) are lazy by design:

// Note: coins must be a strictly increasing, possibly infinite sequence.
let waysToMakeChangeSeq (coins : int seq) n =
  let memoized = memoize2D (n + 1) (n + 1) -1L (fun r minCoin remaining ->
    if remaining = 0 then
      1L
    else
      let mutable sum = 0L
      coins |> Seq.skip minCoin |> Seq.takeWhile ((>=) remaining)
      |> Seq.iteri (fun i coin -> sum <- sum + r (minCoin + i) (remaining - coin))
      sum)
  for i = 0 to n - 1 do
    memoized 0 i |> ignore
  memoized 0 n

The associated generating function is $$
\frac{1}{(1-x^{c_1})(1-x^{c_2})\cdots}.
$$

A partition of an integer \(n\) is a representation of \(n\) as an unordered sum of positive integers. (An empty sum evaluates to \(0\), so the number of partitions of \(0\) is \(1\).) The number of partitions of \(n\) is really just the money changing problem with an infinite series of coins \((1,2,\dots)\):

> let positiveIntegers = Seq.initInfinite ((+) 1);;

val positiveIntegers : seq<int>

> positiveIntegers;;
val it : seq<int> = seq [1; 2; 3; 4; ...]
> waysToMakeChangeSeq positiveIntegers 0;;
val it : int64 = 1L
> waysToMakeChangeSeq positiveIntegers 5;;
val it : int64 = 7L
> waysToMakeChangeSeq positiveIntegers 100;;
val it : int64 = 190569292L

So there are \(190,569,292\) different ways of writing \(100\) as an unordered sum of positive integers.

One response

Leave a Reply