Source code for forte.models.da_rl.aug_wrapper
# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A wrapper adding data augmentation to a Bert model with arbitrary tasks.
"""
import random
import math
from typing import Tuple, Dict, Generator
from forte.utils import create_import_error_msg
from forte.models.da_rl.magic_model import MetaModule
try:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Optimizer
except ImportError as e:
raise ImportError(
create_import_error_msg("torch", "models", "Augmentation Wrapper")
) from e
try:
import texar.torch as tx
except ImportError as e:
raise ImportError(
create_import_error_msg(
"texar-pytorch", "models", "Augmentation Wrapper"
)
) from e
__all__ = ["MetaAugmentationWrapper"]
[docs]class MetaAugmentationWrapper:
# pylint: disable=line-too-long
r"""
A wrapper adding data augmentation to a Bert model with arbitrary tasks.
This is used to perform reinforcement learning for joint data augmentation
learning and model training.
See: https://arxiv.org/pdf/1910.12795.pdf
This code is adapted from:
https://github.com/tanyuqian/learning-data-manipulation/blob/master/augmentation/generator.py
Let :math:`\theta` be the parameters of the downstream (classifier) model.
Let :math:`\phi` be the parameters of the augmentation model.
Equations to update :math:`\phi`:
.. math::
\theta'(\phi) = \theta - \nabla_{\theta} L_{train}(\theta, \phi)
\phi = \phi - \nabla_{\phi} L_{val}(\theta'(\phi))
Args:
augmentation_model:
A Bert-based model for data augmentation.
E.g. BertForMaskedLM.
Model requirement: masked language modeling, the output logits
of this model is of shape `[batch_size, seq_length, token_size]`.
augmentation_optimizer:
An optimizer that is associated with `augmentation_model`.
E.g. Adam optimizer.
input_mask_ids:
Bert token id of `'[MASK]'`. This is used to randomly mask out
tokens from the input sentence during training.
device:
The CUDA device to run the model on.
num_aug:
The number of samples from the augmentation model
for every augmented training instance.
Example usage:
.. code-block:: python
aug_wrapper = MetaAugmentationWrapper(
aug_model, aug_optim, mask_id, device, num_aug)
for batch in training_data:
# Train augmentation model params.
aug_wrapper.reset_model()
for instance in batch:
# Augmented example with params phi exposed
aug_instance_features = \
aug_wrapper.augment_instance(instance_features)
# Model is the downstream Bert model.
model.zero_grad()
loss = model(aug_instance_features)
meta_model = MetaModule(model)
meta_model = aug_wrapper.update_meta_model(
meta_model, loss, model, optim)
# Compute gradient of the augmentation model on validation
data for val_batch in validation_data:
val_loss = meta_model(val_batch_features)
val_loss = val_loss / num_training_instance / num_aug \
/ num_val_batch
val_loss.backward()
# update augmentation model params.
aug_wrapper.update_phi()
# train classifier with augmented batch
aug_batch_features = aug_wrapper.augment_batch(batch_features)
optim.zero_grad()
loss = model(aug_batch_features)
loss.backward()
optim.step()
"""
def __init__(
self,
augmentation_model: nn.Module,
augmentation_optimizer: Optimizer,
input_mask_ids: int,
device: torch.device,
num_aug: int,
):
self._aug_model = augmentation_model
self._aug_optimizer = augmentation_optimizer
self._input_mask_ids = input_mask_ids
self._device = device
self._num_aug = num_aug
def reset_model(self):
self._aug_model.train()
self._aug_model.zero_grad()
def _augment_instance(
self, features: Tuple[torch.Tensor, ...], num_aug: int
) -> torch.Tensor:
r"""Augment a training instance. Randomly mask out some tokens in the
input sentence and use the logits of the augmentation model as the
augmented bert token soft embedding.
Args:
features: A tuple of Bert features of one training instance.
`(input_ids, input_mask, segment_ids, label_ids)`.
`input_ids` is a tensor of Bert token ids.
It has shape `[seq_len]`.
`input_mask` is a tensor of shape `[seq_len]` with 1 indicating
without mask and 0 with mask.
`segment_ids` is a tensor of shape `[seq_len]`.
`label_ids` is a tensor of shape `[seq_len]`.
num_aug: The number of samples from the augmentation model.
Returns:
aug_probs: A tensor of shape `[num_aug, seq_len, token_size]`.
It is the augmented bert token soft embedding.
"""
feature: Generator[torch.Tensor, torch.Tensor, torch.Tensor] = (
t.view(1, -1).to(self._device) for t in features
)
init_ids, input_mask, segment_ids, _ = feature
length = int(torch.sum(input_mask).item())
if length >= 4:
mask_idx = sorted(
random.sample(list(range(1, length - 1)), max(length // 7, 2))
)
else:
mask_idx = [1]
init_ids[0][mask_idx] = self._input_mask_ids
logits = self._aug_model(
init_ids, token_type_ids=segment_ids, attention_mask=input_mask
)[0]
# Get samples
aug_probs_all = []
for _ in range(num_aug):
# Need a gumbel trick here in order to keep phi as variables.
# Enable efficient gradient propagation through theta' to phi.
probs = F.gumbel_softmax(logits.squeeze(0), hard=False)
aug_probs = torch.zeros_like(probs).scatter_(
1, init_ids[0].unsqueeze(1), 1.0
)
for t in mask_idx:
aug_probs = tx.utils.pad_and_concat(
[aug_probs[:t], probs[t : t + 1], aug_probs[t + 1 :]],
axis=0,
)
aug_probs_all.append(aug_probs)
aug_probs = tx.utils.pad_and_concat(
[ap.unsqueeze(0) for ap in aug_probs_all], axis=0
)
return aug_probs
[docs] def augment_instance(
self, features: Tuple[torch.Tensor, ...]
) -> Tuple[torch.Tensor, ...]:
r"""Augment a training instance.
Args:
features: A tuple of Bert features of one training instance.
`(input_ids, input_mask, segment_ids, label_ids)`.
`input_ids` is a tensor of Bert token ids.
It has shape `[seq_len]`.
`input_mask` is a tensor of shape `[seq_len]` with 1 indicating
without mask and 0 with mask.
`segment_ids` is a tensor of shape `[seq_len]`.
`label_ids` is a tensor of shape `[seq_len]`.
Returns:
A tuple of Bert features of augmented training instances.
`(input_probs_aug, input_mask_aug, segment_ids_aug, label_ids_aug)`.
`input_probs_aug` is a tensor of soft Bert embeddings,
distributions over vocabulary.
It has shape `[num_aug, seq_len, token_size]`.
It keeps :math:`\phi` as variable so that after passing it as an
input to the classifier, the gradients of :math:`\theta` will
also apply to :math:`\phi`.
`input_mask_aug` is a tensor of shape `[num_aug, seq_len]`, it
concatenates `num_aug` the input `input_mask` so that it
corresponds to the mask of each token in `input_probs_aug`.
`segment_ids_aug` is a tensor of shape `[num_aug, seq_len]`, it
concatenates `num_aug` the input `segment_ids` so that it
corresponds to the token type of each token in `input_probs_aug`.
`label_ids_aug` is a tensor of shape `[num_aug, seq_len]`, it
concatenates `num_aug` the input `label_ids` so that it corresponds
to the label of each token in `input_probs_aug`.
"""
aug_probs = self._augment_instance(features, self._num_aug)
_, input_mask, segment_ids, label_ids = (
t.to(self._device).unsqueeze(0) for t in features
)
input_mask_aug = tx.utils.pad_and_concat(
[input_mask] * self._num_aug, axis=0
)
segment_ids_aug = tx.utils.pad_and_concat(
[segment_ids] * self._num_aug, axis=0
)
label_ids_aug = tx.utils.pad_and_concat(
[label_ids] * self._num_aug, axis=0
)
return aug_probs, input_mask_aug, segment_ids_aug, label_ids_aug
[docs] def augment_batch(
self, batch_features: Tuple[torch.Tensor, ...]
) -> Tuple[torch.Tensor, ...]:
r"""Augment a batch of training instances. Append augmented instances
to the input instances.
Args:
batch_features: A tuple of Bert features of a batch training
instances. (input_ids, input_mask, segment_ids, label_ids).
`input_ids` is a tensor of Bert token ids.
It has shape `[batch_size, seq_len]`.
`input_mask`, `segment_ids`, `label_ids` are all tensors of
shape `[batch_size, seq_len]`.
Returns:
A tuple of Bert features of augmented training instances.
`(input_probs_aug, input_mask_aug, segment_ids_aug, label_ids_aug)`.
`input_probs_aug` is a tensor of soft Bert embeddings,
It has shape `[batch_size * 2, seq_len, token_size]`.
`input_mask_aug` is a tensor of shape `[batch_size * 2, seq_len]`,
it concatenates two input `input_mask`, the first one corresponds to the
mask of the tokens in the original bert instance, the second one
corresponds to the mask of the augmented bert instance.
`segment_ids_aug` is a tensor of shape `[batch_size * 2, seq_len]`,
it concatenates two input `segment_ids`, the first one corresponds
to the segment id of the tokens in the original bert instance, the
second one corresponds to the segment id of the
augmented bert instance.
`label_ids_aug` is a tensor of shape `[batch_size * 2, seq_len]`,
it concatenates two input `label_ids`, the first one corresponds
to the labels of the original bert instance, the second one
corresponds to the labels of the augmented bert instance.
"""
input_ids, input_mask, segment_ids, labels = batch_features
self._aug_model.eval()
aug_instances = []
features = []
num_instance = len(input_ids)
for i in range(num_instance):
feature = (input_ids[i], input_mask[i], segment_ids[i], labels[i])
features.append(feature)
with torch.no_grad():
aug_probs = self._augment_instance(feature, num_aug=1)
aug_instances.append(aug_probs)
input_ids_or_probs, input_masks, segment_ids, label_ids = [
tx.utils.pad_and_concat(
[t[i].unsqueeze(0) for t in features], axis=0
).to(self._device)
for i in range(4)
]
num_aug = len(aug_instances[0])
input_ids_or_probs_aug = []
for i in range(num_aug):
for aug_probs in aug_instances:
input_ids_or_probs_aug.append(aug_probs[i : i + 1])
input_ids_or_probs_aug = tx.utils.pad_and_concat(
input_ids_or_probs_aug, axis=0
).to(self._device)
inputs_onehot = torch.zeros_like(
input_ids_or_probs_aug[: len(input_ids_or_probs)]
).scatter_(2, input_ids_or_probs.unsqueeze(2), 1.0)
input_probs_aug = tx.utils.pad_and_concat(
[inputs_onehot, input_ids_or_probs_aug], axis=0
).to(self._device)
input_mask_aug = tx.utils.pad_and_concat(
[input_masks] * (num_aug + 1), axis=0
).to(self._device)
segment_ids_aug = tx.utils.pad_and_concat(
[segment_ids] * (num_aug + 1), axis=0
).to(self._device)
label_ids_aug = tx.utils.pad_and_concat(
[label_ids] * (num_aug + 1), axis=0
).to(self._device)
return input_probs_aug, input_mask_aug, segment_ids_aug, label_ids_aug
[docs] def eval_batch(
self, batch_features: Tuple[torch.Tensor, ...]
) -> torch.FloatTensor:
r"""Evaluate a batch of training instances.
Args:
batch_features: A tuple of Bert features of a batch training
instances. (input_ids, input_mask, segment_ids, label_ids).
`input_ids` is a tensor of Bert token ids.
It has shape `[batch_size, seq_len]`.
`input_mask`, `segment_ids`, `label_ids` are all tensors of
shape `[batch_size, seq_len]`.
Returns:
The masked language modeling loss of one evaluation batch.
It is a `torch.FloatTensor` of shape `[1,]`.
"""
self._aug_model.eval()
batch = tuple(t.to(self._device) for t in batch_features)
input_ids, input_mask, segment_ids, labels = batch
loss = self._aug_model(
input_ids,
token_type_ids=segment_ids,
attention_mask=input_mask,
labels=labels,
)[0]
return loss
[docs] def update_meta_model(
self,
meta_model: MetaModule,
loss: torch.Tensor,
model: nn.Module,
optimizer: Optimizer,
) -> MetaModule:
r"""Update the parameters within the `MetaModel`
according to the downstream model loss.
`MetaModel` is used to calculate
:math:`\nabla_{\phi} L_{val}(\theta'(\phi))`,
where it needs gradients applied to :math:`\phi`.
Perform parameter updates in this function, and later applies gradient
change to :math:`\theta` and :math:`\phi` using validation data.
Args:
meta_model: A meta model whose parameters will be updated in-place
by the deltas calculated from the input `loss`.
loss: The loss of the downstream model that have taken
the augmented training instances as input.
model: The downstream Bert model.
optimizer: The optimizer that is associated with the `model`.
Returns:
The same input `meta_model` with the updated parameters.
"""
# grads_theta(phi) = \nabla_{theta} L_{train}(theta, phi)
grads_theta = self._calculate_grads(loss, model, optimizer)
# theta'(phi) = theta - grads_theta(phi)
meta_model.update_params(grads_theta)
return meta_model
@staticmethod
def _calculate_grads(
loss: torch.Tensor, model: nn.Module, optimizer: Optimizer
) -> Dict[str, torch.Tensor]:
grads = torch.autograd.grad(
loss,
[param for name, param in model.named_parameters()],
create_graph=True,
)
grads = {
param: grads[i]
for i, (name, param) in enumerate(model.named_parameters())
}
if isinstance(optimizer, tx.core.BertAdam):
deltas = _texar_bert_adam_delta(grads, model, optimizer)
else:
deltas = _torch_adam_delta(grads, model, optimizer)
return deltas
def update_phi(self):
# L_{val}(theta'(phi))
# apply gradients to phi
# phi = phi - \nabla_{phi} L_{val}(theta'(phi))
self._aug_optimizer.step()
def _texar_bert_adam_delta(
grads: Dict[nn.parameter.Parameter, torch.Tensor],
model: nn.Module,
optimizer: Optimizer,
) -> Dict[str, torch.Tensor]:
# pylint: disable=line-too-long
r"""Compute parameter delta function for texar-pytorch
core.BertAdam optimizer.
This function is adapted from:
https://github.com/asyml/texar-pytorch/blob/master/texar/torch/core/optimization.py#L398
"""
assert isinstance(optimizer, tx.core.BertAdam)
deltas = {}
for group in optimizer.param_groups:
for param in group["params"]:
grad = grads[param]
state = optimizer.state[param]
if len(state) == 0:
# Exponential moving average of gradient values
state["next_m"] = torch.zeros_like(param.data)
# Exponential moving average of squared gradient values
state["next_v"] = torch.zeros_like(param.data)
exp_avg, exp_avg_sq = state["next_m"], state["next_v"]
beta1, beta2 = group["betas"]
if group["weight_decay"] != 0:
grad = grad + group["weight_decay"] * param.data
exp_avg = exp_avg * beta1 + (1.0 - beta1) * grad
exp_avg_sq = exp_avg_sq * beta2 + (1.0 - beta2) * grad * grad
denom = exp_avg_sq.sqrt() + group["eps"]
step_size = group["lr"]
deltas[param] = -step_size * exp_avg / denom
param_to_name = {param: name for name, param in model.named_parameters()}
return {param_to_name[param]: delta for param, delta in deltas.items()}
def _torch_adam_delta(
grads: Dict[nn.parameter.Parameter, torch.Tensor],
model: nn.Module,
optimizer: Optimizer,
) -> Dict[str, torch.Tensor]:
r"""Compute parameter delta function for Torch Adam optimizer."""
assert issubclass(type(optimizer), Optimizer)
deltas = {}
for group in optimizer.param_groups:
for param in group["params"]:
grad = grads[param]
state = optimizer.state[param]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(param.data)
state["exp_avg_sq"] = torch.zeros_like(param.data)
state["step"] = 0
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
step = state["step"] + 1
if group["weight_decay"] != 0:
grad = grad + group["weight_decay"] * param.data
exp_avg = exp_avg * beta1 + (1.0 - beta1) * grad
exp_avg_sq = exp_avg_sq * beta2 + (1.0 - beta2) * grad * grad
denom = exp_avg_sq.sqrt() + group["eps"]
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
step_size = (
group["lr"] * math.sqrt(bias_correction2) / bias_correction1
)
deltas[param] = -step_size * exp_avg / denom
param_to_name = {param: name for name, param in model.named_parameters()}
return {param_to_name[param]: delta for param, delta in deltas.items()}