Sunday, October 30, 2016

Using beam search to generate the most probable sentence

In my last blog post I talked about how to generate random text using a language model that gives the probability of a particular word following a prefix of a sentence. For example, given the prefix "The dog", a language model might tell you that "barked" has a 5% chance of being the next word whilst "meowed" has a 0.3%. It's one thing generating random text in a way that is guided by the probabilities of the words but it is an entirely different thing to generate the most probable text according to the probabilities. By most probable text we mean that if you multiply the probabilities of all the words in the sentence together you get the maximum product. This is useful for conditioned language models which give you different probabilities depending on some extra input, such as an image description generator which accepts an image apart from the prefix and returns probabilities for the next word depending on what's in the image. In this case we'd like to find the most probable description for a particular image.

You might think that the solution is to always pick the most probable word and add it to the prefix. This is called a greedy search and is known to not give the optimal solution. The reason is because it might be the case that every combination of words following the best first word might not be as good as those following the second best word. We need to use a more exploratory search than greedy search. We can do this by thinking of the problem as searching a probability tree like this:



The tree shows a probability tree of which words can follow a sequence of words together with their probabilities. To find the probability of a sentence you multiply every probability in the sentence's path from <start> to <end>. For example, the sentence "the dog barked" has a probability of 75% × 73% × 25% × 100% = 13.7%. The problem we want to solve is how to find the sentence that has the highest probability.

One way to do this is to use a breadth first search. Starting from the <start> node, go through every node connected to it, then to every node connected to those nodes and so on. Each node represents a prefix of a sentence. For each prefix compute its probability, which is the product of all the probabilities on its path from the <start> node. As soon as the most probable prefix found is a complete sentence, that would be the most probable sentence. The reason why no other less probable prefixes could ever result in more probable sentences is because as a prefix grows, its probability shrinks. This is because any additional multiplications with probabilities made to any prefix probability will only make it smaller. For example, if a prefix has a probability of 20% and another word is added to it which has a probability of 99%, then the new probability will be 20% × 99% which is the smaller probability of 19.8%.

Of course a breadth first search is impractical on any language model that has a realistic vocabulary and sentence length since it would take too long to check all the prefixes in the tree. We can instead opt to take a more approximate approach where we only check a subset of the prefixes. The subset would be the top 10 most probable prefixes found at that point. We do a breadth first search as explained before but this time only the top 10 most probable prefixes are kept and we stop when the most probable prefix in these top 10 prefixes is a complete sentence.

This is practical but it's important that the way we find the top 10 prefixes is fast. We can't sort all the prefixes and choose the first 10 as there would be too many. We can instead use a heap data structure. This data structure is designed to quickly take in a bunch of numbers and quickly pop out the smallest number. With this you can insert the prefix probabilities one by one until there are 10 prefixes in it. After that start comparing the next prefix probability with the smallest probability and keep the largest of them.

Here is Python 3 code of a class for this heap data structure.

class NBest(object):

    #################################################################
    @staticmethod
    def _left(pos):
        return 2*pos + 1

    #################################################################
    @staticmethod
    def _parent(pos):
        return (pos+1)//2 - 1
    
    #################################################################
    def __init__(self, max_size):
        self.array = list()
        self.max_size = max_size

    #################################################################
    def add(self, prob, complete, prefix):
        #(prob, complete) tuple used for comparison so that if probabilities are equal then a complete prefix is better than an incomplete one since (0.5, False) < (0.5, True)
        if len(self.array) == self.max_size:
            (min_prob, min_complete, _) = self.array[0]
            if (prob, complete) < (min_prob, min_complete):
                return
            else:
                self.array[0] = (prob, complete, prefix)
                pos = 0
                last_pos = len(self.array)-1
                while True:
                    left_pos = NBest._left(pos)
                    if left_pos > last_pos:
                        break
                    (left_prob, left_complete, _) = self.array[left_pos]

                    right_pos = left_pos + 1
                    if right_pos > last_pos:
                        min_pos = left_pos
                        min_prob = left_prob
                        min_complete = left_complete
                    else:
                        (right_prob, right_complete, _) = self.array[right_pos]
                        if (left_prob, left_complete) < (right_prob, right_complete):
                            min_pos = left_pos
                            min_prob = left_prob
                            min_complete = left_complete
                        else:
                            min_pos = right_pos
                            min_prob = right_prob
                            min_complete = right_complete
                    
                    if (prob, complete) > (min_prob, min_complete):
                        (self.array[pos], self.array[min_pos]) = (self.array[min_pos], self.array[pos])
                        pos = min_pos
                    else:
                        break
        else:
            self.array.append((prob, complete, prefix))
            pos = len(self.array)-1
            while True:
                if pos == 0:
                    break
                parent_pos = NBest._parent(pos)
                (parent_prob, parent_complete, _) = self.array[parent_pos]
                if (prob, complete) < (parent_prob, parent_complete):
                    (self.array[pos], self.array[parent_pos]) = (self.array[parent_pos], self.array[pos])
                    pos = parent_pos
                else:
                    break

The code to perform the actual beam search is this:

def beamsearch(probabilities_function, beam_width=10):
    prefix = ['<start>']
    beam = [ (1.0, False, prefix) ]
    
    while True:
        heap = NBest(beam_width)
        for (prefix_prob, complete, prefix) in beam:
            if complete == True:
                continue
            for (next_prob, next_word) in probabilities_function(prefix):
                if next_word == '<end>':
                    heap.add(prefix_prob*next_prob, True, prefix)
                else:
                    heap.add(prefix_prob*next_prob, False, prefix+[next_word])
        beam = heap.array
        
        (best_prob, best_complete, best_prefix) = max(beam, key=lambda x:(x[0],x[1]))
        if best_complete == True:
            return (best_prefix, best_prob)

"probabilities_function" returns a list of word/probability pairs given a prefix. "beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example). By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process.