Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shine #2

Draft
wants to merge 48 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
b1f3637
added shine idea to hoag lbfgs
zaccharieramzi Mar 23, 2021
d0199f7
added shine flag in logistic clf
zaccharieramzi Mar 23, 2021
ee387a1
added extra kwargs for lfbgs in the logistic clf
zaccharieramzi Mar 23, 2021
0941481
allowed to use shine and extra lbfgs extra kwargs in multilogistic re…
zaccharieramzi Mar 23, 2021
4d335ee
added a debug flag in hoag
zaccharieramzi Mar 31, 2021
095b0d6
added forward and backward times in hoag lbfgs
zaccharieramzi Mar 31, 2021
448f16a
divided the maximum inner iterations in 2 parameters: one for the lfb…
zaccharieramzi Mar 31, 2021
537ff53
added the refining scheme for shine
zaccharieramzi Apr 1, 2021
495eb76
started a file with the benchmark utils
zaccharieramzi Apr 1, 2021
44460dc
changed bench result to a dataclass
zaccharieramzi Apr 1, 2021
65fde48
allowed to set verbose
zaccharieramzi Apr 1, 2021
1bf5f7c
final iteration on dataclass
zaccharieramzi Apr 1, 2021
f1d95a0
corrected computation of median and quantiles
zaccharieramzi Apr 1, 2021
17c27ec
added grid search to logistic cv
zaccharieramzi Apr 2, 2021
a043af8
added random search to logistic cv
zaccharieramzi Apr 2, 2021
200f161
added random and grid serch to benchmark
zaccharieramzi Apr 2, 2021
c249fc1
allowed to change number of points in grid for grid and random search
zaccharieramzi Apr 2, 2021
097bee4
added warm restart to grid search
zaccharieramzi Apr 2, 2021
a58ec06
corrected maxiter in lbfgs for grid search
zaccharieramzi Apr 2, 2021
3f66be1
corrected typo in grid search callback call
zaccharieramzi Apr 2, 2021
a618098
changed callback calling in grid search to make sure the init is take…
zaccharieramzi Apr 2, 2021
7bf7b7e
added reqs file
zaccharieramzi Apr 2, 2021
6184dcf
added real-sim to benchmark and refactored util to have equally-sized…
zaccharieramzi Apr 2, 2021
a92f668
corrected stacking for 20news dataset
zaccharieramzi Apr 2, 2021
61636bd
corrected splitting in benchmark
zaccharieramzi Apr 2, 2021
cd02d3d
added possibility to use a certain train proportion in benchmark
zaccharieramzi Apr 2, 2021
6c2812c
corrected random search
zaccharieramzi Apr 6, 2021
438c1fb
added the jacobian free idea from fpn
zaccharieramzi Apr 9, 2021
6148e2a
made the multilogistic alpha use only 10 dimensions
zaccharieramzi Apr 9, 2021
f55c153
added keras to reqs for mnist
zaccharieramzi Apr 9, 2021
d399b9f
added tf to reqs for mnist manipulation
zaccharieramzi Apr 9, 2021
fae2ec7
corrected hessian
zaccharieramzi Apr 10, 2021
77706e1
allowed hoag lbfgs to handle grouped regularisation for multivariate
zaccharieramzi Apr 10, 2021
9c5fefa
added multilogistic regression to benchmark
zaccharieramzi Apr 10, 2021
1305aa0
added scaling for mnist in benchmark
zaccharieramzi Apr 10, 2021
4977e65
allowed to have a different exponent for backward iterations in multi…
zaccharieramzi Apr 10, 2021
534a462
corrected stacking in mnist dataset
zaccharieramzi Apr 10, 2021
71d3182
added binarization for benchmark mnist
zaccharieramzi Apr 16, 2021
d2641f9
Added pure python LBFGS implem and OPA (#5)
zaccharieramzi May 31, 2021
cb72b65
Figure and Readme (#6)
zaccharieramzi May 31, 2021
3c0370a
ENH update results
tomMoral Sep 30, 2021
df15d70
CLN last changes
tomMoral Oct 3, 2021
905e76c
changed the y label
zaccharieramzi Oct 3, 2021
c15b01f
changed name of add direction in inversion plot to prescribed
zaccharieramzi Oct 4, 2021
1b2e464
corrected bug in appendix figure and changed the y label
zaccharieramzi Oct 4, 2021
285662d
updated readme with new figure layout
zaccharieramzi Oct 4, 2021
baaeb6a
ENH add script for plot
tomMoral Nov 12, 2021
f3b7afe
corrected appendix figure and opa interpolation
zaccharieramzi Nov 12, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,9 @@ dmypy.json

# Pyre type checker
.pyre/


# SAved results
*.csv
*.pdf
*ipynb
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# HOAG - SHINE

This is the first part of the code for the paper "SHINE: SHaring the INverse Estimate from the forward pass for bi-level optimization and implicit models", submitted at the 2022 ICLR conference.
This source code allows to reproduce the experiments on logistic regression, i.e. Figures 1-2, and Figure E.1. in Appendix.

## General instructions

You need Python 3.7 or above to run this code.
You can then install the requirements with: `pip install -r requirements.txt`.

When running the scripts, you will see the following warning printed: `CG did not converge to the desired precision`.
It does not indicate that there is a problem with your current run.

## Reproducing Figure 1, Bi-level optimization

Figure 1. can be reproduced by running the `main_plots.py` script:

```
python main_plots.py
```

By default, the results will be re-computed and saved each time.
If you want to use the results saved from a previous run, you can use the `--no_recomp` flag.
If you want to run a test run without saving the results, you can use the `--no_save` flag.

It will take you about 2 hours to run this script in full.
It will take you about 2 seconds to run this script with saved results.


## Reproducing Figure 2., Bi-level optimization with OPA

Figure 2. can be reproduced by running the `main_plots_opa_df.py` script:

```
python main_plots_opa_df.py
```

By default, the results will be re-computed and saved each time.
If you want to use the results saved from a previous run, you can use the `--no_recomp` flag.
If you want to run a test run without saving the results, you can use the `--no_save` flag.

It will take about 1 hour to run this script in full.
It will take about 10 seconds to run this script with saved results.

## Reproducing Figure E.1., Bi-level optimization

Figure E.1. can be reproduced by running the `main_plots.py` script:

```
python main_plots.py --appendix_figure
```

By default, the results will be re-computed and saved each time.
If you want to use the results saved from a previous run, you can use the `--no_recomp` flag.
If you want to run a test run without saving the results, you can use the `--no_save` flag.

In this case, the raw results are the same as for Figure 1., so you can use these.

It will take you about 2 hours to run this script in full.
It will take you about 2 seconds to run this script with saved results.
135 changes: 12 additions & 123 deletions doc/example_usage.ipynb

Large diffs are not rendered by default.

221 changes: 221 additions & 0 deletions hoag/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from dataclasses import dataclass, fields, field
import time
from typing import List

from libsvmdata import fetch_libsvm
import numpy as np
import pandas as pd
import scipy.sparse as sp
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelBinarizer

from hoag import LogisticRegressionCV, MultiLogisticRegressionCV
from hoag.logistic import _intercept_dot, log_logistic
from hoag.multilogistic import _multinomial_loss


@dataclass
class BenchResult:
lambda_traces: List = field(default_factory=list)
lamda_times: List = field(default_factory=list)
beta_traces: List = field(default_factory=list)
val_losses: List = field(default_factory=list)
test_losses: List = field(default_factory=list)

def __getitem__(self, field_key):
return self.__getattribute__(field_key)

def __setitem__(self, field_key, value):
return self.__setattr__(field_key, value)

def append(self, bench_res):
for f in fields(self):
self[f.name] += [bench_res[f.name]]

def freeze(self):
for f in fields(self):
self[f.name] = np.array(self[f.name])

def median(self):
return BenchResult(
*tuple(np.median(self[f.name], axis=0)
for f in fields(self))
)

def quantile(self, q):
return BenchResult(
*tuple(np.quantile(self[f.name], q, axis=0)
for f in fields(self))
)


def train_test_val_split(X, y, random_state=0, train_prop=1/3):
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=1-train_prop,
random_state=random_state,
)
X_val, X_test, y_val, y_test = train_test_split(
X_test,
y_test,
test_size=1/2,
random_state=random_state,
)
return X_train, y_train, X_test, y_test, X_val, y_val

def get_20_news(random_state=0, train_prop=1/3):
# get a training set and test set
data_train = datasets.fetch_20newsgroups_vectorized(subset='train')
data_test = datasets.fetch_20newsgroups_vectorized(subset='test')

X_train = data_train.data
X_test = data_test.data
y_train = data_train.target
y_test = data_test.target

# binarize labels
y_train[data_train.target < 10] = -1
y_train[data_train.target >= 10] = 1
y_test[data_test.target < 10] = -1
y_test[data_test.target >= 10] = 1

# Regroup all
X = sp.vstack([X_train, X_test])
y = np.hstack([y_train, y_test])
# Equally-sized split
X_train, y_train, X_test, y_test, X_val, y_val = train_test_val_split(
X,
y,
random_state=random_state,
train_prop=train_prop,
)
return X_train, y_train, X_test, y_test, X_val, y_val

def get_realsim(random_state, train_prop=1/3):
X, y = fetch_libsvm("real-sim")
X_train, y_train, X_test, y_test, X_val, y_val = train_test_val_split(
X,
y,
random_state=random_state,
train_prop=train_prop,
)
return X_train, y_train, X_test, y_test, X_val, y_val

def get_mnist(random_state, train_prop=1/3):
from keras.datasets import mnist
import tensorflow as tf

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.reshape(tf.image.resize(X_train[..., None], (12, 12)).numpy()[..., 0], (X_train.shape[0], -1))
X_test = np.reshape(tf.image.resize(X_test[..., None], (12, 12)).numpy()[..., 0], (X_test.shape[0], -1))
X = np.vstack([X_train, X_test])
y = np.hstack([y_train, y_test])
X_train, y_train, X_test, y_test, X_val, y_val = train_test_val_split(
X,
y,
random_state=random_state,
train_prop=train_prop,
)
#scaling
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_val = scaler.transform(X_val)
return X_train, y_train, X_test, y_test, X_val, y_val


def val_loss_univariate(X, y, beta):
_, _, yz = _intercept_dot(beta, X, y)
out = -np.sum(log_logistic(yz))
return out

def val_loss_multivariate(X, y, beta):
n_samples, n_classes = y.shape
out, _, _ = _multinomial_loss(beta, X, y, np.zeros((n_classes,)), np.ones((n_samples,)))
return out

def results_for_kwargs(train_prop=1/3, dataset='20news', random_state=0, search=None, **kwargs):
if dataset == '20news':
get_fun = get_20_news
elif dataset == 'real-sim':
get_fun = get_realsim
elif dataset == 'mnist':
get_fun = get_mnist
else:
raise NotImplementedError(f'Dataset {dataset} not implemented')
X_train, y_train, X_test, y_test, X_val, y_val = get_fun(random_state, train_prop=train_prop)
np.random.seed(random_state)
lambda_traces = []
lambda_times = []
beta_traces = []
start = time.time()
def lambda_tracing(x, lmbd):
delta = time.time() - start
lambda_traces.append(np.copy(lmbd)[0])
lambda_times.append(delta)
beta_traces.append(x.copy())
# optimize model parameters and hyperparameters jointly
# using HOAG
if dataset != 'mnist':
# only 2 classes
clf = LogisticRegressionCV(**kwargs)
val_loss = val_loss_univariate
else:
# multiclasses case
clf = MultiLogisticRegressionCV(**kwargs)
val_loss = val_loss_multivariate
if search is None:
clf.fit(X_train, y_train, X_test, y_test, callback=lambda_tracing)
else:
random = search == 'random'
clf.grid_search(
X_train,
y_train,
X_test,
y_test,
callback=lambda_tracing,
random=random,
)
if dataset == 'mnist':
lbin = LabelBinarizer()
lbin.fit(y_train)
y_val = lbin.transform(y_val)
y_test = lbin.transform(y_test)
val_losses = [val_loss(X_val, y_val, beta) for beta in beta_traces]
test_losses = [val_loss(X_test, y_test, beta) for beta in beta_traces]
res = BenchResult(lambda_traces, lambda_times, beta_traces, val_losses, test_losses)
return res

def randomized_results_for_kwargs(n_random_seed=10, **kwargs):
overall_res = BenchResult()
for seed in range(n_random_seed):
res = results_for_kwargs(random_state=seed, **kwargs)
overall_res.append(res)
return overall_res

def framed_results_for_kwargs(n_random_seed=10, **kwargs):
overall_res = randomized_results_for_kwargs(n_random_seed=n_random_seed, **kwargs)
data = [
{
'seed': i_seed,
'i_iter': i_iter,
'time': overall_res.lamda_times[i_seed][i_iter],
'val_loss': overall_res.val_losses[i_seed][i_iter],
'test_loss': overall_res.test_losses[i_seed][i_iter],
**kwargs,
}
for i_seed in range(n_random_seed)
for i_iter in range(len(overall_res.lamda_times[i_seed]))
]
df_res = pd.DataFrame(data)
return df_res

def quantized_results_for_kwargs(**kwargs):
overall_res = randomized_results_for_kwargs(**kwargs)
overall_res.freeze()
median_res = overall_res.median()
q1_res = overall_res.quantile(0.1)
q9_res = overall_res.quantile(0.9)
return median_res, q1_res, q9_res
Loading