Skip to content

"Probabilistic Embeddings Revisited" paper official repository

License

Notifications You must be signed in to change notification settings

tinkoff-ai/probabilistic-embeddings

Repository files navigation

Probabilistic Embeddings

This repository contains an implementation of all Probabilistic Metric Learning (PML) approaches from Probabilistic Embeddings Revisited paper. It fully supports the following probabilistic methods from previous works:

In addition to PML approaches, classical (deterministic) Metric Learning (ML) methods are supported:

Getting Started

Installation

  1. Clone this repository:
    git clone [email protected]:tinkoff-ai/probabilistic-embeddings.git
    cd probabilistic-embeddings
  2. We recommend building our Docker image with Dockerfile.
  3. Library must be installed before execution. It is recommended to use editable installation:
    pip install -e .
  4. You can check the installation using tests:
    tox -e py38 -r

Quick Start

  1. Prepare experiment .yaml config. In this example, a simple ArcFace model is trained on LWF dataset:

    dataset_params:
      name: lfw-openset
      samples_per_class: null
    
    model_params:
      # Embedder maps input image to embedding vector space.
      embedder_params:
        model_type: resnet18
        # Use ImageNet pretrain.
        pretrained: true
      distribution_params:
        # Spherical 512D embeddings.
        spherical: true
        dim: 512
      # For deterministic embeddings specify Dirac distribution (default).
      distribution_type: dirac
      classifier_type: arcface
    
    trainer_params:
      optimizer_type: adam
      optimizer_params:
        lr: 3.0e-4
  2. Download and unpack LFW dataset.

  3. Run training with command:

    python3 -m probabilistic_embeddings train \
    --config <path-to-yaml-config> \
    --train-root <logs-and-checkpoints-root> \
    <path-to-lfw-data-root>
  4. Logs and checkpoints will be saved to ./<logs-and-checkpoints-root>. The default logging format is Tensorboard.

WandB support

To enable WandB logging run the experiment with command:

WANDB_ENTITY=<entity-name> \
WANDB_API_KEY=<api-key> \
CUDA_VISIBLE_DEVICES=<gpu-index> \
python3 -m probabilistic_embeddings train \
--config <path-to-yaml-config> \
--logger wandb:<project-name>:<experiment-name> \
--train-root <logs-and-checkpoints-root> \
<path-to-dataset-root>

Supported commands

Training

train runs standard training pipeline:

python3 -m probabilistic_embeddings train \
--config <path-to-yaml-config> \
--train-root <logs-and-checkpoints-root> \
<path-to-data-root>

To apply K-fold cross-validation scheme use cval command.

Evaluation

test computes metrics for a given checkpoint:

CUDA_VISIBLE_DEVICES=<gpu-index> \
python3 -m probabilistic_embeddings test \
--config <path-to-config> \
--checkpoint <path-to-checkpoint> \
<path-to-data-root>

evaluate performs model evaluation over multiple random seeds. Add num_evaluation_seeds field to experiment config to specify number of random seeds. Use evaluate-cval command to evaluate with cross-validation. Add num_validation_folds to dataset_params to set the number of folds.

Hyperparameter tuning

In order to run WandB sweeps, use hoptand hopt-cval commands. Hyperparameter tuning is only supported with WandB logger.

CUDA_VISIBLE_DEVICES=<gpu-index> \
python3 -m probabilistic_embeddings hopt \
--config <path-to-config> \
--logger wandb:<project-name>:<experiment-name> \
--train-root <training-root> <path-to-data-root>

Hyperparameters to search and their ranges should be specified in config as in this example:

...
model_params:
  ...
  classifier_type: arcface
  classifier_params:
    _hopt:
      scale:
        min: 1.0
        max: 64.0
      margin:
        min: 0.0
        max: 1.0
...

Reproducing paper results

In order to reproduce all the results of the paper, you need to generate configs for all experiments:

mkdir configs/reality/generated
python scripts/configs/generate-reality.py \
configs/reality/templates/ \
configs/reality/generated/ \
--best configs/reality/best/

Our hyperparameter search results are stored in configs/reality/best. You can reproduce hyperparameter search with hopt command and download best parameters from WandB. To reproduce training and evaluation, please, refer to the commands above.

Supported Datasets

Repository supports multiple datasets. Face recognition: MS1MV2, MS1MV3, LFW, and CASIA. Image retrieval: Cars196, CUB200, In-shop clothes (Inshop) and Stanford Online Products (SOP). We also implement multiple image classification datasets, please, refer to ./src/probabilistic_embeddings/dataset for more details.

Serialized datasets used in reality configs can be downloaded via the following links: Cars196, CUB200, In-shop clothes (Inshop) and Stanford Online Products (SOP).

Citing

If you use code from this repository in your project, please, cite our paper:

@inproceedings{pml2022,
  title={Probabilistic Embeddings Revisited},
  author={Ivan Karpukhin and Stanislav Dereka and Sergey Kolesnikov},
  year={2022},
  url={https://arxiv.org/pdf/2202.06768.pdf}

About

"Probabilistic Embeddings Revisited" paper official repository

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published