[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

yakt00/IRGen

Repository files navigation

IRGen

A PyTorch implementation of IRGen based on the paper IRGen: Generative Modeling for Image Retrieval.

Network Architecture image from the paper

Requirements

conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
pip install -r requirements.txt

Datasets

CARS196, CUB200-2011 and In-shop Clothes are used in this repo.

You should download these datasets by yourself, and extract them into data/car, data/cub, data/isc directory. For each dataset, run the following instruction to generate the ground truth file.

python gnd_generater.py 

Usage

Train Tokenizer

python train_tokenizer.py 
optional arguments:
--data_path                   datasets path [default value is 'data']
--data_name                   dataset name [default value is 'isc'](choices=['car', 'cub', 'isc', 'imagenet', 'places'])
--feats                       initialize features for quantize [default value is '']
--output_dir                  saving output direction [default value is 'results']
--lr                          train learning rate [default value is 5e-4]
--batch_size                  train batch size [default value is 128]
--num_epochs                  train epoch number [default value is 200]
--rq_weight                   loss weight for rq reconsturcted features

Get image tokens

python rq.py --features 'isc_features.npy' --file_name 'isc_rq.pkl' --data_dir 'data/isc'
optional arguments:
--data_name                   dataset name [default value is 'isc'](choices=['car', 'cub', 'isc', 'imagenet', 'places'])

Train IRGen

python -m torch.distributed.launch --nproc_per_node=8 train_ar.py --file_name 'in-shop_clothes_retrieval_trainval.pkl' --codes 'isc_rq.pkl'
optional arguments:
--data_dir                    datasets path [default value is 'data/isc/Img']
--data_name                   dataset name [default value is 'isc'](choices=['car', 'cub', 'isc', 'imagenet', 'places'])
--output_dir                  saving output direction [default value is 'results']
--lr                          train learning rate [default value is 8e-5]
--batch_size                  train batch size [default value is 64]
--num_epochs                  train epoch number [default value is 200]
--smoothing                   smoothing value for label smoothing [default value is 0.1]

Test IRGen

python test_ar.py --file_name 'in-shop_clothes_retrieval_trainval.pkl' --codes 'isc_rq.pkl' --model_dir 'results/isc_rq_e200.pkl' 
optional arguments:
--data_dir                    datasets path [default value is 'data/isc/Img']
--data_name                   dataset name [default value is 'isc'](choices=['car', 'cub', 'isc', 'imagenet', 'places'])
--beam_size                   size for beam search [default value is 30]
--ks                          query number for test@k[default value is [1,10,20,30]]

Benchmarks

The models are trained on 8 NVIDIA Tesla V100 (32G) GPU.

In-shop

P refers to precision, R refers to recall.

P@1 P@10 P@20 P@30 R@1 R@10 R@20 R@30
92.4% 87.4% 87.0% 86.9% 92.4% 96.8% 97.6% 97.9%

CUB200

P@1 P@2 P@4 P@8 R@1 R@2 R@4 R@8
82.7% 82.7% 83.0% 82.8% 82.7% 86.4% 89.2% 91.4%

CARS196

P@1 P@2 P@4 P@8 R@1 R@2 R@4 R@8
90.1% 89.9% 90.2% 90.5% 90.1% 92.1% 93.2% 93.7%

Results

The Precision-Recall curve.

In-shop Clothes

ISC

CUB200

CUB

Cars196

Cars

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published