Skip to content

Latest commit

 

History

History
180 lines (129 loc) · 8.64 KB

README.md

File metadata and controls

180 lines (129 loc) · 8.64 KB

Contents

EMDR2 is an end-to-end training algorithm developed for the task of open-domain question answering.

Comparison with other approaches

End-to-End Training

Results

Setup

To use this repo, we recommend using one of NGC's more recent PyTorch containers. The image version used in this paper can be pulled with the command docker pull nvcr.io/nvidia/pytorch:20.03-py3.

An installation of the Nvidia container toolkit may also be required.

There are additional dependencies that are needed to be installed. We have provided a Dockerfile for the same under the directory docker. For instance, to build a new docker image (nvcr.io/nvidia/pytorch:20.10-py3-faiss-compiled) over the base container, please use this command

cd docker
sudo docker build -t nvcr.io/nvidia/pytorch:20.10-py3-faiss-compiled .

To run the above image in an interactive mode, please use this command

sudo docker run --ipc=host --gpus all -it --rm -v /mnt/disks:/mnt/disks nvcr.io/nvidia/pytorch:20.10-py3-faiss-compiled bash

, where /mnt/disks is the directory to be mounted.

Downloading Data and Checkpoints

We've provided pretrained checkpoints and datasets on Dropbox for use to train models for open-domain QA tasks and dense retrieval. These files can be downloaded using the wget command-line utility and the links provided below.

Required data files for training / inference
Required checkpoints and embeddings
Data for Masked Salient Spans (MSS) training (Optional)

We also provide data for Masked Salient Spans training (URL). This file contains around 20M sentences extracted from the Wikipedia passages file and includes the positions of the named entities in the sentences. To obtain these named entities, we used pre-trained OntoNotes-5.0 model provided by the Stanza toolkit.

An example line from the file in jsonlines format is:

{"doc_id": 209, "sent_text": "Karpov 's outstanding classical tournament play has been seriously limited since 1997 , since he prefers to be more involved in the politics of his home country of Russia .", 
"bert_ent_pos": [[14, 14], [31, 31]], "linguistic_ent": [["1997", "DATE", 11, 11], ["Russia", "GPE", 28, 28]]}

Here, the fields doc_id indicates the passage id in the evidence, sent_text denotes the sentence text as obtained after Stanza tokenization, linguistic_ent contains the list of named entities in this format: (named entity text, entity type, entity start position, entity end position), and bert_ent_pos contains the entities start and end positions after BERT tokenization.

For more details on MSS training, please refer to the papers below.

Usage

We've provided several scripts for training models for both dense retriever and open-domain QA tasks in examples directory.

Please ensure to change the data and checkpoint paths in these scripts.

To replicate the answer generation results on the Natural Questions (NQ) dataset, run the script as

bash examples/openqa/emdr2_nq.sh

Similar scripts are provided for TriviaQA, WebQuestions and also for training dense retriever.

Training

For end-to-end training, we used a single node of 16 A100 GPUs with 40GB GPU memory.

In the codebase, the first set of 8 GPUs are used for model training, the second set of 8 GPUs are used for asynchronous evidence embedding, and all the 16 GPUs are used for online retrieval at every step.

The code can also be run on a node with 8 GPUs by disabling asynchronous evidence embedding computation. However, this can lead to some loss in performance.

Pre-trained Checkpoints

Dataset Dev EM Test EM Checkpoint Precomputed Evidence Embedding
Natural Questions 50.42 52.49 link link
TriviaQA 71.13 71.43 link link
WebQuestions 49.86 48.67 link link

To use these checkpoints, please set the variables of CHECKPOINT_PATH and EMBEDDING_PATH to point to the above checkpoint and embedding index, respectively. Also, add the option of --no-load-optim and remove the options of --emdr2-training --async-indexer --index-reload-interval 500 from the example script, so that it works in inference mode. As the memory requirement for inference is lower, evaluation can also be performed on 4-8 GPUs.

Helper Scripts

  • Sometimes, we need to save the retriever model for tasks such as top-K recall evaluation. To just save the retriever model from the checkpoints, please use this cmd
python tools/save_emdr2_models.py --submodel-name retriever --load e2eqa/trivia --save e2eqa/trivia/retriever/
  • To create evidence embeddings from a retriever checkpoint and perform top-K recall evaluation, please use this script. Make sure to correctly set the paths of datasets and checkpoints.
bash examples/helper-scripts/create_wiki_indexes_and_evaluate.sh

Issues

For any errors or bugs in the codebase, please either open a new issue or send an email to Devendra Singh Sachan ([email protected]) .

Citation

If you find these codes or data useful, please consider citing our paper as:

@inproceedings{sachan2021endtoend,
    title={End-to-End Training of Multi-Document Reader and Retriever for Open-Domain Question Answering},
    author={Devendra Singh Sachan and Siva Reddy and William L. Hamilton and Chris Dyer and Dani Yogatama},
    booktitle={Advances in Neural Information Processing Systems},
    editor={A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan},
    year={2021},
    url={https://openreview.net/forum?id=5KWmB6JePx}
}

Some of the ideas and implementations in this work were based on a previous paper. Please also consider citing the following paper, if the code is helpful.

@inproceedings{sachan-etal-2021-end,
    title = "End-to-End Training of Neural Retrievers for Open-Domain Question Answering",
    author = "Sachan, Devendra and Patwary, Mostofa and Shoeybi, Mohammad and Kant, Neel and Ping, Wei and Hamilton, William L. and Catanzaro, Bryan",
    booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
    month = aug,
    year = "2021",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.acl-long.519",
    doi = "10.18653/v1/2021.acl-long.519",
    pages = "6648--6662"
}