I was reading the excellent Programming Phoenix the other day and came across this code snippet for finding the item in a list with the highest score:

items
|> Enum.sort(&(&1.score > &2.score))
|> Enum.take(1)

My initial thought when reading that code was “Why sort the whole list if you only care about the item with the highest score?” (Of course, the code in their example is never dealing with more than a dozen items so readability is the most important thing.)

The approach I thought of was to iterate through the list, building a second list of the n items with the highest score. For each item i in the list, compare it to the smallest element s in the second list. If i > s, remove s from the second list, add i and figure out the new smallest element. If i <= s, move on to the next element, without changing the second list.

I couldn’t find a method in Elixir’s Enum module that does that (as of v1.3,) so I decided to write it myself. It turned out to be a great example of the power of Elixir’s pattern matching and function guards.

I started a new project with mix new take_max and wrote a few base case tests:

defmodule TakeMaxTest do
  use ExUnit.Case
  doctest TakeMax

  test "taking 0 elements returns an empty list" do
    assert TakeMax.take_max([1,2,3], 0, &(&1)) == []
  end

  test "taking from an empty list returns an empty list" do
    assert TakeMax.take_max([], 5, &(&1)) == []
  end
end

Those tests failed, naturally, because the TakeMax.take_max/3 function doesn’t exist. Let’s write enough to get those tests green.

defmodule TakeMax do
  def take_max(_list, 0, _func), do: []
  def take_max([], _num, _func), do: []
end

Pattern matching allows us to write very little code here. If asked for 0 items or given an empty array, return an empty array.

There’s another test we can write that will help us avoid the real work of implementing take_max:

test "taking more elements than are in the list returns everything in the list" do
  source = [1, 2, 3]
  taken = TakeMax.take_max(source, 20, &(&1))
  assert Enum.sort(source) == Enum.sort(taken)
end

This test will fail because there’s no matching TakeMax.take_max/3 for the arguments. Let’s add one:

def take_max(list, num, func) do
  take_max(list, num, [], func)
end

defp take_max([], _remaining, keepers, _func), do: keepers
defp take_max([head | tail], remaining, keepers, func) do
  take_max(tail, remaining - 1, [head | keepers], func)
end

We use a private function to accumulate a list of elements to return, keeping track of how many more elements need to be added to the list. This approach saves us from repeatedly counting the elements in keepers, which is an O(n) operation in Elixir.

The current implementation doesn’t respect the value of num; it will return the entire list, regardless of how many elements were requested. That needs to change, and it’s the next test we write:

test "it returns the N largest elements in the list" do
  # Use large numbers so that Elixir doesn't think it's a list of characters
  source = [500, 700, 800, 100, 400, 900]
  taken = TakeMax.take_max(source, 3, &(&1))
  assert Enum.sort(taken) == [700, 800, 900]
end

When remaining gets down to 0 and we have elements left in the source list, we need to remove an element from keepers for each element we add (and only add an element if it’s larger than the smallest element in keepers.) With a new private function to do that, we now have an implementation that looks like this:

defp take_max([], _remaining, keepers, _func), do: keepers
defp take_max([head | tail], remaining, keepers, func) do
  take_max(tail, remaining - 1, [head | keepers], func)
end

defp take_max([head | tail], 0, keepers, func) do
  [min_keeper | rest_keepers] = Enum.sort_by(keepers, func)
  if func.(head) > func.(min_keeper) do
    keepers = [head | rest_keepers]
  end
  take_max(tail, 0, keepers, func)
end

But, the tests still fail and the output from mix test tell us why:

bschmeckpeper@raptor
~/src/elixir/take_max > mix test
lib/take_max.ex:14: warning: this clause cannot match because a previous clause at line 10 always matches
Compiled lib/take_max.ex

The ordering of the function definitions matter, Elixir will evaluate the first function with a matching definition. We could move our new function definition above the earlier two, but I prefer keeping the special cases up top. Adding a when remaining > 0 guard to the second take_max definition lets us keep the current ordering and also makes our intentions more explicit. So our implementation now looks like:

defp take_max([], _remaining, keepers, _func), do: keepers
defp take_max([head | tail], remaining, keepers, func) when remaining > 0 do
  take_max(tail, remaining - 1, [head | keepers], func)
end

defp take_max([head | tail], 0, keepers, func) do
  [min_keeper | rest_keepers] = Enum.sort_by(keepers)
  if func.(head) > func.(min_keeper) do
    keepers = [head | rest_keepers]
  end
  take_max(tail, 0, keepers, func)
end

We should also test that our sort function gets used properly:

test "it sorts elements in the list by the given function" do
  source = ["long string", "abc", "this is really long", "foobar", "nope, this is the longest"]
  taken = TakeMax.take_max(source, 2, &String.length/1)
  assert Enum.sort(taken) == ["nope, this is the longest", "this is really long"]
end

That test passes out of the box!

The current implementation can be improved, though. We currently find min_keeper by sorting keepers on each function call. If we passed an already sorted keepers into the function, finding min_keeper would be an O(1) operation and we would only need to sort the list when min_keeper is removed and a new element is added. This change introduces a new function, because remaining == 1 is now a special case.

defp take_max([head | []], remaining, keepers, _func) when remaining > 0, do: [head | keepers]
defp take_max([head | tail], 1, keepers, func) do
  take_max(tail, 0, Enum.sort_by([head | keepers], func), func)
end

defp take_max([head | tail], 0, [min_keeper | rest_keepers] = keepers, func) do
  if func.(head) > func.(min_keeper) do
    keepers = Enum.sort_by([head | rest_keepers], func)
  end
  take_max(tail, 0, keepers, func)
end

If we only have one more element, we return it along with the rest of keepers. (This function should be placed above the defp take_max([head | tail], remaining, keepers, func) when remaining > 1 function, so that it matches as often as possible.) Otherwise, we add the head element to keepers, sort it and pass that into the next call to take_max. Inside of the main take_max function, we use pattern matching to assign min_keeper and rest_keepers and re-sort the list if we’re replacing min_keeper with head. We can also adjust our previous guard to match on remaining > 1.

Our complete implementation looks like this:

defmodule TakeMax do
  def take_max(_list, 0, _func), do: []
  def take_max([], _num, _func), do: []

  def take_max(list, num, func) do
    take_max(list, num, [], func)
  end

  defp take_max([], _remaining, keepers, _func), do: keepers
  defp take_max([head | []], remaining, keepers, _func) when remaining > 0, do: [head | keepers]
  defp take_max([head | tail], remaining, keepers, func) when remaining > 1 do
    take_max(tail, remaining - 1, [head | keepers], func)
  end

  defp take_max([head | tail], 1, keepers, func) do
    take_max(tail, 0, Enum.sort_by([head | keepers], func), func)
  end

  defp take_max([head | tail], 0, [min_keeper | rest_keepers] = keepers, func) do
    if func.(head) > func.(min_keeper) do
      keepers = Enum.sort_by([head | rest_keepers], func)
    end
    take_max(tail, 0, keepers, func)
  end
end

This code highlights a lot of what I like about Elixir. Recursion, pattern matching and guards isolate the edge cases and assignment in function declarations removes boilerplate from function bodies. As a result, the heart of TakeMax, the only function that includes any actual logic, is just 3 lines long.

All of the code for this post is available on GitHub.