BeamSearch的基礎知識如下:
解碼及貪心搜索
生成式任務相比普通的分類、tagging等NLP任務會復雜不少。在生成的時候,模型的輸出是壹個時間步壹個時間步依次獲得的,而且前面時間步的結果還會影響後面時間步的結果。也就是說,每壹個時間步,模型給出的都是基於歷史生成結果的條件概率。為了生成完整的句子,需要壹個稱為解碼的額外動作來融合模型多個時間步的輸出,而且使得最終得到的序列的每壹步條件概率連乘起來最大。
在文本生成任務中,每壹個時間步可能的輸出種類稱為字典大小(vocabulary size,我們用表示),進行T步隨機的生成可能獲得的結果總***有種。拿中文文本生成來說,的值大約是5000-6000,即常用漢字的個數。在如此大的基數下,遍歷整個生成空間是不現實的。
最容易想到的策略是貪心搜索,即每壹個時間步都取出壹個條件概率最大的輸出,再將從開始到當前步的結果作為輸入去獲得下壹個時間步的輸出,直到模型給出生成結束的標誌。例如下圖,每壹個時間步都取出了條件概率最大壹個結果,生成了序列[A,B,C]。
很明顯,這樣做將原來指數級別的求解空間直接壓縮到了與長度線性相關的大小。由於丟棄了絕大多數的可能解,這種關註當下的策略無法保證最終得到的序列概率是最優的。
Beam Search
而beam search是對貪心策略壹個改進。思路也很簡單,就是稍微放寬壹些考察的範圍。在每壹個時間步,不再只保留當前分數最高的1個輸出,而是保留num_beams個。當num_beams=1時集束搜索就退化成了貪心搜索。
下圖是壹個實際的例子,每個時間步有ABCDE***5種可能的輸出,即,圖中的num_beams=2,也就是說每個時間步都會保留到當前步為止條件概率最優的2個序列。
在第壹個時間步,A和C是最優的兩個,因此得到了兩個結果[A],[C],其他三個就被拋棄了;第二步會基於這兩個結果繼續進行生成,在A這個分支可以得到5個候選人,[AA],[AB],[AC],[AD],[AE],C也同理得到5個,此時會對這10個進行統壹排名,再保留最優的兩個,即圖中的[AB]和[CE];第三步同理,也會從新的10個候選人裏再保留最好的兩個,最後得到了[ABD],[CED]兩個結果。
可以發現,beam search在每壹步需要考察的候選人數量是貪心搜索的num_beams倍,因此是壹種犧牲時間換性能的方法。
以上就是Beam Search的基本概念,下面我們解析壹種高效率實現方式。
Beam Search代碼解析
Beam Search的原理雖然簡單,但實際實現的時候卻有很多細節要考慮。下面要解析這個實現出自於NLP界著名Python包Transformers,我為了說明方便做了壹些改動。
壹個正確且高效的算法需要處理的問題大概有兩個:
充分利用硬件,可以處理批量數據,且盡量使用並行計算少用循環處理好長短不同的生成結果。
下面是基礎版的beam search函數定義。其中context是編碼器編碼獲得的向量,batch_size是每批數據中包含的樣本量,bos_token_id是句子開頭標誌的token id,pad_token_id是用於填充的token id,eos_token_id是句子結束標誌的token id。這裏給參數填上的默認值和我們後面講解時使用的例子是壹致的。
在函數中主要執行以下三個步驟:
1、準備初始輸入
2、在當前生成的序列長度未達到max_length時擴展生成序列
3、準備最終輸出的序列
下面我們分別解析。
準備初始輸入
其中BeamHypotheses是壹個容器類,每個樣本綁定壹個。每個容器中會維護num_beams個當前最優的序列。當往容器中添加壹個序列而導致序列數大於num_beams的時候,它會自動踢掉分數最低的那個序列。類代碼如下。
序列擴展
序列擴展是beam search的核心過程,我們特地畫了壹張圖來解釋這個版本的實現策略。
下面對照這個圖來講解代碼。
乍壹看是不是有些復雜,我感覺關鍵的有以下幾點:
1、只有出現了EOS token才會將生成的序列裝進該樣本對應的容器中
2、當前input_ids保存著當前得分最高的num_beams個序列
準備輸出
上面那個while循環跳出意味著已經生成了長度為max_length的文本,比較理想的情況是所有的句子都已經生成出了eos_token_id,即句子生成結束了。但並不是所有情況都這樣,對於那些”意猶未盡“的樣本,我們需要先手動結束。
經過上面的處理,所有生成好的句子都已經保存在generated_hyps容器中,每個容器內保存著num_beams個序列,最後就是輸出期望個數的句子。
總結
好了,上面就是最基礎的beam search算法。這樣生成出來的結果已經會比貪心搜索好壹些,但還是會遇到諸如詞語重復這樣的問題。其實已經有很多針對重復問題的研究,我們在代碼中也已經留出了位置。