Implementing wc command – Rust coding challenges18 min read
This is the second post of the Rust coding challenges series, in which we will learn computer science fundamentals and write some actual code to make usable programs. In this challenge, we will learn how to implement a simple word count command in Rust efficiently. To start, here are 2 main functions that our program needs to perform:
- count the number of words from a given text input
- count the number of occurrences of a particular word from a given text input
To meet this functionality is indeed as straightforward as it might sound. Basically, we loop one pass through, tokenize the text, and sum up the number, easy, right? But here is a caveat for this challenge, the input text has 10 million lines of random text, which seems not too big objectively since it’s just around 700MB of text, but to some extent, I think it’s a good number to start to learn some basic optimization. The cool thing about Rust is that we don’t need too much experience or the knowledge of too low-level details to make our program performant. By doing this challenge, we will learn a new data structure specifically suitable for handling large files, take advantage of CPU power to do parallel processing, efficient chunking, and tokenizing, unsafe operations, borrowed data, and reusable functions in Rust. So let’s tackle this problem!
Naive approach (30 seconds)
Here is a simple program that we can first come up with to solve this word count exercise:
pub fn read_file(file_path: &str) -> String {
let mut file = File::open(file_path).expect("Fail to open the file");
let mut content = String::new();
file.read_to_string(&mut content).expect("Fail to read the file");
content
}
pub fn count_words_simple(text: String) -> usize {
text.lines().flat_map(|line| line.split_whitespace()).count()
}
We first load all the file content into the memory, and then there is a tiny function doing the word-counting job. Here we loop through each line, flat each line into words and then count them. The result is that it correctly counts the number of words for our 10 million lines input, but in…30 seconds. The native wc
command of the Unix-based OS outputs the same result almost instantly. Let’s make some improvements!
For the read_file function, do you think that we need to adjust it somehow? First, we can measure how long it takes to load a 10 million line file into the memory. Simply as that:
pub fn read_file(file_path: &str) -> String {
let mut file = File::open(file_path).expect("Fail to open the file");
let mut content = String::new();
file.read_to_string(&mut content).expect("Fail to read the file");
content
}
pub fn main() {
let start = Instant::now();
let file_path = "/path/to/the/input/file.txt"
let content = read_file(file_path);
println!("Total time read file: {:?}", start.elapsed());
}
Total time read file: 137.034ms
Turns out the performance of our read_file is already impressive, it just took around 137ms to load all the file content on my computer. File::open
establishes a connection to a file from our application through a system call like read()
in Unix. The actual work happens in the file.read_to_string(...)
, internally Rust uses a buffer for this purpose, so instead of reading a file byte-by-byte, the buffer in the memory is the immediate location where data is stored. One question might arise, why is performance better when more layers between 2 destinations, now we have to go from a file, and then to a buffer residing in the memory, and finally reach the application.
Because each time the file data is read, it needs to perform a “system call” and by doing so, the CPU needs to switch from the user mode to kernel mode, and context-switching is indeed expensive. So the idea of a buffer is to reduce the number of system calls by using a temporary location in the memory to store the data, the size of the buffer can be tuned to maximize the performance, a typical buffer size would be from 4KB to 4MB. So instead of issuing a system call on every byte, we only do that once per 4MB, for example. The advantage of this approach would be tremendous if we look at some numbers.
For example, we have a 100MB file, and each system call costs 200ns. Ignoring the memory access latency and other factors, then the time we need to wait will be:
100 million bytes × 200 ns/system call = 100_000_000 * 200 ns = 20 seconds
But with a buffer size of 4KB, then the amount of time that we have to wait will be drastically reduced:
100 MB / 4 MB = 25 MB -> 25_000 system calls
25_000 * 200 ns = 5_000_000 ns = 5 ms
The first optimization
Let’s pay attention to the count_word_simple
function since it took most of our time. What we can do here? The first intuition is to take advantage of available CPU power, we can break the text into chunks, count each chunk in parallel, and at the end sum them up, sounds really like a small variant of map-reduce! First, we write a function to count words in chunks:
fn count_words_in_chunk_1(chunk: &[String]) -> usize {
chunk
.iter()
.flat_map(|line| line.split_whitespace())
.count()
}
Then we have our main function to do the parallel operations:
pub fn count_word_parallel_1(text: String) -> usize {
let num_threads = std::thread::available_parallelism().unwrap().get();
let lines: Vec<String> = text.lines().map(|line| line.to_string()).collect();
let lines = Arc::new(lines);
let lines_per_threads = (lines.len() + num_threads - 1) / num_threads;
let total_word_count = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for i in 0..num_threads {
let start = lines_per_threads * i;
let len = lines.len();
let end = min(start + lines_per_threads, lines.len());
if start < end {
let total_word_count_clone = total_word_count.clone();
let lines_clone = Arc::clone(&lines);
let handle = thread::spawn(move || {
let chunk = &lines_clone[start..end];
let count = count_words_in_chunk_1(chunk);
total_word_count_clone.fetch_add(count, Ordering::Relaxed);
});
handles.push(handle);
}
}
for handle in handles {
handle.join().unwrap();
}
total_word_count.load(Ordering::Relaxed)
}
The result? From 30 seconds, here is our new metric:
Total words: 85039300, elapsed time: 7.608273208s
About 4 times faster, before diving into further improvement for our word counting function, let’s break down some new concepts that we’ve encountered here. First, we convert the text into a vector of lines, and because the lines will be shared between different threads, to access the lines safely we wrap it into an Arc data structure, which stands for ‘Atomically Reference Counted’, then we split the work of counting chunks equally to available threads. If we use lines.len() / num_threads to distribute the workload, all the workers would have equal chunks to process but the reminder, meaning some lines at the end will be left uncounted. That’s why we need to take the (lines.len() + num_threads - 1) / num_threads
to cover all the lines from the input.
Next, we spawn a number of threads, each working in a chunk independently. Notice the thread::spawn syntax and the move keyword. Here everything inside the closure after the move
keyword is moved to the worker thread, meaning that the reference to the total_word_count_clone
is no longer valid in the main thread since it has “moved” to the worker thread. To understand what it means exactly, here is a piece of code that will not compile:
let lines_clone = Arc::clone(&lines);
let handle = thread::spawn(move || {
let chunk = &lines_clone[start..end];
let count = count_words_in_chunk_1(chunk);
total_word_count_clone.fetch_add(count, Ordering::Relaxed);
});
// Attempt to use lines_clone here would cause a compile error
let lines_ref = lines_clone;
Let’s name the main thread main_thread
and the spawned thread worker_thread
, here is some visualization:
Before the Move: After the Move:
+---------------------+ +---------------------+
| Main Thread | | Main Thread |
| | | |
| +---------------+ | | +---------------+ |
| | lines_clone | | | | lines_clone | |
| | (Valid) | | | | (Invalid) | |
| +---------------+ | | +---------------+ |
| | | |
| +---------------+ | | +---------------+ |
| | lines_ref | | | | lines_ref | |
| | (Valid) | | | | (Valid) | |
| +---------------+ | | +---------------+ |
+---------------------+ +---------------------+
| |
| |
| |
| move |
V V
+-----------------------------+ +-----------------------------+
| Worker Thread | | Worker Thread |
| | | |
| +---------------------+ | | +---------------------+ |
| | lines_clone | | | | lines_clone | |
| | (Now owned) | | | | (Now owned) | |
| +---------------------+ | | +---------------------+ |
| | | |
+-----------------------------+ +-----------------------------+
Concurrency concepts
The combination of Arc and AtomicUsize is one of the correct ways to increment the counter in a thread-safe manner. Here, the Arc lets you share the data to multiple threads without the excessive overhead of copying the exact same data into each thread, imagine instead of using Arc::clone(&lines)
, we do something like lines.clone()
, which is highly inefficient in terms of memory usage and computational power since we have to allocate more memory and copy the data over and over. Just think about Arc a wrapper to exactly one value and can be shared safely between threads.
Rather than using a Mutex
, which is also a data structure in the standard library acting as a lock, we just have to use AtomicUsize
since we’re only interested in the total count number, the fetch_add
function is guaranteed to be atomic, so it correctly does its job and at the same time efficient because it doesn’t incur any locking or synchronization.
However, explaining what does Ordering::Relaxed
mean can be a challenge, let me try to formulate it in a simple way. In the multi-threaded application, multiple threads can do operations concurrently, and what if there is a thread reading a value X and some other thread writing to the value X at the same time? There is a potential data race since the X value might be read before it’s written because the execution order isn’t always the same from case to case, if there is no ordering guarantee between threads, there can be sometimes we see the value is read before it’s written, and sometimes it’s written before it’s read because of some factors such as instruction reordering because of CPU optimization. For example, if 2 threads read and write a value at the same time, one of the execution orders might be:
+-------------------------+
| Thread A (Writer) |
|------------------------|
| 1. Lock |
| 2. Write 42 |
| 3. Unlock |
+-------------------------+
+-------------------------+
| Thread B (Reader) |
|------------------------|
| 4. Lock |
| 5. Read (expect 42) |
| 6. Unlock |
+-------------------------+
That is the ideal case we expect, here we have a shared variable holding an integer value, and the value 42 is read after it’s first written. But here is another possibility of the execution order:
+-------------------------+
| Thread A (Writer) |
|------------------------|
| 1. Lock |
| 2. Unlock |
| 3. Write 42 |
+-------------------------+
+-------------------------+
| Thread B (Reader) |
|------------------------|
| 4. Lock |
| 5. Read (sees 0, not 42|
| 6. Unlock |
+-------------------------+
We’re not lucky in this case, the CPU has reordered the instructions and the updated value is not visible at the time thread B reads it. That’s why the concept of Ordering
comes into play, if the Ordering::Release
is being used, it’s guaranteed that 2 threads are reading and writing concurrently, the write operation will happen before the read operation. In our word count example, because each of our workers counts the chunks independently and there is no direct communication between them, that’s why we use Ordering::Relaxed
since this enum doesn’t preserve any instruction order, just atomic operations.
At the end of our function count_word_parallel_1
, the handle.join()
blocks the main thread until the computation is finished in each worker. Finally, we load the computed value and return it.
Further optimization
At the very beginning of count_word_parallel_1
, we convert a text into a vector of lines:
fn count_word_parallel_1(text: String) -> usize {
let lines: Vec<String> = text.lines().map(|line| line.to_string()).collect();
...
}
There are multiple problems with that, firstly we need to loop over all the lines, and then convert each line to an owned string, and then collect them into a vector, this is an expensive operation when our input text is large. Instead, we can completely remove this computation and work directly with the input string, something like this:
fn count_word_parallel_2(text: &str) -> usize {
// what to do here?
}
Now we need to figure out the way to do the chunking, under the hood, the text is an array of characters, and we can create a helper function that takes in the input text, and the number of chunks we want, finally it will return a vector of borrowed string (not owned one):
fn get_chunks(text: &str, partitions: usize) -> Vec<&str> {
let mut end_index = 0;
let mut chunks = Vec::with_capacity(partitions);
for i in 0..partitions {
let start_index = if i == 0 { 0 } else { end_index + 1 };
end_index = get_end_index(text, i, start_index, partitions);
if start_index < text.len() {
let chunk = &text[start_index..end_index];
chunks.push(chunk);
}
}
chunks
}
fn get_end_index(text: &str, i: usize, start_index: usize, partitions: usize) -> usize {
let chunk_size = text.len() / partitions;
let bytes = text.as_bytes();
let mut end_index = start_index + chunk_size;
if end_index >= text.len() || i == partitions - 1 {
return text.len();
}
while end_index < text.len() && bytes[end_index] != b' ' {
end_index += 1;
}
end_index
}
The partitions that we pass in with be equal to the number of available workers our computer has, so for each partition, we need to calculate the correct index of both the start index and the end index, the only constraint here is that we do not start or end in the middle of a word, that’s why we have the get_end_index
function to help us in this case.
One observation you can make here is that we use text.as_bytes()
directly rather than text.chars()
, why it’s better? Because chars()
is more costly when it needs to convert UTF-8 byte values into valid character values (not the character literals but the 4-byte Unicode scalars that represent them). With as_bytes()
, however, there is no complex decoding! The space character is encoded using a single byte, that’s why we just have to check each byte whether it’s a space character, and if it’s not, then we keep incrementing the end_index
, the b' '
, simply convert the character literal, here the space character into its byte representation in the UTF-8 format.
Then, we modify the count_word_in_chunk_1
just a little bit, from using the owned string data to a borrowed one:
fn count_words_in_chunk_1(chunk: &str) -> usize {
chunk
.lines()
.flat_map(|line| line.split_whitespace())
.count()
}
Let’s see how much improvement boost we’ve achieved after some optimizations:
Total words: 85039300, elapsed time: 5.268985375s
About a 2.25-second difference, pretty good improvement. Moving back to the count_word_in_chunk_1
we can further improve this function without sacrificing the correctness by something like this:
fn count_word_in_chunk_2(chunk: &str) -> usize {
let mut count = 0;
let mut in_word = false;
let bytes = chunk.as_bytes();
for byte in bytes {
// Check for any ASCII whitespace (space, newline, tab, etc.)
if byte.is_ascii_whitespace() {
in_word = false;
} else {
if !in_word {
count += 1;
in_word = true;
}
}
}
count
}
Here we apply the same technique that we’ve used in the get_chunks
function and get away from lines()
and flat_map()
as those functions create extra iterators and memory allocation, keep in mind that here we only consider ASCII whitespace, Unicode could potentially have some others.
Now, let’s see the final result of our word_count
function:
Total words: 85039300, elapsed time: 1.22284s
10 million lines of text in just over 1.2 seconds without diving too much into low-level details, that’s what we’ve achieved through some basic optimizations.
Word occurrences
Next, we’ll create another function that takes an input file and a word and returns the number of occurrences of that word, here we can apply the same techniques that we’ve done from the count_word
function:
- Read files incrementally using a buffer
- Use borrowed data directly instead of allocating memory for the owned one
- Avoid immediate overhead, split text into chunks efficiently
- Process all the chunks concurrently, count the number of occurrences in each chunk
- Finally, sum up all the occurrences.
Let’s define our count_word_occurrences_in_chunk
:
fn count_word_occurrences_in_chunk(chunk: &str, word: &str) -> usize {
let mut count = 0;
let word_len = word.len();
let mut start = 0;
let chunk_bytes = chunk.as_bytes();
while let Some(pos) = chunk[start..].find(word) {
let end = start + pos + word_len;
if start + pos == 0
|| chunk_bytes[start + pos - 1].is_ascii_whitespace()
&& (end == chunk.len() || chunk_bytes[end].is_ascii_whitespace())
{
count += 1;
}
start += pos + word_len;
}
count
}
Looking at this code can be a little bit confusing at first. Let’s break down what we have here, we still convert the chunk into bytes as usual. Inside the while loop, we keep searching for the word, and if one is found, we need to calculate the next start position for the left string as in chunk[start...]
, we have end = start + pos + word_len
instead of end = start + word_len
because the start
position here is not the index of the first letter of the matched word, the first letter index will be at start + pos
since the pos
here will be relative to the chunk[start..]
, not the chunk[0..]
. We simply check if the index right before the matched first letter is a space and then increment the counter, finally returning the count.
However, as we’ve been working on the second function, noticed we only have to create one extra function count_word_occurrences_in_chunk
and the old logic can be reused! We’re now ready to create a reusable function that can be utilized in both tasks that we need to do:
fn count_in_parallel<F>(text: &str, count_fn: F) -> usize
where
F: Fn(&str) -> usize + Send + Sync + 'static,
{
let num_threads: usize = thread::available_parallelism().unwrap().get();
let total_word_count: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
let count_fn = Arc::new(count_fn);
let chunks = get_chunks(text, num_threads);
thread::scope(|scope| {
for i in 0..chunks.len() {
let total_word_count_clone = Arc::clone(&total_word_count);
let count_fn_clone = Arc::new(&count_fn);
let chunk = chunks[i];
scope.spawn(move || {
let count = count_fn_clone(chunk);
total_word_count_clone.fetch_add(count, Ordering::Relaxed);
});
}
});
total_word_count.load(Ordering::Relaxed)
}
The first argument we passed is the text input, and the second argument is a custom function that we can pass in, depending on the tasks that we need to perform. Here basically what it means is that the custom function that we pass needs to take a string slice (&str)
and return an usize
, the Send
and Sync
are important traits that ensure our count_fn
function can be accessed safely across multiple threads, and the 'static
lifetime ensures the function can live for the entire duration of the program.
pub fn count_words(text: &str) -> usize {
count_in_parallel(text, count_words_in_chunk)
}
pub fn count_word_occurrences(text: &str, word: String) -> usize {
count_in_parallel(text, move |chunk| {
count_word_occurrences_in_chunk(chunk, &word)
})
}
Let’s take a look at the performance when performing the second task (i.e. counting the word occurrences):
./wc-command rccwc -wo rust /Users/learntocodetogether/rcc/wc-command/input/random_10m_lines.txt`
Total occurrences of 'rust': 8502920, elapsed time: 1.749215542s
Just over 1.7 seconds!
Bonus: MMap data structure
mmap
which stands for Memory-Mapped Files, is an efficient data structure when handling large files and random access is desirable. The file content will be directly mapped to the memory without any system call in between, the file data is not eagerly loaded as fully into the memory like file.read_to_string()
but instead only the portion of the file will be retrieved, hence reducing the memory usage considerably when a big file is opened.
We first need to import the dependency:
// Cargo.toml
[dependencies]
memmap2 = "0.9.5"
We can then use this as follows:
pub fn read_file_as_string(file_path: &str) -> String {
let file = File::open(file_path).expect("Path is not valid");
let mmap = unsafe { MmapOptions::new().map(&file).expect("Cannot map from file") };
let text = unsafe { str::from_utf8_unchecked(&mmap).to_string() };
text
}
Rust is known for memory safety, but the language itself is not 100% memory-safe because we can still perform unsafe operations. Unsafe operations are those operations accessing memory locations where they aren’t supposed to, for example, dereferencing a raw/null pointer. In our case, the mmap
operations are unsafe since if the file is mapped while being written at the same time, or the file format can be something not in the UTF-8 format, in these case it would lead to unexpected behavior and would crash the program.
The complete implementation for this challenge can be found here: https://github.com/namvdo/rust-coding-challenges/tree/master/challenges/wc-command