Source code for forte.data.extractors.relation_extractor
# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union, List, Optional, Tuple, Type
from forte.common import ProcessorConfigError
from forte.common.configuration import Config
from forte.data.base_extractor import BaseExtractor
from forte.data.converter.feature import Feature
from forte.data.data_pack import DataPack
from forte.data.ontology import Annotation, Link
from forte.data.ontology.core import Entry
__all__ = ["LinkExtractor"]
from forte.utils import get_class
def get_index(
pack: DataPack, index_entries: List[Annotation], context_entry: Annotation
):
founds = []
for i, entry in enumerate(index_entries):
if pack.covers(context_entry, entry):
founds.append(i)
return [founds[0], founds[-1] + 1]
[docs]class LinkExtractor(BaseExtractor):
"""
This extractor extracts relation type features from data packs. This
extractor expects the parent and child of the relation to be Annotation
entries.
"""
def initialize(self, config: Union[Dict, Config]):
# pylint: disable=attribute-defined-outside-init
super().initialize(config)
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor cannot be None."
)
if self.config.attribute is None:
raise ProcessorConfigError(
"'attribute' is required in this extractor."
)
if self.config.index_annotation is None:
raise ProcessorConfigError(
"'index_annotation' is required in this extractor."
)
if self.config.entry_type is None:
raise ProcessorConfigError(
"'entry_type' is required in this extractor."
)
else:
self._entry_class: Type[Link] = get_class(self.config.entry_type)
if not issubclass(self._entry_class, Link):
raise ProcessorConfigError(
"`entry_class` to this extractor must be a Link tpe."
)
self._parent_class: Type[Annotation] = self._entry_class.ParentType # type: ignore
if not issubclass(self._parent_class, Annotation):
raise ProcessorConfigError(
f"The parent class of the provided {self.config.entry_type}"
" must be an Annotation."
)
self._child_class: Type[Annotation] = self._entry_class.ChildType # type: ignore
if not issubclass(self._child_class, Annotation):
raise ProcessorConfigError(
f"The child class of the provided {self.config.entry_type}"
" must be an Annotation."
)
[docs] @classmethod
def default_configs(cls):
r"""Returns a dictionary of default hyper-parameters.
Here:
- "`entry_type`": The target relation entry type, should be a Link entry.
- "`attribute`": The attribute of the relation to extract.
- "`index_annotation`": The annotation object used to index the
head and child node of the relations.
"""
config = super().default_configs()
config.update(
{
"entry_type": None,
"attribute": None,
"index_annotation": None,
}
)
return config
[docs] def update_vocab(
self, pack: DataPack, context: Optional[Annotation] = None
):
"""
Update values of relation attributes to the vocabulary.
Args:
pack: The input data pack.
context: The context is an Annotation entry where
features will be extracted within its range. If None, then the
whole data pack will be used as the context. Default is None.
Returns:
None
"""
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
entry: Entry
for entry in pack.get(self.config.entry_type, context):
attribute = getattr(entry, self.config.attribute)
self.add(attribute)
[docs] def extract(
self, pack: DataPack, context: Optional[Annotation] = None
) -> Feature:
"""Extract link data as features from the context.
Args:
pack: The input data pack that contains the features.
context: The context is an Annotation entry where
features will be extracted within its range. If None, then the
whole data pack will be used as the context. Default is None.
Returns:
"""
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
index_annotations: List[Annotation] = list(
pack.get(self.config.index_annotation, context)
)
parent_nodes: List[Annotation] = []
child_nodes: List[Annotation] = []
relation_atts = []
r: Link
for r in pack.get(self.config.entry_type, context):
parent_nodes.append(r.get_parent()) # type: ignore
child_nodes.append(r.get_child()) # type: ignore
raw_att = getattr(r, self.config.attribute)
relation_atts.append(
self.element2repr(raw_att) if self.vocab else raw_att
)
parent_unit_span = []
child_unit_span = []
for p, c in zip(parent_nodes, child_nodes):
parent_unit_span.append(get_index(pack, index_annotations, p))
child_unit_span.append(get_index(pack, index_annotations, c))
meta_data = {
"parent_unit_span": parent_unit_span,
"child_unit_span": child_unit_span,
"pad_value": self.get_pad_value(),
"dim": 1,
"dtype": int if self.vocab else str,
}
return Feature(data=relation_atts, metadata=meta_data, vocab=self.vocab)
[docs] def add_to_pack(
self,
pack: DataPack,
predictions: List[Tuple[Tuple[int, int], Tuple[int, int], int]],
context: Optional[Annotation] = None,
):
"""
Convert prediction back to Links inside the data pack.
Args:
pack: The datapack to add predictions back.
predictions: This is the output of the model, it is a
triplet, the first element shows the parent, the second
element shows the child. These two are indexed by the
`index_annotation` of this extractor. The last element is the
index of the relation attribute.
context: The context is an Annotation
entry where predictions will be added to. This has the same
meaning with `context` as in
:meth:`~forte.data.base_extractor.BaseExtractor.extract`.
If None, then the whole data pack will be used as the
context. Default is None.
"""
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
index_entries: List[Annotation] = list(
pack.get(self.config.index_annotation, context)
)
for parent, child, rel_index in predictions:
parent_begin_entry_index, parent_end_entry_index = parent
child_begin_entry_index, child_end_entry_index = child
parent_start = index_entries[parent_begin_entry_index].begin
parent_end = index_entries[parent_end_entry_index].end
child_start = index_entries[child_begin_entry_index].begin
child_end = index_entries[child_end_entry_index].end
rel_value = self.id2element(rel_index)
child_anno = self._child_class(
pack, child_start, child_end # type:ignore
)
parent_anno = self._parent_class(
pack, parent_start, parent_end # type:ignore
)
link = self._entry_class(
pack, parent_anno, child_anno # type:ignore
)
setattr(link, self.config.attribute, rel_value)