Source code for forte.processors.data_augment.algorithms.UDA

# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implements the UDA(Unsupervised Data Augmentation)
algorithm as described in
Unsupervised Data Augmentation for Consistency Training
(https://arxiv.org/abs/1904.12848)
(https://github.com/google-research/uda)
"""
from typing import Optional, List, Tuple, Iterator
from torch import Tensor
from forte.utils import create_import_error_msg

try:
    from texar.torch.losses.info_loss import kl_divg_loss_with_logits
    from texar.torch.data import DataIterator
    from texar.torch.data.data.dataset_utils import Batch
except ImportError as err:
    raise ImportError(
        create_import_error_msg(
            "texar-pytorch", "models", "Unsupervised Data Augmentation(UDA)"
        )
    ) from err


__all__ = [
    "UDAIterator",
]


[docs]class UDAIterator: r""" This iterator wraps the Unsupervised Data Augmentation(UDA) algorithm by calculating the unsupervised loss automatically during each iteration. It takes both supervised and unsupervised data iterator as input. The unsupervised data should contain the original input and the augmented input. The original and augmented inputs should be in the same training example. During each iteration, the iterator will return the supervised and unsupervised batches. Users can call the :func:`calculate_uda_loss` to get the UDA loss and combine it with the supervised loss for model training. It uses tricks such as prediction sharpening and confidence masking. Please refer to the UDA paper for more details. (https://arxiv.org/abs/1904.12848) Args: sup_iterator: The iterator for supervised data. Each item is a training/evaluation/test example with key-value pairs as inputs. unsup_iterator: The iterator for unsupervised data. Each training example in it should contain both the original and augmented data. softmax_temperature: The softmax temperature for sharpening the distribution. The value should be larger than 0. Defaults to 1. confidence_threshold: The threshold for confidence-masking. It is a threshold of the probability in [0, 1], rather than of the logit. If set to -1, the threshold will be ignored. Defaults to -1. reduction: Default: 'mean'. This is the same as the `reduction` argument in :func:`texar.torch.losses.info_loss.kl_divg_loss_with_logits`. The loss will be a scalar tensor if the `reduction` is not :attr:`'none'`. Specifies the reduction to apply to the output: - ``'none'``: no reduction will be applied. - ``'batchmean'``: the sum of the output will be divided by the batch size. - ``'sum'``: the output will be summed. - ``'mean'``: the output will be divided by the number of elements in the output. """ def __init__( self, sup_iterator: DataIterator, unsup_iterator: DataIterator, softmax_temperature: float = 1.0, confidence_threshold: float = -1, reduction: str = "mean", ): self._sup_iterator: DataIterator = sup_iterator self._unsup_iterator: DataIterator = unsup_iterator self._softmax_temperature = softmax_temperature self._confidence_threshold = confidence_threshold self._reduction = reduction # The flag for returning the unsupervised data. self._use_unsup = True self._sup_iter: Iterator[Batch] self._unsup_iter: Iterator[Batch] def __len__(self): return self._sup_iterator.__len__() def switch_to_dataset( self, dataset_name: Optional[str] = None, use_unsup: bool = True ): # Set the flag of using unsupervised data. self._use_unsup = use_unsup self._sup_iterator.switch_to_dataset(dataset_name) def switch_to_dataset_unsup(self, dataset_name: Optional[str] = None): self._unsup_iterator.switch_to_dataset(dataset_name) @property def num_datasets(self) -> int: return ( self._sup_iterator.num_datasets + self._unsup_iterator.num_datasets ) @property def dataset_names(self) -> List[str]: return ( self._sup_iterator.dataset_names + self._unsup_iterator.dataset_names )
[docs] def calculate_uda_loss( self, logits_orig: Tensor, logits_aug: Tensor ) -> Tensor: r""" This function calculate the KL divergence between the output probabilities of original input and augmented input. The two inputs should have the same shape, and the last dimension of them should be the probability distribution. Args: logits_orig: A tensor contains the logits of the original data. logits_aug: A tensor contains the logits of the augmented data. Must have the same shape as `logits_orig`. Returns: The loss, as a pytorch scalar float tensor if the `reduction` is not :attr:`'none'`, otherwise a tensor with the same shape as the `logits_orig`. """ uda_loss = kl_divg_loss_with_logits( target_logits=logits_orig, input_logits=logits_aug, softmax_temperature=self._softmax_temperature, confidence_threshold=self._confidence_threshold, reduction=self._reduction, ) return uda_loss
def __iter__(self): r""" The :class:`texar.torch.data.DataIterator` is not inherited from the :class:`Iterator`. So we have to get the iterator explicitly. """ self._sup_iter = iter(self._sup_iterator) self._unsup_iter = iter(self._unsup_iterator) return self def __next__(self) -> Tuple[Batch, Optional[Batch]]: r""" When :attr:`_use_unsup` = False, the iterator will be the same as a normal iterator for the supervised data. Otherwise, it will yield unsupervised batch, in addition to the supervised batch. The iterator will only raise the StopIteration when the supervised dataset reaches its end. The unsupervised data will always be yielded as long as the iterator still has supervised data. """ if not self._use_unsup: return next(self._sup_iter), None try: sup_batch = next(self._sup_iter) except StopIteration as e: raise StopIteration from e try: unsup_batch = next(self._unsup_iter) except StopIteration: self._unsup_iter = iter(self._unsup_iterator) unsup_batch = next(self._unsup_iter) return sup_batch, unsup_batch