Source code for forte.processors.nlp.srl_predictor

# Copyright 2019 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Dict, List, Tuple, Optional
from forte.utils import create_import_error_msg
from forte.common.configuration import Config
from forte.common.resources import Resources
from forte.data.data_pack import DataPack
from forte.data.ontology import Annotation
from forte.data.span import Span
from forte.models.srl.model import LabeledSpanGraphNetwork
from forte.processors.base.batch_processor import RequestPackingProcessor
from ft.onto.base_ontology import (
    PredicateLink,
    PredicateMention,
    PredicateArgument,
)

try:
    import texar.torch as tx
except ImportError as e:
    raise ImportError(
        create_import_error_msg("texar-pytorch", "nlp", "Texar models")
    ) from e


try:
    import torch
except ImportError as e:
    raise ImportError(
        create_import_error_msg("torch", "nlp", "nlp processors")
    ) from e


logger = logging.getLogger(__name__)

__all__ = [
    "SRLPredictor",
]

Prediction = List[Tuple[Span, List[Tuple[Span, str]]]]


[docs]class SRLPredictor(RequestPackingProcessor): """ An Semantic Role labeler trained according to `He, Luheng, et al. "Jointly predicting predicates and arguments in neural semantic role labeling." <https://aclweb.org/anthology/P18-2058>`_. """ word_vocab: tx.data.Vocab char_vocab: tx.data.Vocab model: LabeledSpanGraphNetwork def __init__(self): super().__init__() self.device = torch.device( torch.cuda.current_device() if torch.cuda.is_available() else "cpu" )
[docs] def initialize(self, resources: Resources, configs: Optional[Config]): super().initialize(resources, configs) model_dir = configs.storage_path if configs is not None else None logger.info("restoring SRL model from %s", model_dir) # initialize the batcher if configs: self.batcher.initialize(configs.batcher) self.word_vocab = tx.data.Vocab( os.path.join(model_dir, "embeddings/word_vocab.english.txt") ) self.char_vocab = tx.data.Vocab( os.path.join(model_dir, "embeddings/char_vocab.english.txt") ) model_hparams = LabeledSpanGraphNetwork.default_hparams() model_hparams["context_embeddings"]["path"] = os.path.join( model_dir, model_hparams["context_embeddings"]["path"] ) model_hparams["head_embeddings"]["path"] = os.path.join( model_dir, model_hparams["head_embeddings"]["path"] ) self.model = LabeledSpanGraphNetwork( self.word_vocab, self.char_vocab, model_hparams ) self.model.load_state_dict( torch.load( os.path.join(model_dir, "pretrained/model.pt"), map_location=self.device, ) ) self.model.eval()
[docs] def predict(self, data_batch: Dict) -> Dict[str, List[Prediction]]: text: List[List[str]] = [ sentence.tolist() for sentence in data_batch["Token"]["text"] ] text_ids, length = tx.data.padded_batch( [ self.word_vocab.map_tokens_to_ids_py(sentence) for sentence in text ] ) text_ids = torch.from_numpy(text_ids).to(device=self.device) length = torch.tensor(length, dtype=torch.long, device=self.device) batch_size = len(text) batch = tx.data.Batch( batch_size, text=text, text_ids=text_ids, length=length, srl=[[]] * batch_size, ) self.model = self.model.to(self.device) batch_srl_spans = self.model.decode(batch) # Convert predictions into annotations. batch_predictions: List[Prediction] = [] for idx, srl_spans in enumerate(batch_srl_spans): word_spans = data_batch["Token"]["span"][idx] predictions: Prediction = [] for pred_idx, pred_args in srl_spans.items(): begin, end = word_spans[pred_idx] # TODO cannot create annotation here. # Need to convert from Numpy numbers to int. pred_span = Span(begin.item(), end.item()) arguments = [] for arg in pred_args: begin = word_spans[arg.start][0].item() end = word_spans[arg.end][1].item() arg_annotation = Span(begin, end) arguments.append((arg_annotation, arg.label)) predictions.append((pred_span, arguments)) batch_predictions.append(predictions) return {"predictions": batch_predictions}
[docs] def pack( self, pack: DataPack, predict_results: Dict[str, List[Prediction]], _: Optional[Annotation] = None, ): batch_predictions = predict_results["predictions"] for predictions in batch_predictions: for pred_span, arg_result in predictions: pred = PredicateMention(pack, pred_span.begin, pred_span.end) for arg_span, label in arg_result: arg = PredicateArgument(pack, arg_span.begin, arg_span.end) link = PredicateLink(pack, pred, arg) link.arg_type = label
[docs] @classmethod def default_configs(cls): """ This defines the default configuration structure for the predictor. """ return { "storage_path": None, "batcher": { "batch_size": 4, "context_type": "ft.onto.base_ontology.Sentence", "requests": {"ft.onto.base_ontology.Token": []}, }, }