Source code for forte.data.extractors.attribute_extractor
# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implements AttributeExtractor, which is used to extract feature
from the attribute of entries.
"""
from typing import (
Any,
Union,
Iterable,
Hashable,
SupportsInt,
Dict,
Optional,
List,
)
from forte.common import ProcessorConfigError
from forte.common.configuration import Config
from forte.data.base_extractor import BaseExtractor
from forte.data.converter.feature import Feature
from forte.data.data_pack import DataPack
from forte.data.ontology.core import Entry
from forte.data.ontology.top import Annotation
__all__ = ["AttributeExtractor"]
[docs]class AttributeExtractor(BaseExtractor):
r"""`AttributeExtractor` extracts feature from the attribute of entry.
Most of the time, a user will not need to call this class explicitly,
they will be called by the framework.
"""
def initialize(self, config: Union[Dict, Config]):
super().initialize(config)
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
if self.config.entry_type is None:
raise ProcessorConfigError(
"The ``entry_type`` configuration must be "
"provided and cannot be None."
)
if self.config.attribute is None:
raise ProcessorConfigError(
"The `attribute` configuration must be "
"provided and cannot be None."
)
[docs] @classmethod
def default_configs(cls):
r"""Returns a dictionary of default hyper-parameters.
Here:
- "`attribute`": str
The name of the attribute we want to extract from the entry. This
attribute should present in the entry definition. There are some
built-in attributes for some instance, such as `text` for
`Annotation` entries. ``tid`` should be also available for any
entries. The default value is ``tid``.
- "`entry_type`": str
The fully qualified name of the entry to extract attributes from.
The default value is None, but this value must present or an
`ProcessorConfigError` will be thrown.
"""
config = super().default_configs()
config.update(
{
"attribute": "tid",
"entry_type": None,
}
)
return config
@classmethod
def _get_attribute(cls, entry: Entry, attr: str) -> Any:
r"""Get the attribute from entry. You can
overwrite this function if you have special way to get the
attribute from entry.
Args:
entry: An instance of Entry type, where the
attribute will be extracted from.
attr: The name of the attribute.
Returns:
Any. The attribute extracted from entry.
"""
return getattr(entry, attr)
@classmethod
def _set_attribute(cls, entry: Entry, attr: str, value: Any):
r"""Set the attribute of an entry to value.
You can overwrite this function if you have special way to
set the attribute.
Args:
entry: An instance of Entry type, where the
attribute will be set.
attr: The name of the attribute.
value: The value to be set for the attribute.
"""
if attr == "text":
raise AttributeError("text attribute of entry cannot be changed.")
setattr(entry, attr, value)
[docs] def update_vocab(
self, pack: DataPack, context: Optional[Annotation] = None
):
r"""Get all attributes of one instance and add them into the vocabulary.
Args:
pack: The data pack input to extract vocabulary.
context: The context is an Annotation entry where
features will be extracted within its range. If None, then the
whole data pack will be used as the context. Default is None.
"""
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
entry: Entry
for entry in pack.get(self.config.entry_type, context):
# The following pylint skip due to a bug:
# https://github.com/PyCQA/pylint/issues/3507
# Hashable is not recognized the type.
# pylint: disable=isinstance-second-argument-not-valid-type
element = self._get_attribute(entry, self.config.attribute)
if not isinstance(element, Hashable):
raise AttributeError(
"Only hashable element can be"
"added into the vocabulary. Consider setting"
"vocab_method to be raw and do not call update_vocab"
"if you only need the raw attribute value without"
"converting them into index."
)
self.add(element)
[docs] def extract(
self, pack: DataPack, context: Optional[Annotation] = None
) -> Feature:
"""Extract the attribute of an entry of the configured entry type.
The entry type is passed in from via extractor config `entry_type`.
Args:
pack: The datapack that contains the current instance.
context: The context is an Annotation entry where
features will be extracted within its range. If None, then the
whole data pack will be used as the context. Default is None.
Returns:
Features (attributes) for instance with in the provided
context, they will be converted to the representation based on
the vocabulary configuration.
"""
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
data = []
instance: Annotation
for instance in pack.get(self.config.entry_type, context):
value = self._get_attribute(instance, self.config.attribute)
rep = self.element2repr(value) if self.vocab else value
data.append(rep)
meta_data = {
"need_pad": self.config.need_pad,
"pad_value": self.get_pad_value(),
"dim": 1,
"dtype": int if self.vocab else Any,
}
return Feature(data=data, metadata=meta_data, vocab=self.vocab)
[docs] def pre_evaluation_action(
self, pack: DataPack, context: Optional[Annotation]
):
r"""This function is performed on the pack before the evaluation
stage, allowing one to perform some actions before the evaluation.
By default, this function will remove all attributes defined in the
config (set them to None). You can overwrite this function by yourself.
Args:
pack: The datapack that contains the current
instance.
context: The context is an Annotation entry where
data are extracted within its range. If None, then the
whole data pack will be used as the context. Default is None.
"""
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
entry: Entry
for entry in pack.get(self.config.entry_type, context):
self._set_attribute(entry, self.config.attribute, None)
[docs] def add_to_pack(
self,
pack: DataPack,
predictions: Iterable[SupportsInt],
context: Optional[Annotation] = None,
):
r"""Add the prediction for attributes to the data pack.
We assume the number of predictions in the iterable to be the same as
the number of the entries of the defined type in the data pack.
Args:
pack: The datapack that contains the current
instance.
predictions: This is the output
of the model, which should be the class index for the attribute.
context: The context is an Annotation
entry where predictions will be added to. This has the same
meaning with `context` as in
:meth:`~forte.data.base_extractor.BaseExtractor.extract`.
If None, then the whole data pack will be used as the
context. Default is None.
"""
if self.config is None:
raise ProcessorConfigError(
"Configuration for the extractor not found."
)
instance_entries: List[Entry] = list(
pack.get(self.config.entry_type, context)
)
# The following pylint skip due to a bug:
# https://github.com/PyCQA/pylint/issues/3507
# Iterable is not recognized the type.
# pylint: disable=isinstance-second-argument-not-valid-type
# _predictions = predictions if isinstance(predictions, Iterable) else \
# [predictions]
# if not isinstance(predictions, Iterable):
# predictions = [predictions]
values = [self.id2element(int(x)) for x in predictions]
for entry, value in zip(instance_entries, values):
self._set_attribute(entry, self.config.attribute, value)