Skip to content

Training

Basic Training

To train a simple dense retriever, call the tevatron.driver.train module.

Here we use Natural Questions as example.

We train on a machine with 4xV100 GPU, if the GPU resources are limited for you, please train with gradient cache.

python -m torch.distributed.launch --nproc_per_node=4 -m tevatron.driver.train \
  --output_dir model_nq \
  --dataset_name Tevatron/wikipedia-nq \
  --model_name_or_path bert-base-uncased \
  --do_train \
  --save_steps 20000 \
  --fp16 \
  --per_device_train_batch_size 32 \
  --train_n_passages 2 \
  --learning_rate 1e-5 \
  --q_max_len 32 \
  --p_max_len 156 \
  --num_train_epochs 40 \
  --negatives_x_device

Here we are using our self-contained datasets to train. To use custom dataset, replace --dataset_name Tevatron/wikipedia-nq by --train_dir <train data dir>, (see here for details).

Here we picked bert-base-uncased BERT weight from Huggingface Hub and turned on AMP with --fp16 to speed up training. Several command flags are provided in addition to configure the learned model, e.g. --add_pooler which adds an linear projection. A full list command line arguments can be found in tevatron.arguments.

GradCache

Tevatron adopts gradient cache technique to allow large batch training of dense retriever on memory limited GPU.

Details is described in paper Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup .

Adding following three flags to training command to enable gradient cache: - --grad_cache: enable gradient caching - --gc_q_chunk_size: sub-batch size for query - --gc_p_chunk_size: sub-batch size for passage

For example, the following command can train dense retrieval model for Natural Question in 128 batch size but only with one GPU.

CUDA_VISIBLE_DEVICES=0 python -m tevatron.driver.train \
  --output_dir model_nq \
  --dataset_name Tevatron/wikipedia-nq \
  --model_name_or_path bert-base-uncased \
  --do_train \
  --save_steps 20000 \
  --fp16 \
  --per_device_train_batch_size 128 \
  --train_n_passages 2 \
  --learning_rate 1e-5 \
  --q_max_len 32 \
  --p_max_len 156 \
  --num_train_epochs 40 \
  --grad_cache \
  --gc_q_chunk_size 32 \
  --gc_p_chunk_size 16

Notice that GradCache also support multi-GPU setting.

Training with TPU

Tevatron implements TPU training via Jax/Flax. We provide a separate module tevatron.driver.jax_train to train on TPU. The arguments managements aligns with above Pytorch training driver.

By running the following commands on a V3-8 TPU VM is equivalent to the commands above.

python -m tevatron.driver.jax_train \
  --output_dir model_nq \
  --dataset_name Tevatron/wikipedia-nq \
  --model_name_or_path bert-base-uncased \
  --do_train \
  --per_device_train_batch_size 16 \
  --train_n_passages 2 \
  --learning_rate 1e-5 \
  --q_max_len 32 \
  --p_max_len 156 \
  --num_train_epochs 40

Note that our Jax training driver also support gradient cache by adding --grad_cache option.

Arguments Description

Our Argument parser inherits from TrainingArguments from HuggingFace's transformers. For the common use training arguments such as learning rate and batch size configuration, please check document from HuggingFace for details.

Here we describe the details of the arguments additionally defined for Tevaron's CLI

name description type default supported driver
do_train Whether to run training bool required pytorch, jax
model_name_or_path Model backbone to initialize dense retriever. It can be either a model name that avaliable in Huggingface model hub. Or a path to a model directory str required pytorch, jax
tokenizer_name Tokenizer name or path if not the same as model_name_or_path str same as model_name_or_path pytorch, jax
cache_dir Path to the directory to save the cache of models and datasets str ~/.cache/ pytorch, jax
untie_encoder Whether query encoder and passage encoder share same parameter bool False pytorch, jax
add_pooler Whether add pooler on top of last layer output bool False pytorch
projection_in_dim The input dim of pooler int 768
projection_out_dim The output dim of pooler int 768 pytorch
dataset_name Dataset name that avaliable on HuggingFace str json pytorch, jax
train_dir Directory that stores custom training data str None pytorch, jax
dataset_proc_num Number of threads to use to preprocess/tokenize data int 12 pytorch, jax
train_n_passages Number of passages for each anchor query during training. It will load 1 positive passage + (train_n_passages-1) negative passage for each example during training int 8 pytorch, jax
passage_field_separator The token to seperate title and text field for passages str " " pytorch
q_max_len Maximum query length int 32 pytorch, jax
p_max_len Maximum passage length int 128 pytorch, jax
negative_x_device Whether gather in-batch negative passages cross devices bool False pytorch
grad_cache Whether use gradient cache feature. This can be used to support large batch size while GPU/TPU memory are limited. bool False pytorch, jax
gc_q_chunk_size Sub-batch size for queries with grad_cache int 4 pytorch, jax
gc_p_chunk_size Sub-batch size for passages with grad_cache int 32 pytorch, jax