Sometimes it is not enough to just generate the most probable sentence using a language model. Sometimes you want to generate the top 3 most probable sentences instead. In that case we need to modify our beam search a bit. We will make the function a generator that returns a sequence of sentences in order of probability instead of just returning a single most probable sentence. Here are the changes we need to make:
In the single sentence version, we were getting the most probable prefix in the current beam and checking if it is complete. If it is, then we return it and stop there. Instead, we will now not stop until the current beam is empty (or until the caller stops requesting for more sentences). After returning the most probable prefix we will check the second most probable prefix and keep on returning complete prefixes until we either find one which is not complete or we return all the beam. In the case that we return the whole beam then the algorithm stops there as there is nothing left with which to generate new prefixes. This means that the beam width gives a limit on the number of sentences that can be returned. If we do not return all the beam then we continue generating prefixes with the remainder.
In the case that some complete sentences were returned, they need to also be removed from the beam before we continue generating. Since the beam is implemented as a min-first heap queue (min-first because we want to pop the least probable prefix quickly when the beam becomes bigger than the beam width) then we cannot remove the highest probable complete sentence quickly as well. In order to do this, we first turn the heap into a list which is sorted by probability and then start popping out the items at the end if they are complete sentences. Following this, we will then heapify the list back into a min-first heap queue and continue as normal. This sorting and reheapifying should not impact on the performance too much if the beam width is relatively small.
If the clip length is reached then the whole beam is immediately returned in order of probability. This is because as soon as one prefix is equal to the allowed maximum then that means that the entire beam consists of
- incomplete sentences that are also as long as the allowed maximum (since all the prefixes grow together)
- complete sentences that were found before but which do not have a maximum probability
Here is the modified Python 3 code:
import heapq class Beam(object): def __init__(self, beam_width, init_beam=None): if init_beam is None: self.heap = list() else: self.heap = init_beam heapq.heapify(self.heap) #make the list a heap self.beam_width = beam_width def add(self, prob, complete, prefix): heapq.heappush(self.heap, (prob, complete, prefix)) if len(self.heap) > self.beam_width: heapq.heappop(self.heap) def __iter__(self): return iter(self.heap) def beamsearch(probabilities_function, beam_width=10, clip_len=-1): prev_beam = Beam(beam_width) prev_beam.add(1.0, False, [ '<start>' ]) while True: curr_beam = Beam(beam_width) #Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. for (prefix_prob, complete, prefix) in prev_beam: if complete == True: curr_beam.add(prefix_prob, True, prefix) else: #Get probability of each possible next word for the incomplete prefix. for (next_prob, next_word) in probabilities_function(prefix): if next_word == '<end>': #if next word is the end token then mark prefix as complete and leave out the end token curr_beam.add(prefix_prob*next_prob, True, prefix) else: #if next word is a non-end token then mark prefix as incomplete curr_beam.add(prefix_prob*next_prob, False, prefix+[next_word]) sorted_beam = sorted(curr_beam) #get all prefixes in current beam sorted by probability any_removals = False while True: (best_prob, best_complete, best_prefix) = sorted_beam[-1] #get highest probability prefix if best_complete == True or len(best_prefix)-1 == clip_len: #if most probable prefix is a complete sentence or has a length that exceeds the clip length (ignoring the start token) then yield it yield (best_prefix[1:], best_prob) #yield best sentence without the start token and together with its probability sorted_beam.pop() #remove the yielded sentence and check the next highest probability prefix any_removals = True if len(sorted_beam) == 0: #if there are no more sentences in the beam then stop checking break else: break if any_removals == True: #if there were any changes in the current beam then... if len(sorted_beam) == 0: #if there are no more prefixes in the current beam (due to clip length being reached) then end the beam search break else: #otherwise set the previous beam to the modified current beam prev_beam = Beam(beam_width, sorted_beam) else: #if the current beam was left unmodified then assign it to the previous beam as is prev_beam = curr_beam