Skip to content

Commit

Permalink
fixes for logreg predict_proba, knnreg, inc cov, inc pca
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanglaser committed Jan 18, 2025
1 parent bb5206f commit 8211a23
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 35 deletions.
4 changes: 2 additions & 2 deletions onedal/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def _predict_proba(self, X, module, queue):
result = self._infer(X, module, queue, sua_iface)

y = from_table(result.probabilities, sua_iface=sua_iface, sycl_queue=queue, xp=xp)
y = y.reshape(-1, 1)
return xp.hstack([1 - y, y])
y = xp.reshape(y, (-1, 1))
return xp.concat([1 - y, y], axis=0)

def _predict_log_proba(self, X, module, queue):
_, xp, _ = _get_sycl_namespace(X)
Expand Down
17 changes: 11 additions & 6 deletions onedal/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
kdtree_knn_classification_prediction,
kdtree_knn_classification_training,
)
from ..utils._array_api import _get_sycl_namespace

from .._config import _get_config
from ..common._base import BaseEstimator
Expand Down Expand Up @@ -205,11 +206,14 @@ def _fit(self, X, y, queue):
self, "effective_metric_params_", self.metric_params
)

_, xp, _ = _get_sycl_namespace(X)
use_raw_input = _get_config().get("use_raw_input", False) is True
if y is not None or self.requires_y:
shape = getattr(y, "shape", None)
X, y = super()._validate_data(
X, y, dtype=[np.float64, np.float32], accept_sparse="csr"
)
if not use_raw_input:
X, y = super()._validate_data(
X, y, dtype=[np.float64, np.float32], accept_sparse="csr"
)
self._shape = shape if shape is not None else y.shape

if _is_classifier(self):
Expand All @@ -233,7 +237,7 @@ def _fit(self, X, y, queue):
self._validate_n_classes()
else:
self._y = y
else:
elif not use_raw_input:
X, _ = super()._validate_data(X, dtype=[np.float64, np.float32])

self.n_samples_fit_ = X.shape[0]
Expand Down Expand Up @@ -261,7 +265,7 @@ def _fit(self, X, y, queue):
result = self._onedal_fit(X, _fit_y, queue)

if y is not None and _is_regressor(self):
self._y = y if self._shape is None else y.reshape(self._shape)
self._y = y if self._shape is None else xp.reshape(y, self._shape)

self._onedal_model = result
result = self
Expand Down Expand Up @@ -625,7 +629,8 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None)
return super()._kneighbors(X, n_neighbors, return_distance, queue=queue)

def _predict_gpu(self, X, queue=None):
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
if _get_config()["use_raw_input"] is False:
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
onedal_model = getattr(self, "_onedal_model", None)
n_features = getattr(self, "n_features_in_", None)
n_samples_fit_ = getattr(self, "n_samples_fit_", None)
Expand Down
38 changes: 22 additions & 16 deletions sklearnex/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from sklearnex import config_context

from .._config import get_config
from .._device_offload import dispatch, wrap_output_data
from .._utils import IntelEstimator, PatchingConditionsChain, register_hyperparameters
from ..metrics import pairwise_distances
Expand Down Expand Up @@ -186,6 +187,9 @@ def location_(self):
def _onedal_partial_fit(self, X, queue=None, check_input=True):
first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0

use_raw_input = get_config()["use_raw_input"]
# never check input when using raw input
check_input &= use_raw_input is False
# finite check occurs on onedal side
if check_input:
if sklearn_check_version("1.2"):
Expand Down Expand Up @@ -333,23 +337,25 @@ def _onedal_fit(self, X, queue=None):
if hasattr(self, "_onedal_estimator"):
self._onedal_estimator._reset()

if sklearn_check_version("1.2"):
self._validate_params()
use_raw_input = get_config()["use_raw_input"]
if not use_raw_input:
if sklearn_check_version("1.2"):
self._validate_params()

# finite check occurs on onedal side
if sklearn_check_version("1.0"):
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
copy=self.copy,
force_all_finite=False,
)
else:
X = check_array(
X, dtype=[np.float64, np.float32], copy=self.copy, force_all_finite=False
)
self.n_features_in_ = X.shape[1]
# finite check occurs on onedal side
if sklearn_check_version("1.0"):
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
copy=self.copy,
force_all_finite=False,
)
else:
X = check_array(
X, dtype=[np.float64, np.float32], copy=self.copy, force_all_finite=False
)
self.n_features_in_ = X.shape[1]

self.batch_size_ = self.batch_size if self.batch_size else 5 * self.n_features_in_

Expand Down
26 changes: 15 additions & 11 deletions sklearnex/preview/decomposition/incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def _onedal_transform(self, X, queue=None):
assert hasattr(self, "_onedal_estimator")
if self._need_to_finalize:
self._onedal_finalize_fit()
X = check_array(X, dtype=[np.float64, np.float32])
use_raw_input = get_config()["use_raw_input"]
if not use_raw_input:
X = check_array(X, dtype=[np.float64, np.float32])
return self._onedal_estimator.predict(X, queue)

def _onedal_fit_transform(self, X, queue=None):
Expand Down Expand Up @@ -125,17 +127,19 @@ def _onedal_finalize_fit(self, queue=None):
self._need_to_finalize = False

def _onedal_fit(self, X, queue=None):
if sklearn_check_version("1.2"):
self._validate_params()
use_raw_input = get_config()["use_raw_input"]
if not use_raw_input:
if sklearn_check_version("1.2"):
self._validate_params()

if sklearn_check_version("1.0"):
X = validate_data(self, X, dtype=[np.float64, np.float32], copy=self.copy)
else:
X = check_array(
X,
dtype=[np.float64, np.float32],
copy=self.copy,
)
if sklearn_check_version("1.0"):
X = validate_data(self, X, dtype=[np.float64, np.float32], copy=self.copy)
else:
X = check_array(
X,
dtype=[np.float64, np.float32],
copy=self.copy,
)

n_samples, n_features = X.shape

Expand Down

0 comments on commit 8211a23

Please sign in to comment.