Source code for forte.processors.data_augment.algorithms.machine_translator

# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file wraps a machine translation model.
It could be used for back translation.
For simplicity, the model is not wrapped as a processor.
"""
from typing import List
from abc import abstractmethod
from forte.utils import create_import_error_msg

__all__ = [
    "MachineTranslator",
    "MarianMachineTranslator",
]


[docs]class MachineTranslator: r""" This class is a wrapper for machine translation models. Args: src_lang: The source language. tgt_lang: The target language. device: "cuda" for gpu, "cpu" otherwise. """ def __init__(self, src_lang: str, tgt_lang: str, device: str): self.src_lang: str = src_lang self.tgt_lang: str = tgt_lang self.device = device
[docs] @abstractmethod def translate(self, src_text: str) -> str: r""" This function translates the input text into target language. Args: src_text: The input text in source language. Returns: The output text in target language. """ raise NotImplementedError
[docs]class MarianMachineTranslator(MachineTranslator): r""" This class is a wrapper for the Marian Machine Translator (https://huggingface.co/transformers/model_doc/marian.html). Please refer to their doc for supported languages. """ def __init__( self, src_lang: str = "en", tgt_lang: str = "fr", device: str = "cpu" ): super().__init__(src_lang, tgt_lang, device) self.model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}" try: from transformers import ( # pylint:disable=import-outside-toplevel MarianMTModel, MarianTokenizer, ) except ImportError as err: raise ImportError( create_import_error_msg( "transformers", "data_aug", "Machine Translator" ) ) from err self.tokenizer = MarianTokenizer.from_pretrained(self.model_name) self.model = MarianMTModel.from_pretrained(self.model_name) self.model = self.model.to(self.device)
[docs] def translate(self, src_text: str) -> str: translated: List[str] = self.model.generate( # TODO: Should not use prepare_seq2seq_batch for deprecation **self.tokenizer.prepare_seq2seq_batch( # Have to use explicitly call `convert_to_tensors` to make # this line work in both transformers 3 and 4, probably won't # work in 5. [src_text] ) .convert_to_tensors("pt") .to(self.device) ) tgt_texts: List[str] = [ self.tokenizer.decode(t, skip_special_tokens=True) for t in translated ] return tgt_texts[0]