Efficient Vocabulary Generation for Very Large Corpora
Table of Contents
For an ongoing project, I had to perform topic clustering on a large corpus of diverse and very long news articles. BERTopic usually works very well for such use-cases, with a variety of memory saving techniques already being implemented. Where I ran into trouble was a usually innocuous intermediate step.
After embedding the documents, reducing the embedding feature dimensions, and clustering the corpus, a second set of features is estimated for each cluster. Specifically, BERTopic uses the class-based TF-IDF scores to generate a topic-token matrix. The clustering in document embedding space is assumed to be non-convex, making estimation of a central tendency infeasible1. By extension, computing the distance between topics is diffcult or intractable. Using topic-token TF-IDF representations instead, inter topic distance can be robustly estimated. Another benefit is that we immediately get access to textual representations of our topics.
But this assumes access to a vocabulary estimated over the entire corpus. For most cases, this vocabulary needs to include higher order n-grams, while simultaneously removing semantically poor words. Human language being what it is, this means the vocabulary size will quickly exceed working memory.
The Problem
Using the default CountVectorizer
(a standard sklearn function) proved to be a problem, with an intermediate vocabulary that just would not fit into memory. It simply iterates over the corpus and tracks how often each n-gram token occurs. Once ran, it only keeps the top $n$ tokens.
In general, the issue of open-ended vocabularies has been solved by using sub-word tokenization. After all, language models are more or less required to hold two $d\times |\mathcal{V}|$ layers in memory; one to encode each token $t\in\mathcal{V}$ and one to convert the $d$ dimensional contextual embeddings into $|\mathcal{V}|$-sized softmax. Reasonable sized vocabularies are essentially a prerequisite. Sub-word tokenization strikes a trade-off between vocabulary size and sequence length, at the cost of tokens that are (individually) semantically meaningless. This invalidates this approach; I’m interested in semantic topic similarity, not whether or not their counts of sub-word tokens happen to overlap.
Using HashingVectorizer
, I could retain the full word tokens, and it keeps no state, greatly reducing the memory strain. Specifically, it only retains a token’s hash while discarding the actual orthographic form of the tokens. This makes post-hoc inspection of the features also impossible. In the end, I would like to have a representation of each topic available. Not mention, we’d still have to have keep a $|\mathcal{V}|$ dictionary in memory.
To it’s merit, BERTopic provides an OnlineCountVectorizer
meant to solve exactly this problem. Instead of estimating a vocabulary over the entire corpus, it uses small batches sampled from the corpus to learn an intermediate vocabulary, dropping any tokens that occur too infrequently or exceed the desired vocabulary size. Mini-batching alleviates the memory woes, but it results in an inexact token frequencies at the end. Intermediate counts are forgotten, and which words make the vocabulary and which don’t largely relies on the order of documents. Ususally, the words that occur less frequently are also exactly the words that carry the most semantic information, and it is here that this approach is most likely to err.
I’d like to think we can do better.
The Solution (?)
Online estimation, or mini-batching, is not a bad idea altogether, though. If we iterate through our corpus until we have a vocabulary of size $|\mathcal{V}_{i}|=n_{\text{max_terms}}$, we’d be left with $m$ separate relatively small vocabularies. Each of these we can easily store as a list on disk, meaning the memory footprint is as small as possible. In the end, to construct our final vocabulary, we’d just have to search through the $m$ lists and sum their occurrences to get their exact count.
… except that this incurs a $\mathcal{O}\left(m\cdot n \cdot n_{\text{max_terms}}\right)$ cost search operation. For each one of $n$ words, we’d possibly have to look through $n_{\text{max_terms}}$ words in $m$. For infrequent words (again, exactly the class of words we care most about), the probability of a word being in each mini-batch vocabulary list is relatively small, meaning we’re bound to hit worst-case performance often.
Luckily, we can sort words alphabetically. Once sorted, we can find the token in each independent list in $\mathcal{O}(1)$ time. The sorted lists can been seen a set of stacks, each with an independent pointer. At the top of the stack we have the smallest value item, i.e. the vocabulary’s lowest order token.
1 0 0 0 0
2--> [ a | aa | a | ab ] smallest = a
3 [ aa | b | ab | ac ]
4 [ ac | ba | ac | ba ]
5
6processed = []
At this point we scan the top of the stacks for the smallest token, ‘a’, and for stacks that have it at the top we pop it off (or equivalently increase the pointer by 1). All stacks that did not have that token at the top are guaranteed to only have higher value tokens, and we ignore those stacks. In other words, if the stack does not have the token on top, it simply does not have it.
1 1 0 1 0
2--> [ aa | aa | ab | ab ] smallest = aa
3 [ ac | b | ac | ac ]
4 [ ba | ba | c | ba ]
5
6processed = [a]
After processing the tokens, we get to a situation that is remarkably similar to the start. Once again, the smallest value token for each stack is guaranteed to be at the top, and we need only scan the top layer. Rinse, lather and repeat.
1 2 1 1 0
2--> [ ac | b | ab | ab ] smallest = ab
3 [ ba | ba | ac | ac ]
4 [ bc | bb | c | ba ]
5
6processed = [a, aa]
To process the entire vocabulary, we simply gradually move down each stack until exhausting the values within it. In each iteration, because we started with sorted stacks and we select the smallest element, we have guarantees that eliminate the need for search the rest of the list.
Implementing a vocabulary reader
The devil lies in the (implementation) details. Let’s assume the vocabulary lists are stored as .csv
files, with each row containing just the token, and its count within the batch. We can efficiently achieve the desired behaviour by writing a class with just two methods:2
- Peek: look at the next line of the vocabulary list, return the token on that line, and cache the token count
- Keep: return the peeked token’s count and empty the cache
Efficient is the operative term here. We do not want to load all lists into memory. Instead, we want to just load in single rows, and move ever further down the list. Python allows moving along the file stream using the tell
and seek
methods. The former returns the current position in the data stream, whereas the latter moves the data stream to the provided position. We can use these to set our pointer.
The Python backbone for this would look something like:
1class VocabReader:
2
3 def __init__(self):
4 # Set a pointer
5 # This is the current location in the data stream
6 self.pointer = 0
7
8 def peek(self) -> str:
9 # Look at token located at `self.pointer`
10 # Returns the token
11 ...
12
13 def keep(self) -> int:
14 # Returns the count of the peeked token
15 ...
16
17 @property
18 def terminated(self) -> bool:
19 # Whether or not the list has been exhausted
20 ...
In the background, the VocabReader
instances can cache the result of peek, only opening the file when the cache is empty. This way, we’re only opening the file once per token.
At this point we have offloaded the vocabulary entirely to the disk, while the added compute cost is $\mathcal{O}(|\mathcal{V}|)$. The initial processing has the same cost, so we’ve essentially doubled the compute cost (neglibe in big-O terms).
Finishing Touches
Ultimately, all we want is a list with at most $n_{\text{max_terms}}$ tokens in it. We can process the intermediate vocabularies efficiently, but we still need to track and store the large collated vocabulary somehow. For that, we only need three more, relatively simple data structures.
Heap: quickly insert tokens into a ‘sorted’ vocabulary, with $\mathcal{O}(1)$ access to the least frequent tokens. Python implements this through the
heapq
module.LimitedCounter: Some data structure tracking which tokens we’ve already processed. A
set
ordict
wouldn’t work, as this would inevitably hash all $n$ tokens, the exact issue we want to avoid. Rather, we define a special instance of aCounter
that deletes an entry once it has been seen $m$ times. Once we’ve seen a token $m$ times, we can be certain it has been seen by all $m$VocabReader
instances, and won’t appear again.token2reader: a mapping from all tokens on top of a stack to the readers that have that token on top. This way, we can quickly fetch which
VocabReader
instances need updating. This is easily implemented using acollections.defaultdict[list]
Putting it all together, the vocabulary collation function would look something like this:
1def collate_vocabs(
2 vocab_fps: typing.List[str],
3 min_df: int,
4 max_features: int
5) -> typing.List[typing.Tuple[int, str]]:
6 # Gather all of the `VocabReader` instances
7 readers = [VocabReader(vocab_fp) for vocab_fp in vocabs]
8
9 # Use a limited counter to track which tokens have been seen already
10 seen_tokens = LimitedCounter(limit=len(readers))
11
12 # Use a heap to keep track of the most frequent tokens
13 # This allows for fast inserts and fast access to minimum
14 vocab_heap = []
15 heapq.heapify(vocab_heap)
16
17 # Iterate until we've exhausted every vocabulary stack
18 while not all(reader.terminated for reader in readers):
19 # Create a new token2reader instance
20 token2reader = create_token2reader()
21
22 # Peek at the next token on each reader's stack
23 for reader in readers:
24 try:
25 reader_token = reader.peek()
26 except StopIteration:
27 continue
28
29 # Add the token and reader to the `token2reader` mapping
30 token2reader[reader_token].append(reader)
31
32 # Find the token with the minimum value across all readers
33 min_val_token = min(token2reader.keys())
34
35 # Fetch all the readers that have the min_val_token on top
36 vocab_readers_with_matches = token2reader[min_val_token]
37
38 # Sum up the counts for the min_val_token in each stack
39 token_count = 0
40 for reader in vocab_readers_with_matches:
41 token_count += reader.keep()
42
43 # If the count is too small to include it in the final vocab,
44 # remove it
45 if token_count < min_df:
46 continue
47
48 # Finally, add the (count, token) tuple to the heap,
49 # removing the lowest count token when necessary
50 elif len(vocab_heap) < max_features:
51 heapq.heappush(vocab_heap, (token_count, min_val_token))
52
53 elif len(vocab_heap) == max_features:
54 heapq.heappushpop(vocab_heap, (token_count, min_val_token))
55
56 # Add the token to the seen tokens collection
57 seen_tokens.add(min_val_token)
58
59 # Finally construct and output the final vocabulary
60 # `vocab_heap` stores the tokens as (count, token) tuples
61 vocab = {term: i for i, term in enumerate(sorted(map(lambda x: x[1], vocab_heap)))}
62
63 return vocab
Voilà, we have a dictionary of exactly $n_{\text{max_terms}}$ where the count of each item is the same as if we’d computed on the entire corpus in one go. At no point did the memory consumption exceed the number of tokens present in a single batch, allowing for work on very large datasets without RAM being a constraint. I added one more variable here than strictly necessary: min_df
. Very infrequent words likely only add noise, so the user can (somewhat arbitrarily) cull those terms before being added to the heap. As a result, we can also be certain that all tokens in our dictionary occur at least min_df
times.
for example, imagine our cluster is a thin ring (or its $k$-dimensional equivalent). The mean would lie in the middle, far away from the cluster. The mode is spread evenly across the surface of the ring, and the median is not clearly defined. Choosing a single point to represent the cluster remains difficult ↩︎
peek and keep being antigrams here is a happy, but unintended coincidence ↩︎