-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbeam_search.py
103 lines (77 loc) · 3.62 KB
/
beam_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
class Beam:
def __init__(self, beam_size=8, min_length=0, n_top=1, ranker=None,
start_token_id=1, end_token_id=2):
self.beam_size = beam_size
self.min_length = min_length
self.ranker = ranker
self.end_token_id = end_token_id
self.top_sentence_ended = False
self.prev_ks = []
self.next_ys = [torch.LongTensor(beam_size).fill_(start_token_id)] # remove padding
self.current_scores = torch.FloatTensor(beam_size).zero_()
self.all_scores = []
# Time and k pair for finished.
self.finished = []
self.n_top = n_top
self.ranker = ranker
def advance(self, next_log_probs):
# next_probs : beam_size X vocab_size
vocabulary_size = next_log_probs.size(1)
# current_beam_size = next_log_probs.size(0)
current_length = len(self.next_ys)
if current_length < self.min_length:
for beam_index in range(len(next_log_probs)):
next_log_probs[beam_index][self.end_token_id] = -1e10
if len(self.prev_ks) > 0:
beam_scores = next_log_probs + self.current_scores.unsqueeze(1).expand_as(next_log_probs)
# Don't let EOS have children.
last_y = self.next_ys[-1]
for beam_index in range(last_y.size(0)):
if last_y[beam_index] == self.end_token_id:
beam_scores[beam_index] = -1e10 # -1e20 raises error when executing
else:
beam_scores = next_log_probs[0]
flat_beam_scores = beam_scores.view(-1)
top_scores, top_score_ids = flat_beam_scores.topk(k=self.beam_size, dim=0, largest=True, sorted=True)
self.current_scores = top_scores
self.all_scores.append(self.current_scores)
prev_k = top_score_ids // vocabulary_size # (beam_size, )
next_y = top_score_ids - prev_k * vocabulary_size # (beam_size, )
self.prev_ks.append(prev_k)
self.next_ys.append(next_y)
for beam_index, last_token_id in enumerate(next_y):
if last_token_id == self.end_token_id:
# skip scoring
self.finished.append((self.current_scores[beam_index], len(self.next_ys) - 1, beam_index))
if next_y[0] == self.end_token_id:
self.top_sentence_ended = True
def get_current_state(self):
"Get the outputs for the current timestep."
return torch.stack(self.next_ys, dim=1)
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prev_ks[-1]
def done(self):
return self.top_sentence_ended and len(self.finished) >= self.n_top
def get_hypothesis(self, timestep, k):
hypothesis = []
for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
hypothesis.append(self.next_ys[j + 1][k])
# for RNN, [:, k, :], and for trnasformer, [k, :, :]
k = self.prev_ks[j][k]
return hypothesis[::-1]
def sort_finished(self, minimum=None):
if minimum is not None:
i = 0
# Add from beam until we have minimum outputs.
while len(self.finished) < minimum:
# global_scores = self.global_scorer.score(self, self.scores)
# s = global_scores[i]
s = self.current_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
i += 1
self.finished = sorted(self.finished, key=lambda a: a[0], reverse=True)
scores = [sc for sc, _, _ in self.finished]
ks = [(t, k) for _, t, k in self.finished]
return scores, ks