-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy patharithmetic.py
303 lines (265 loc) · 13.6 KB
/
arithmetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
__author__ = "Bill MacCartney"
__copyright__ = "Copyright 2015, Bill MacCartney"
__credits__ = []
__license__ = "GNU General Public License, version 2.0"
__version__ = "0.9"
__maintainer__ = "Bill MacCartney"
__email__ = "See the author's website"
from collections import defaultdict
from numbers import Number
from domain import Domain
from example import Example
from experiment import evaluate_for_domain, evaluate_dev_examples_for_domain, train_test, train_test_for_domain, interact, learn_lexical_semantics, generate
from metrics import DenotationAccuracyMetric
from parsing import Grammar, Rule
from scoring import rule_features
# ArithmeticDomain =============================================================
# TODO: comment.
class ArithmeticDomain(Domain):
def train_examples(self):
return [
Example(input="one plus one", semantics=('+', 1, 1), denotation=2),
Example(input="one plus two", semantics=('+', 1, 2), denotation=3),
Example(input="one plus three", semantics=('+', 1, 3), denotation=4),
Example(input="two plus two", semantics=('+', 2, 2), denotation=4),
Example(input="two plus three", semantics=('+', 2, 3), denotation=5),
Example(input="three plus one", semantics=('+', 3, 1), denotation=4),
Example(input="three plus minus two", semantics=('+', 3, ('~', 2)), denotation=1),
Example(input="two plus two", semantics=('+', 2, 2), denotation=4),
Example(input="three minus two", semantics=('-', 3, 2), denotation=1),
Example(input="minus three minus two", semantics=('-', ('~', 3), 2), denotation=-5),
Example(input="two times two", semantics=('*', 2, 2), denotation=4),
Example(input="two times three", semantics=('*', 2, 3), denotation=6),
Example(input="three plus three minus two", semantics=('-', ('+', 3, 3), 2), denotation=4),
]
def test_examples(self):
return [
Example(input="minus three", semantics=('~', 3), denotation=-3),
Example(input="three plus two", semantics=('+', 3, 2), denotation=5),
Example(input="two times two plus three", semantics=('+', ('*', 2, 2), 3), denotation=7),
Example(input="minus four", semantics=('~', 4), denotation=-4),
]
def dev_examples(self):
return arithmetic_dev_examples
numeral_rules = [
Rule('$E', 'one', 1),
Rule('$E', 'two', 2),
Rule('$E', 'three', 3),
Rule('$E', 'four', 4),
]
operator_rules = [
Rule('$UnOp', 'minus', '~'),
Rule('$BinOp', 'plus', '+'),
Rule('$BinOp', 'minus', '-'),
Rule('$BinOp', 'times', '*'),
]
compositional_rules = [
Rule('$E', '$UnOp $E', lambda sems: (sems[0], sems[1])),
Rule('$EBO', '$E $BinOp', lambda sems: (sems[1], sems[0])),
Rule('$E', '$EBO $E', lambda sems: (sems[0][0], sems[0][1], sems[1])),
]
def rules(self):
return self.numeral_rules + self.operator_rules + self.compositional_rules
def operator_precedence_features(self, parse):
"""
Traverses the arithmetic expression tree which forms the semantics of
the given parse and adds a feature (op1, op2) whenever op1 appears
lower in the tree than (i.e. with higher precedence than) than op2.
"""
def collect_features(semantics, features):
if isinstance(semantics, tuple):
for child in semantics[1:]:
collect_features(child, features)
if isinstance(child, tuple) and child[0] != semantics[0]:
features[(child[0], semantics[0])] += 1.0
features = defaultdict(float)
collect_features(parse.semantics, features)
return features
def features(self, parse):
features = rule_features(parse)
features.update(self.operator_precedence_features(parse))
return features
def weights(self):
weights = defaultdict(float)
weights[('*', '+')] = 1.0
weights[('*', '-')] = 1.0
weights[('~', '+')] = 1.0
weights[('~', '-')] = 1.0
weights[('+', '*')] = -1.0
weights[('-', '*')] = -1.0
weights[('+', '~')] = -1.0
weights[('-', '~')] = -1.0
return weights
def grammar(self):
return Grammar(rules=self.rules(), start_symbol='$E')
ops = {
'~': lambda x: -x,
'+': lambda x, y: x + y,
'-': lambda x, y: x - y,
'*': lambda x, y: x * y,
}
def execute(self, semantics):
if isinstance(semantics, tuple):
op = self.ops[semantics[0]]
args = [self.execute(arg) for arg in semantics[1:]]
return op(*args)
else:
return semantics
def training_metric(self):
return DenotationAccuracyMetric()
# EagerArithmeticDomain ========================================================
# TODO: add comment.
class EagerArithmeticDomain(Domain):
def train_examples(self):
return [convert_example(ex) for ex in ArithmeticDomain().train_examples()]
def test_examples(self):
return [convert_example(ex) for ex in ArithmeticDomain().test_examples()]
def dev_examples(self):
return [convert_example(ex) for ex in ArithmeticDomain().dev_examples()]
numeral_rules = ArithmeticDomain.numeral_rules
operator_rules = [
Rule('$BinOp', 'plus', lambda x: (lambda y: x + y)),
Rule('$BinOp', 'minus', lambda x: (lambda y: x - y)),
Rule('$BinOp', 'times', lambda x: (lambda y: x * y)),
Rule('$UnOp', 'minus', lambda x: -1 * x),
]
compositional_rules = [
Rule('$E', '$EBO $E', lambda sems: sems[0](sems[1])),
Rule('$EBO', '$E $BinOp', lambda sems: sems[1](sems[0])),
Rule('$E', '$UnOp $E', lambda sems: sems[0](sems[1])),
]
def rules(self):
return self.numeral_rules + self.operator_rules + self.compositional_rules
def grammar(self):
return Grammar(rules=self.rules(), start_symbol='$E')
def execute(self, semantics):
return semantics
def training_metric(self):
return DenotationAccuracyMetric()
def convert_example(example):
return Example(input=example.input,
semantics=example.denotation,
denotation=example.denotation)
# ==============================================================================
arithmetic_dev_examples = [
Example(input='three plus four', denotation=7),
Example(input='one times one', denotation=1),
Example(input='four plus one plus four', denotation=9),
Example(input='minus three plus two', denotation=-1),
Example(input='minus three plus three', denotation=0),
Example(input='minus four minus minus three', denotation=-1),
Example(input='four minus minus three', denotation=7),
Example(input='two plus one', denotation=3),
Example(input='minus one minus four minus four', denotation=-9),
Example(input='one plus minus one plus minus one', denotation=-1),
Example(input='two times minus two plus three', denotation=-1),
Example(input='two times minus three', denotation=-6),
Example(input='four times two', denotation=8),
Example(input='one plus four', denotation=5),
Example(input='four minus one', denotation=3),
Example(input='minus one times one times one', denotation=-1),
Example(input='minus minus two', denotation=2),
Example(input='one minus three times minus two times three', denotation=19),
Example(input='minus two minus four', denotation=-6),
Example(input='one minus two', denotation=-1),
Example(input='three minus one', denotation=2),
Example(input='minus three minus minus minus four', denotation=-7),
Example(input='minus three plus four', denotation=1),
Example(input='minus four minus four times minus one', denotation=0),
Example(input='minus minus three plus two', denotation=5),
Example(input='four plus three', denotation=7),
Example(input='minus three plus one', denotation=-2),
Example(input='minus two times one minus minus two', denotation=0),
Example(input='one plus minus two', denotation=-1),
Example(input='four plus four', denotation=8),
Example(input='two minus one', denotation=1),
Example(input='one plus minus three times four plus four times four', denotation=5),
Example(input='minus one times three plus two', denotation=-1),
Example(input='minus three times one minus minus three', denotation=0),
Example(input='four plus four minus four minus one', denotation=3),
Example(input='four minus one times three minus one', denotation=0),
Example(input='two minus minus one', denotation=3),
Example(input='minus minus three minus two', denotation=1),
Example(input='one times minus four minus four plus one plus one', denotation=-6),
Example(input='two plus four times two', denotation=10),
Example(input='one plus two times one', denotation=3),
Example(input='three minus four', denotation=-1),
Example(input='two times two', denotation=4),
Example(input='three minus minus three plus two minus minus three', denotation=11),
Example(input='three minus minus three times two', denotation=9),
Example(input='minus three times four times two', denotation=-24),
Example(input='minus four minus four', denotation=-8),
Example(input='three minus minus minus one', denotation=2),
Example(input='two minus four', denotation=-2),
Example(input='four times four minus one times three', denotation=13),
Example(input='four minus three times three', denotation=-5),
Example(input='minus three plus minus one', denotation=-4),
Example(input='one minus three', denotation=-2),
Example(input='minus one minus two', denotation=-3),
Example(input='one times four times three', denotation=12),
Example(input='minus three times one', denotation=-3),
Example(input='three minus minus three', denotation=6),
Example(input='three times minus minus minus four', denotation=-12),
Example(input='minus one minus three', denotation=-4),
Example(input='minus four plus one times three times four minus four', denotation=4),
Example(input='minus minus four plus four plus minus three', denotation=5),
Example(input='two minus minus three', denotation=5),
Example(input='four plus one minus one times four', denotation=1),
Example(input='three times two', denotation=6),
Example(input='four plus three times minus two plus minus one', denotation=-3),
Example(input='minus three minus one', denotation=-4),
Example(input='minus minus two times four', denotation=8),
Example(input='one plus three minus minus two minus minus minus four', denotation=2),
Example(input='minus one minus one plus four plus three', denotation=5),
Example(input='three times three minus one', denotation=8),
Example(input='two minus four minus minus three', denotation=1),
Example(input='minus minus three minus minus one minus three', denotation=1),
Example(input='three plus two', denotation=5),
Example(input='minus minus three', denotation=3),
Example(input='minus minus three times one', denotation=3),
Example(input='minus two plus four', denotation=2),
Example(input='two minus minus two', denotation=4),
Example(input='one plus three', denotation=4),
Example(input='one times four', denotation=4),
Example(input='minus three minus minus minus four plus four plus one', denotation=-2),
Example(input='three times four minus two minus two minus three', denotation=5),
Example(input='minus three minus three times minus minus minus minus two', denotation=-9),
Example(input='minus four times minus two', denotation=8),
Example(input='minus three plus two times three minus minus minus four', denotation=-1),
Example(input='four times three', denotation=12),
Example(input='minus minus three plus minus four', denotation=-1),
Example(input='minus four times four', denotation=-16),
Example(input='two plus minus one', denotation=1),
Example(input='minus minus minus three plus minus one', denotation=-4),
Example(input='three plus one minus minus two', denotation=6),
Example(input='minus four times minus four', denotation=16),
Example(input='four plus minus two', denotation=2),
Example(input='two times four', denotation=8),
Example(input='minus minus minus four minus one times three plus two', denotation=-5),
Example(input='one minus one', denotation=0),
Example(input='minus minus one', denotation=1),
Example(input='minus minus minus four', denotation=-4),
Example(input='four plus two', denotation=6),
Example(input='two minus three', denotation=-1),
Example(input='minus four plus two', denotation=-2),
]
def train_on_dev_experiment():
from metrics import denotation_match_metrics
domain = ArithmeticDomain()
train_test(model=domain.model(),
train_examples=arithmetic_dev_examples,
test_examples=domain.test_examples(),
metrics=denotation_match_metrics(),
training_metric=DenotationAccuracyMetric(),
seed=1,
print_examples=False)
# ==============================================================================
if __name__ == '__main__':
evaluate_for_domain(ArithmeticDomain())
# train_test_for_domain(ArithmeticDomain(), seed=1, print_examples=False)
# train_on_dev_experiment()
# evaluate_for_domain(EagerArithmeticDomain())
# evaluate_dev_examples_for_domain(ArithmeticDomain())
# interact(ArithmeticDomain(), "two times two plus three")
# learn_lexical_semantics(ArithmeticDomain(), seed=1)
# generate(ArithmeticDomain().rules(), '$E')