Source code for forte.models.da_rl.aug_wrapper

# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A wrapper adding data augmentation to a Bert model with arbitrary tasks.
"""

import random
import math
from typing import Tuple, Dict, Generator
from forte.utils import create_import_error_msg
from forte.models.da_rl.magic_model import MetaModule

try:
    import torch
    from torch import nn
    from torch.nn import functional as F
    from torch.optim import Optimizer
except ImportError as e:
    raise ImportError(
        create_import_error_msg("torch", "models", "Augmentation Wrapper")
    ) from e


try:
    import texar.torch as tx
except ImportError as e:
    raise ImportError(
        create_import_error_msg(
            "texar-pytorch", "models", "Augmentation Wrapper"
        )
    ) from e

__all__ = ["MetaAugmentationWrapper"]


[docs]class MetaAugmentationWrapper: # pylint: disable=line-too-long r""" A wrapper adding data augmentation to a Bert model with arbitrary tasks. This is used to perform reinforcement learning for joint data augmentation learning and model training. See: https://arxiv.org/pdf/1910.12795.pdf This code is adapted from: https://github.com/tanyuqian/learning-data-manipulation/blob/master/augmentation/generator.py Let :math:`\theta` be the parameters of the downstream (classifier) model. Let :math:`\phi` be the parameters of the augmentation model. Equations to update :math:`\phi`: .. math:: \theta'(\phi) = \theta - \nabla_{\theta} L_{train}(\theta, \phi) \phi = \phi - \nabla_{\phi} L_{val}(\theta'(\phi)) Args: augmentation_model: A Bert-based model for data augmentation. E.g. BertForMaskedLM. Model requirement: masked language modeling, the output logits of this model is of shape `[batch_size, seq_length, token_size]`. augmentation_optimizer: An optimizer that is associated with `augmentation_model`. E.g. Adam optimizer. input_mask_ids: Bert token id of `'[MASK]'`. This is used to randomly mask out tokens from the input sentence during training. device: The CUDA device to run the model on. num_aug: The number of samples from the augmentation model for every augmented training instance. Example usage: .. code-block:: python aug_wrapper = MetaAugmentationWrapper( aug_model, aug_optim, mask_id, device, num_aug) for batch in training_data: # Train augmentation model params. aug_wrapper.reset_model() for instance in batch: # Augmented example with params phi exposed aug_instance_features = \ aug_wrapper.augment_instance(instance_features) # Model is the downstream Bert model. model.zero_grad() loss = model(aug_instance_features) meta_model = MetaModule(model) meta_model = aug_wrapper.update_meta_model( meta_model, loss, model, optim) # Compute gradient of the augmentation model on validation data for val_batch in validation_data: val_loss = meta_model(val_batch_features) val_loss = val_loss / num_training_instance / num_aug \ / num_val_batch val_loss.backward() # update augmentation model params. aug_wrapper.update_phi() # train classifier with augmented batch aug_batch_features = aug_wrapper.augment_batch(batch_features) optim.zero_grad() loss = model(aug_batch_features) loss.backward() optim.step() """ def __init__( self, augmentation_model: nn.Module, augmentation_optimizer: Optimizer, input_mask_ids: int, device: torch.device, num_aug: int, ): self._aug_model = augmentation_model self._aug_optimizer = augmentation_optimizer self._input_mask_ids = input_mask_ids self._device = device self._num_aug = num_aug def reset_model(self): self._aug_model.train() self._aug_model.zero_grad() def _augment_instance( self, features: Tuple[torch.Tensor, ...], num_aug: int ) -> torch.Tensor: r"""Augment a training instance. Randomly mask out some tokens in the input sentence and use the logits of the augmentation model as the augmented bert token soft embedding. Args: features: A tuple of Bert features of one training instance. `(input_ids, input_mask, segment_ids, label_ids)`. `input_ids` is a tensor of Bert token ids. It has shape `[seq_len]`. `input_mask` is a tensor of shape `[seq_len]` with 1 indicating without mask and 0 with mask. `segment_ids` is a tensor of shape `[seq_len]`. `label_ids` is a tensor of shape `[seq_len]`. num_aug: The number of samples from the augmentation model. Returns: aug_probs: A tensor of shape `[num_aug, seq_len, token_size]`. It is the augmented bert token soft embedding. """ feature: Generator[torch.Tensor, torch.Tensor, torch.Tensor] = ( t.view(1, -1).to(self._device) for t in features ) init_ids, input_mask, segment_ids, _ = feature length = int(torch.sum(input_mask).item()) if length >= 4: mask_idx = sorted( random.sample(list(range(1, length - 1)), max(length // 7, 2)) ) else: mask_idx = [1] init_ids[0][mask_idx] = self._input_mask_ids logits = self._aug_model( init_ids, token_type_ids=segment_ids, attention_mask=input_mask )[0] # Get samples aug_probs_all = [] for _ in range(num_aug): # Need a gumbel trick here in order to keep phi as variables. # Enable efficient gradient propagation through theta' to phi. probs = F.gumbel_softmax(logits.squeeze(0), hard=False) aug_probs = torch.zeros_like(probs).scatter_( 1, init_ids[0].unsqueeze(1), 1.0 ) for t in mask_idx: aug_probs = tx.utils.pad_and_concat( [aug_probs[:t], probs[t : t + 1], aug_probs[t + 1 :]], axis=0, ) aug_probs_all.append(aug_probs) aug_probs = tx.utils.pad_and_concat( [ap.unsqueeze(0) for ap in aug_probs_all], axis=0 ) return aug_probs
[docs] def augment_instance( self, features: Tuple[torch.Tensor, ...] ) -> Tuple[torch.Tensor, ...]: r"""Augment a training instance. Args: features: A tuple of Bert features of one training instance. `(input_ids, input_mask, segment_ids, label_ids)`. `input_ids` is a tensor of Bert token ids. It has shape `[seq_len]`. `input_mask` is a tensor of shape `[seq_len]` with 1 indicating without mask and 0 with mask. `segment_ids` is a tensor of shape `[seq_len]`. `label_ids` is a tensor of shape `[seq_len]`. Returns: A tuple of Bert features of augmented training instances. `(input_probs_aug, input_mask_aug, segment_ids_aug, label_ids_aug)`. `input_probs_aug` is a tensor of soft Bert embeddings, distributions over vocabulary. It has shape `[num_aug, seq_len, token_size]`. It keeps :math:`\phi` as variable so that after passing it as an input to the classifier, the gradients of :math:`\theta` will also apply to :math:`\phi`. `input_mask_aug` is a tensor of shape `[num_aug, seq_len]`, it concatenates `num_aug` the input `input_mask` so that it corresponds to the mask of each token in `input_probs_aug`. `segment_ids_aug` is a tensor of shape `[num_aug, seq_len]`, it concatenates `num_aug` the input `segment_ids` so that it corresponds to the token type of each token in `input_probs_aug`. `label_ids_aug` is a tensor of shape `[num_aug, seq_len]`, it concatenates `num_aug` the input `label_ids` so that it corresponds to the label of each token in `input_probs_aug`. """ aug_probs = self._augment_instance(features, self._num_aug) _, input_mask, segment_ids, label_ids = ( t.to(self._device).unsqueeze(0) for t in features ) input_mask_aug = tx.utils.pad_and_concat( [input_mask] * self._num_aug, axis=0 ) segment_ids_aug = tx.utils.pad_and_concat( [segment_ids] * self._num_aug, axis=0 ) label_ids_aug = tx.utils.pad_and_concat( [label_ids] * self._num_aug, axis=0 ) return aug_probs, input_mask_aug, segment_ids_aug, label_ids_aug
[docs] def augment_batch( self, batch_features: Tuple[torch.Tensor, ...] ) -> Tuple[torch.Tensor, ...]: r"""Augment a batch of training instances. Append augmented instances to the input instances. Args: batch_features: A tuple of Bert features of a batch training instances. (input_ids, input_mask, segment_ids, label_ids). `input_ids` is a tensor of Bert token ids. It has shape `[batch_size, seq_len]`. `input_mask`, `segment_ids`, `label_ids` are all tensors of shape `[batch_size, seq_len]`. Returns: A tuple of Bert features of augmented training instances. `(input_probs_aug, input_mask_aug, segment_ids_aug, label_ids_aug)`. `input_probs_aug` is a tensor of soft Bert embeddings, It has shape `[batch_size * 2, seq_len, token_size]`. `input_mask_aug` is a tensor of shape `[batch_size * 2, seq_len]`, it concatenates two input `input_mask`, the first one corresponds to the mask of the tokens in the original bert instance, the second one corresponds to the mask of the augmented bert instance. `segment_ids_aug` is a tensor of shape `[batch_size * 2, seq_len]`, it concatenates two input `segment_ids`, the first one corresponds to the segment id of the tokens in the original bert instance, the second one corresponds to the segment id of the augmented bert instance. `label_ids_aug` is a tensor of shape `[batch_size * 2, seq_len]`, it concatenates two input `label_ids`, the first one corresponds to the labels of the original bert instance, the second one corresponds to the labels of the augmented bert instance. """ input_ids, input_mask, segment_ids, labels = batch_features self._aug_model.eval() aug_instances = [] features = [] num_instance = len(input_ids) for i in range(num_instance): feature = (input_ids[i], input_mask[i], segment_ids[i], labels[i]) features.append(feature) with torch.no_grad(): aug_probs = self._augment_instance(feature, num_aug=1) aug_instances.append(aug_probs) input_ids_or_probs, input_masks, segment_ids, label_ids = [ tx.utils.pad_and_concat( [t[i].unsqueeze(0) for t in features], axis=0 ).to(self._device) for i in range(4) ] num_aug = len(aug_instances[0]) input_ids_or_probs_aug = [] for i in range(num_aug): for aug_probs in aug_instances: input_ids_or_probs_aug.append(aug_probs[i : i + 1]) input_ids_or_probs_aug = tx.utils.pad_and_concat( input_ids_or_probs_aug, axis=0 ).to(self._device) inputs_onehot = torch.zeros_like( input_ids_or_probs_aug[: len(input_ids_or_probs)] ).scatter_(2, input_ids_or_probs.unsqueeze(2), 1.0) input_probs_aug = tx.utils.pad_and_concat( [inputs_onehot, input_ids_or_probs_aug], axis=0 ).to(self._device) input_mask_aug = tx.utils.pad_and_concat( [input_masks] * (num_aug + 1), axis=0 ).to(self._device) segment_ids_aug = tx.utils.pad_and_concat( [segment_ids] * (num_aug + 1), axis=0 ).to(self._device) label_ids_aug = tx.utils.pad_and_concat( [label_ids] * (num_aug + 1), axis=0 ).to(self._device) return input_probs_aug, input_mask_aug, segment_ids_aug, label_ids_aug
[docs] def eval_batch( self, batch_features: Tuple[torch.Tensor, ...] ) -> torch.FloatTensor: r"""Evaluate a batch of training instances. Args: batch_features: A tuple of Bert features of a batch training instances. (input_ids, input_mask, segment_ids, label_ids). `input_ids` is a tensor of Bert token ids. It has shape `[batch_size, seq_len]`. `input_mask`, `segment_ids`, `label_ids` are all tensors of shape `[batch_size, seq_len]`. Returns: The masked language modeling loss of one evaluation batch. It is a `torch.FloatTensor` of shape `[1,]`. """ self._aug_model.eval() batch = tuple(t.to(self._device) for t in batch_features) input_ids, input_mask, segment_ids, labels = batch loss = self._aug_model( input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=labels, )[0] return loss
[docs] def update_meta_model( self, meta_model: MetaModule, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, ) -> MetaModule: r"""Update the parameters within the `MetaModel` according to the downstream model loss. `MetaModel` is used to calculate :math:`\nabla_{\phi} L_{val}(\theta'(\phi))`, where it needs gradients applied to :math:`\phi`. Perform parameter updates in this function, and later applies gradient change to :math:`\theta` and :math:`\phi` using validation data. Args: meta_model: A meta model whose parameters will be updated in-place by the deltas calculated from the input `loss`. loss: The loss of the downstream model that have taken the augmented training instances as input. model: The downstream Bert model. optimizer: The optimizer that is associated with the `model`. Returns: The same input `meta_model` with the updated parameters. """ # grads_theta(phi) = \nabla_{theta} L_{train}(theta, phi) grads_theta = self._calculate_grads(loss, model, optimizer) # theta'(phi) = theta - grads_theta(phi) meta_model.update_params(grads_theta) return meta_model
@staticmethod def _calculate_grads( loss: torch.Tensor, model: nn.Module, optimizer: Optimizer ) -> Dict[str, torch.Tensor]: grads = torch.autograd.grad( loss, [param for name, param in model.named_parameters()], create_graph=True, ) grads = { param: grads[i] for i, (name, param) in enumerate(model.named_parameters()) } if isinstance(optimizer, tx.core.BertAdam): deltas = _texar_bert_adam_delta(grads, model, optimizer) else: deltas = _torch_adam_delta(grads, model, optimizer) return deltas def update_phi(self): # L_{val}(theta'(phi)) # apply gradients to phi # phi = phi - \nabla_{phi} L_{val}(theta'(phi)) self._aug_optimizer.step()
def _texar_bert_adam_delta( grads: Dict[nn.parameter.Parameter, torch.Tensor], model: nn.Module, optimizer: Optimizer, ) -> Dict[str, torch.Tensor]: # pylint: disable=line-too-long r"""Compute parameter delta function for texar-pytorch core.BertAdam optimizer. This function is adapted from: https://github.com/asyml/texar-pytorch/blob/master/texar/torch/core/optimization.py#L398 """ assert isinstance(optimizer, tx.core.BertAdam) deltas = {} for group in optimizer.param_groups: for param in group["params"]: grad = grads[param] state = optimizer.state[param] if len(state) == 0: # Exponential moving average of gradient values state["next_m"] = torch.zeros_like(param.data) # Exponential moving average of squared gradient values state["next_v"] = torch.zeros_like(param.data) exp_avg, exp_avg_sq = state["next_m"], state["next_v"] beta1, beta2 = group["betas"] if group["weight_decay"] != 0: grad = grad + group["weight_decay"] * param.data exp_avg = exp_avg * beta1 + (1.0 - beta1) * grad exp_avg_sq = exp_avg_sq * beta2 + (1.0 - beta2) * grad * grad denom = exp_avg_sq.sqrt() + group["eps"] step_size = group["lr"] deltas[param] = -step_size * exp_avg / denom param_to_name = {param: name for name, param in model.named_parameters()} return {param_to_name[param]: delta for param, delta in deltas.items()} def _torch_adam_delta( grads: Dict[nn.parameter.Parameter, torch.Tensor], model: nn.Module, optimizer: Optimizer, ) -> Dict[str, torch.Tensor]: r"""Compute parameter delta function for Torch Adam optimizer.""" assert issubclass(type(optimizer), Optimizer) deltas = {} for group in optimizer.param_groups: for param in group["params"]: grad = grads[param] state = optimizer.state[param] if len(state) == 0: state["exp_avg"] = torch.zeros_like(param.data) state["exp_avg_sq"] = torch.zeros_like(param.data) state["step"] = 0 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] step = state["step"] + 1 if group["weight_decay"] != 0: grad = grad + group["weight_decay"] * param.data exp_avg = exp_avg * beta1 + (1.0 - beta1) * grad exp_avg_sq = exp_avg_sq * beta2 + (1.0 - beta2) * grad * grad denom = exp_avg_sq.sqrt() + group["eps"] bias_correction1 = 1.0 - beta1**step bias_correction2 = 1.0 - beta2**step step_size = ( group["lr"] * math.sqrt(bias_correction2) / bias_correction1 ) deltas[param] = -step_size * exp_avg / denom param_to_name = {param: name for name, param in model.named_parameters()} return {param_to_name[param]: delta for param, delta in deltas.items()}