diff --git a/tpe/sampler.go b/tpe/sampler.go index cf3da4fc..1530e3a1 100644 --- a/tpe/sampler.go +++ b/tpe/sampler.go @@ -542,6 +542,9 @@ func getObservationPairs(study *goptuna.Study, paramName string) ([]float64, [][ values := make([]float64, 0, len(trials)) scores := make([][2]float64, 0, len(trials)) for _, trial := range trials { + if trial.State != goptuna.TrialStateComplete && trial.State != goptuna.TrialStatePruned { + continue + } ir, ok := trial.InternalParams[paramName] if !ok { continue @@ -552,7 +555,7 @@ func getObservationPairs(study *goptuna.Study, paramName string) ([]float64, [][ if trial.State == goptuna.TrialStateComplete { score0 = math.Inf(-1) score1 = sign * trial.Value - } else if trial.State == goptuna.TrialStatePruned { + } else { if len(trial.IntermediateValues) > 0 { var step int var intermediateValue float64 @@ -569,8 +572,6 @@ func getObservationPairs(study *goptuna.Study, paramName string) ([]float64, [][ score0 = math.Inf(1) score1 = 0.0 } - } else { - continue } values = append(values, paramValue) scores = append(scores, [2]float64{score0, score1})