-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlstm_test.go
38 lines (33 loc) · 1.08 KB
/
lstm_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package automata_test
import (
"github.com/Kegsay/automata"
"testing"
)
func TestLSTM_ShortTerm(t *testing.T) {
testLookupTable := &automata.LookupTable{}
lstm := automata.NewLSTM(testLookupTable, 1, []int{6}, 1)
trainer := automata.Trainer{
Network: lstm,
MaxErrorRate: 0.001,
LearnRate: 0.2,
Iterations: 10000,
CostFunction: &automata.MeanSquaredErrorCost{},
}
// LSTM must remember where in the sequence it is to produce the right output.
trainSets := []automata.TrainSet{
{[]float64{0}, []float64{0}},
{[]float64{1}, []float64{1}},
{[]float64{1}, []float64{0}},
{[]float64{0}, []float64{1}},
{[]float64{0}, []float64{0}},
}
if err := trainer.Train(trainSets); err != nil {
t.Fatalf("trainer.Train threw error: %s", err.Error())
}
// test it
activateNetwork(t, lstm, []float64{0}, []float64{0})
activateNetwork(t, lstm, []float64{1}, []float64{1})
activateNetwork(t, lstm, []float64{1}, []float64{0})
activateNetwork(t, lstm, []float64{0}, []float64{1})
activateNetwork(t, lstm, []float64{0}, []float64{0})
}