Custom Collector Basics

If you’ve been playing with Java streams, chances are you’ve used the Collector interface. Said differently, you’ve probably invoked the collect() method on a stream to triger a reduction operation. Something like this:

list.stream.collect(Collectors.toList());

There’s a load of predefined collectors that you can create from the factory methods. Collectors.toList is just one of them. Some of the other predefined collectors are groupingBy, partitioningBy, maxBy, minBy, reducing, summingInt.

In this post, I’ll present a really simple example that defines a custom collector, i.e a concrete implementation of the Collector<T,A,R> interface.

There are 5 methods that must be implemented when you implement this interface:

  • supplier()
  • accumulator()
  • combiner()
  • finisher()
  • characteristics()

The Goal

Define a collector that takes a stream of words (read from a book file) and partitions them into three distinct buckets based on word length.

The BookWordsCollector Class

package com.creatosix.sandbox.collectors;

import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;

import com.creatosix.sandbox.collectors.BookWordsCollector.WORD_LENGTH;

/**
 * A customer collector that manually partitions the words in a file based 
 * on the word length. 
 * 
 * A result of invoking this Collector is a Map with three entries for 
 * SHORT, MEDIUM and LONG words.
 *  
 * @author javeedchida
 *
 */
public class BookWordsCollector implements Collector<String, Map<WORD_LENGTH, List<String>>, Map<WORD_LENGTH, List<String>>> {

	private static final int SHORT_MAX = 6;
	private static final int MEDIUM_MAX = 12;
	
	@Override
	public Supplier<Map<WORD_LENGTH,List<String>>> supplier() {
		return () -> new HashMap<WORD_LENGTH,List<String>>() {
			private static final long serialVersionUID = 1L;

			{
				this.put(WORD_LENGTH.SHORT, new ArrayList<String>());
				this.put(WORD_LENGTH.MEDIUM, new ArrayList<String>());
				this.put(WORD_LENGTH.LONG, new ArrayList<String>());
			}
		};
	}

	@Override
	public BiConsumer<Map<WORD_LENGTH,List<String>>, String> accumulator() {
		return (Map<WORD_LENGTH, List<String>> map, String word) -> {
			if(word.length() <= SHORT_MAX) {
				map.get(WORD_LENGTH.SHORT).add(word);
			} else if(word.length() > SHORT_MAX && word.length() <= MEDIUM_MAX) {
				map.get(WORD_LENGTH.MEDIUM).add(word);
			} else {
				map.get(WORD_LENGTH.LONG).add(word);
			}
		};
	}

	@Override
	public BinaryOperator<Map<WORD_LENGTH,List<String>>> combiner() {
		return (Map<WORD_LENGTH,List<String>> m1, Map<WORD_LENGTH, List<String>> m2) -> {
			m1.get(WORD_LENGTH.SHORT).addAll(m2.get(WORD_LENGTH.SHORT));
			return m1;
		};
	}

	@Override
	public Function<Map<WORD_LENGTH,List<String>>, Map<WORD_LENGTH,List<String>>> finisher() {
		return Function.identity();
	}

	@Override
	public Set<Characteristics> characteristics() {
		return Collections.unmodifiableSet(EnumSet.of(Characteristics.IDENTITY_FINISH, Characteristics.UNORDERED));
	}

	public enum WORD_LENGTH {
		SHORT, MEDIUM, LONG
	};

}

The Test Class to invoke the customer collector

package com.creatosix.sandbox.collectors;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.IntSummaryStatistics;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class TestBookWordsCollector {
	
	private static final String OUTPUT_HEADER = "Word distribution by length [SHORT: <= 6; MEDIUM: 6 - 12; LONG > 12]";

	public static void main(String[] args) {
		Map<BookWordsCollector.WORD_LENGTH, List<String>> bookWordsCollector = readAndCollect();
		outputCollectorStats(bookWordsCollector);
	}
	
	private static Map<BookWordsCollector.WORD_LENGTH, List<String>> readAndCollect() {
		Map<BookWordsCollector.WORD_LENGTH, List<String>> bookWordsCollector = null;
		/* Read in the lines from a document and call customer collector on a Stream of Strings */
		try(Stream<String> lines = Files.lines(Paths.get("/Users/javeedchida/Downloads/the_count_of_monte_cristo.txt"))) {
			Stream<String> stream = lines.map(line -> line.replace(",", "")).flatMap(line -> Arrays.stream(line.split(" ")));
			bookWordsCollector = stream.collect(new BookWordsCollector());
		} catch(IOException e) {
			e.printStackTrace();
		}
		return bookWordsCollector;
	}
	
	private static void outputCollectorStats(Map<BookWordsCollector.WORD_LENGTH, List<String>> bookWordsCollector) {
		System.out.println(OUTPUT_HEADER);
		IntStream.rangeClosed(1, OUTPUT_HEADER.length()).boxed().map(i -> "-").forEach(System.out::print);
		System.out.println();
		bookWordsCollector.forEach((k, v) -> {
			Long wordCount = v.stream().count();
			IntSummaryStatistics summary = v.stream().map(w -> w.length()).collect(Collectors.summarizingInt(i -> i));
			System.out.println(String.format("%s \nAverage word length=%d \nLongest word length=%d", k, Math.round(summary.getAverage()), summary.getMax()));
			System.out.println(String.format("Total word count = %d\n", wordCount));
		});
	}

}

Sample Run

Word distribution by length [SHORT: <= 6; MEDIUM: 6 - 12; LONG > 12]
--------------------------------------------------------------------
MEDIUM 
Average word length=8 
Longest word length=12
Total word count = 91423

LONG 
Average word length=14 
Longest word length=53
Total word count = 2509

SHORT 
Average word length=3 
Longest word length=6
Total word count = 382488

Quick Observations

The supplier() simply initializes the collectors with a Map containing 3 entries of empty Lists.

The accumulator() is where the interesting work happens. Every word in the stream comes in as the second argument to the BiConsumer functional interface’s accept(Object, Object) method. In our implementation the second argument is a String. The first argument is a Map of Lists keyed by an enum that gets modified by the addition of the incoming String into the correct bucket.

The combiner() is essentially doing an addAll from the lists in the second Map argument into the corresponding lists in the first Map argument.

The finisher() is really just an identity function and simply returns the Map resulting from the accumulation step.

The characteristics() marks this as UNORDERED and IDENTITY_FINISH but we don’t mark it as CONCURRENT because accessing the collector (for outputting the resulting stats) can still cause a race condition if a parallel stream is processed.

Published by Javeed

Java/JEE solution provider. Liferay Portal enthusiast. I am a Software Engineer with a focus on highly maintainable solutions that lie at the sweet spot intersection of modularity, extensibility and development team productivity. I enjoy experimenting with innovative documentation ideas.

Leave a comment