Wednesday, December 19, 2007

Scala: Tail-Recursive Higher-Order Function

Last time I showed a simple handbook tail recursion solution. Digging deeper into Scala and functional programming, I've bumped into more interesting stuff. Introduce some new terminology:

1. A higher-order function is a function which takes other functions as parameters or returns another function. For example

def sum(f: Int => Int, a: Int, b: Int): Int =
if (a > b) 0 else f(a) + sum(f, a + 1, b)

defines a higher-order function which takes a function f plus two integers, and returns the sum of every f(x) in the interval (a, b).

2. An anonymous function is simply a function with no name. It's really just syntactic sugar but comes in handy when fiddling with higher-order functions, see:

sum(x => x * x, 10, 20)

is a call to the above-defined function, with the (unnamed) square-function as the first argument.
Watch how the compiler automagically infers the type of x here.

3. Currying is another interesting concept, nicely described here. Currying lets you decompose argument lists arbitrarily, hence serving a nice ground to refactoring. Taking the sum example, we could rewrite it in a way that we factor out the interval-parameters from it's argument-list, and make it just pass back another function, which in turn deals with the intervals. Like so:

def sum(f: Int => Int): (Int, Int) => Int = {
def sumF(a: Int, b: Int): Int =
if (a > b) 0 else f(a) + sumF(a + 1, b)
sumF
}

Try it out:

sum(x => x*x)(10, 20)
def s = sum(x => x*x)
s(10,20)
s(20,30)

Putting it in a more concise form, we are currying:

def sum(f: Int => Int)(a: Int, b: Int): Int =
if (a > b) 0 else f(a) + sum(f)(a + 1, b)

This is just a more compact form, but means the same as the previous one, you can still go

sum(x => x*x)(10, 20)
def s = sum(x => x*x) _ // note the underscore!
s(10,20)
s(20,30)


To wrap it all up, take exercise 5.2.1 from Scala By Example: rewrite sum (the curried version) to use tail-recursion.
Converting linear recursion to tail-recursion shouldn't be too hard: instead of recursively calling your method and then applying the iterating function in every step, you accumulate and recursively call on the partial results. Looking at the previous solution you would find it's really easy, and it goes like this:

def sum(f: Int => Int)(a: Int, b: Int): Int = {
def sum1(a: Int, result: Int): Int = {
if (a > b) result
else sum1(a + 1, f(a) + result)
}
sum1(a, 0)
}

No comments: