Source code for forte.data.ontology.top

# Copyright 2019 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from functools import total_ordering
from pathlib import Path
from typing import (
    Optional,
    Sequence,
    Tuple,
    Type,
    Any,
    Dict,
    Union,
    Iterable,
    List,
    Callable,
)
from dataclasses import dataclass

import numpy as np

from forte.common.aliases import URL
from forte.data.base_pack import PackType
from forte.data.modality import Modality
from forte.data.ontology.core import (
    Entry,
    BaseLink,
    BaseGroup,
    MultiEntry,
    EntryType,
    FList,
)
from forte.data.span import Span
from forte.utils.utils import get_full_module_name

__all__ = [
    "Generics",
    "Annotation",
    "Group",
    "Link",
    "MultiPackGeneric",
    "MultiPackGroup",
    "MultiPackLink",
    "Query",
    "SinglePackEntries",
    "MultiPackEntries",
    "AudioAnnotation",
    "ImageAnnotation",
    "Region",
    "Box",
    "Payload",
    "TextPayload",
    "AudioPayload",
    "ImagePayload",
    "load_func",
]

QueryType = Union[Dict[str, Any], np.ndarray]

"""
To create a new top level entry, the following steps are required to
make sure it available across the ontology system:
    1. Create a new top level class that inherits from `Entry` or `MultiEntry`
    2. Add the new class to `SinglePackEntries` or `MultiPackEntries`
    3. Register a new method in `DataStore`: `add_<new_entry>_raw()`
    4. Insert a new conditional branch in `EntryConverter.save_entry_object()`
"""


[docs]class Generics(Entry): def __init__(self, pack: PackType): super().__init__(pack=pack)
[docs]@dataclass @total_ordering class Annotation(Entry): r"""Annotation type entries, such as "token", "entity mention" and "sentence". Each annotation has a :class:`~forte.data.span.Span` corresponding to its offset in the text. Args: pack: The container that this annotation will be added to. begin: The offset of the first character in the annotation. end: The offset of the last character in the annotation + 1. """ begin: int end: int payload_idx: int def __init__(self, pack: PackType, begin: int, end: int): self._span: Optional[Span] = None self.begin: int = begin self.end: int = end super().__init__(pack) @property def span(self) -> Span: # Delay span creation at usage. if self._span is None: self._span = Span(self.begin, self.end) return self._span def __eq__(self, other): r"""The eq function of :class:`Annotation`. By default, :class:`Annotation` objects are regarded as the same if they have the same type, span, and are generated by the same component. Users can define their own eq function by themselves but this must be consistent to :meth:`hash`. """ if other is None: return False return (type(self), self.begin, self.end) == ( type(other), other.begin, other.end, ) def __lt__(self, other): r"""To support total_ordering, `Annotation` must implement `__lt__`. The ordering is defined in the following way: 1. If the `begin` of the annotations are different, the one with larger begin will be larger. 2. In the case where the begins are the same, the one with larger end will be larger. 3. In the case where both offsets are the same, we break the tie using the normal sorting of the class name. """ if self.begin == other.begin: if self.end == other.end: return str(type(self)) < str(type(other)) return self.end < other.end else: return self.begin < other.begin @property def text(self): if self.pack is None: raise ValueError( "Cannot get text because annotation is not " "attached to any data pack." ) return self.pack.get_span_text(self.begin, self.end) @property def index_key(self) -> int: return self.tid
[docs] def get( self, entry_type: Union[str, Type[EntryType]], components: Optional[Union[str, Iterable[str]]] = None, include_sub_type=True, ) -> Iterable[EntryType]: """ This function wraps the :meth:`~forte.data.data_pack.DataPack.get` method to find entries "covered" by this annotation. See that method for more information. Example: .. code-block:: python # Iterate through all the sentences in the pack. for sentence in input_pack.get(Sentence): # Take all tokens from each sentence created by NLTKTokenizer. token_entries = sentence.get( entry_type=Token, component='NLTKTokenizer') ... In the above code snippet, we get entries of type :class:`~ft.onto.base_ontology.Token` within each ``sentence`` which were generated by ``NLTKTokenizer``. You can consider build coverage index between `Token` and `Sentence` if this snippet is frequently used. Args: entry_type: The type of entries requested. components: The component (creator) generating the entries requested. If `None`, will return valid entries generated by any component. include_sub_type: whether to consider the subtypes of the provided entry type. Default `True`. Yields: Each `Entry` found using this method. """ yield from self.pack.get(entry_type, self, components, include_sub_type)
# pylint: disable=duplicate-bases
[docs]@dataclass class Group(BaseGroup[Entry]): r"""Group is an entry that represent a group of other entries. For example, a "coreference group" is a group of coreferential entities. Each group will store a set of members, no duplications allowed. """ members: FList[Entry] member_type: str MemberType = Entry def __init__( self, pack: PackType, members: Optional[Iterable[Entry]] = None, ): # pylint: disable=useless-super-delegation # These attributes are used to store values of member type # in data store and must thus be in a primitive form. self.member_type = get_full_module_name(self.MemberType) super().__init__(pack, members)
[docs] def add_member(self, member: Entry): r"""Add one entry to the group. The update will be populated to the corresponding list in ``DataStore`` of ``self.pack``. Args: member: One member to be added to the group. """ if not isinstance(member, self.MemberType): raise TypeError( f"The members of {type(self)} should be " f"instances of {self.MemberType}, but got {type(member)}" ) self.members.append(member)
[docs] def get_members(self) -> List[Entry]: r"""Get the member entries in the group. The function will retrieve a list of member entries' ``tid``s from ``DataStore`` and convert them to entry object on the fly. Returns: A set of instances of :class:`~forte.data.ontology.core.Entry` that are the members of the group. """ if self.pack is None: raise ValueError( "Cannot get members because group is not " "attached to any data pack." ) return list(self.members)
[docs]class MultiPackGeneric(MultiEntry, Entry): def __init__(self, pack: PackType): super().__init__(pack=pack)
# pylint: disable=duplicate-bases
[docs]@dataclass class MultiPackGroup(MultiEntry, BaseGroup[Entry]): r"""Group type entries, such as "coreference group". Each group has a set of members. """ member_type: str members: FList[Entry] MemberType = Entry def __init__( self, pack: PackType, members: Optional[Iterable[Entry]] = None ): # pylint: disable=useless-super-delegation # These attributes are used to store values of member type # in data store and must thus be in a primitive form. self.member_type = get_full_module_name(self.MemberType) super().__init__(pack) if members is not None: self.add_members(members)
[docs] def add_member(self, member: Entry): if not isinstance(member, self.MemberType): raise TypeError( f"The members of {type(self)} should be " f"instances of {self.MemberType}, but got {type(member)}" ) self.members.append(member)
[docs] def get_members(self) -> List[Entry]: return list(self.members)
[docs]@dataclass class Query(Generics): r"""An entry type representing queries for information retrieval tasks. Args: pack: Data pack reference to which this query will be added """ value: Optional[QueryType] results: Dict[str, float] def __init__(self, pack: PackType): super().__init__(pack) self.value: Optional[QueryType] = None self.results: Dict[str, float] = {}
[docs] def add_result(self, pid: str, score: float): """ Set the result score for a particular pack (based on the pack id). Args: pid: the pack id. score: the score for the pack Returns: None """ self.results[pid] = score
[docs] def update_results(self, pid_to_score: Dict[str, float]): r"""Updates the results for this query. Args: pid_to_score: A dict containing pack id -> score mapping """ self.results.update(pid_to_score)
[docs]@dataclass @total_ordering class AudioAnnotation(Entry): r"""AudioAnnotation type entries, such as "recording" and "audio utterance". Each audio annotation has a :class:`~forte.data.span.Span` corresponding to its offset in the audio. Most methods in this class are the same as the ones in :class:`Annotation`, except that it replaces property `text` with `audio`. Args: pack: The container that this audio annotation will be added to. begin: The offset of the first sample in the audio annotation. end: The offset of the last sample in the audio annotation + 1. """ begin: int end: int payload_idx: int def __init__(self, pack: PackType, begin: int, end: int): self._span: Optional[Span] = None self.begin: int = begin self.end: int = end super().__init__(pack) @property def audio(self): if self.pack is None: raise ValueError( "Cannot get audio because annotation is not " "attached to any data pack." ) return self.pack.get_span_audio(self.begin, self.end) @property def span(self) -> Span: # Delay span creation at usage. if self._span is None: self._span = Span(self.begin, self.end) return self._span def __eq__(self, other): r"""The eq function of :class:`AudioAnnotation`. By default, :class:`AudioAnnotation` objects are regarded as the same if they have the same type, span, and are generated by the same component. Users can define their own eq function by themselves but this must be consistent to :meth:`hash`. """ if other is None: return False return (type(self), self.begin, self.end) == ( type(other), other.begin, other.end, ) def __lt__(self, other): r"""To support total_ordering, `AudioAnnotation` must implement `__lt__`. The ordering is defined in the following way: 1. If the begin of the audio annotations are different, the one with larger begin will be larger. 2. In the case where the begins are the same, the one with larger end will be larger. 3. In the case where both offsets are the same, we break the tie using the normal sorting of the class name. """ if self.begin == other.begin: if self.end == other.end: return str(type(self)) < str(type(other)) return self.end < other.end else: return self.begin < other.begin @property def index_key(self) -> int: return self.tid
[docs] def get( self, entry_type: Union[str, Type[EntryType]], components: Optional[Union[str, Iterable[str]]] = None, include_sub_type=True, ) -> Iterable[EntryType]: """ This function wraps the :meth:`~forte.data.data_pack.DataPack.get()` method to find entries "covered" by this audio annotation. See that method for more information. For usage details, refer to :meth:`forte.data.ontology.top.Annotation.get()`. Args: entry_type: The type of entries requested. components: The component (creator) generating the entries requested. If `None`, will return valid entries generated by any component. include_sub_type: whether to consider the subtypes of the provided entry type. Default `True`. Yields: Each `Entry` found using this method. """ yield from self.pack.get(entry_type, self, components, include_sub_type)
@dataclass class ImageAnnotation(Entry): def __init__(self, pack: PackType, image_payload_idx: int = 0): """ ImageAnnotation type entries, such as "edge" and "box". Each ImageAnnotation has a ``image_payload_idx`` corresponding to its image representation in the payload array. Args: pack: The container that this image annotation will be added to. image_payload_idx: the index of the image payload in the DataPack's image payload list. If it's not set, it defaults to 0 which means it will load the first image payload. """ self._image_payload_idx = image_payload_idx super().__init__(pack) @property def image_payload_idx(self) -> int: return self._image_payload_idx @property def image(self): if self.pack is None: raise ValueError( "Cannot get image because image annotation is not " "attached to any data pack." ) return self.pack.get_payload_data_at( Modality.Image, self._image_payload_idx ) @property def max_x(self): return self.image_shape[1] - 1 @property def max_y(self): return self.image_shape[0] - 1 @property def image_shape(self): """ Returns the shape of the image. Raises: ValueError: if the image annotation is not attached to any data pack. ValueError: if the image shape is not valid. It must be either 2D or 3D. Returns: The shape of the image. """ if self.pack is None: raise ValueError( "Cannot get image because image annotation is not " "attached to any data pack." ) image_shape = self.pack.get_payload_at( Modality.Image, self._image_payload_idx ).cache_shape if not 2 <= len(image_shape) <= 3: raise ValueError( "Image shape is not valid." "It should be 2D ([height, width])" " or 3D ([height, width, channel])." ) return image_shape def __eq__(self, other): if other is None: return False return self.image_payload_idx == other.image_payload_idx class Region(ImageAnnotation): """ A region class associated with an image payload. Args: pack: the container that this ``Region`` will be added to. image_payload_idx: the index of the image payload in the DataPack's image payload list. If it's not set, it defaults to 0 which meaning it will load the first image payload. """ def __init__(self, pack: PackType, image_payload_idx: int = 0): super().__init__(pack, image_payload_idx) if image_payload_idx is None: self._image_payload_idx = 0 else: self._image_payload_idx = image_payload_idx def compute_iou(self, other) -> float: intersection = np.sum(np.logical_and(self.image, other.image)) union = np.sum(np.logical_or(self.image, other.image)) return intersection / union class Box(Region): """ A box class with a center position and a box configuration. Note: all indices are zero-based and counted from top left corner of image. Args: pack: the container that this ``Box`` will be added to. tl_point: the indices of top left point of the box [row index, column index], the unit is one pixel. br_point: the indices of bottom right point of the box [row index, column index], the unit is one pixel. image_payload_idx: the index of the image payload in the DataPack's image payload list. If it's not set, it defaults to 0 which meaning it will load the first image payload. """ def __init__( self, pack: PackType, tl_point: List[int], br_point: List[int], image_payload_idx: int = 0, ): super().__init__(pack, image_payload_idx) if tl_point[0] < 0 or tl_point[1] < 0: raise ValueError( f"input parameter top left point indices ({tl_point}) must" "be non-negative" ) if br_point[0] < 0 or br_point[1] < 0: raise ValueError( f"input parameter bottom right point indices ({br_point}) must" "be non-negative" ) if tl_point[0] >= br_point[0]: raise ValueError( f"top left point y coordinate({tl_point[0]}) must be less than" f" bottom right y coordinate({br_point[0]})" ) if tl_point[1] >= br_point[1]: raise ValueError( f"top left point x coordinate({tl_point[1]}) must be less than" f" bottom right x coordinate({br_point[1]})" ) self._y0, self._x0 = tl_point self._y1, self._x1 = br_point self._cy = round((self._y0 + self._y1) / 2) self._cx = round((self._x0 + self._x1) / 2) self._height = self._y1 - self._y0 self._width = self._x1 - self._x0 @classmethod def init_from_center_n_shape( cls, pack: PackType, cy: int, cx: int, height: int, width: int, image_payload_idx: int = 0, ): """ A class method to initialize a ``Box`` from a box's center position and shape. Note: all indices are zero-based and counted from top left corner of image. Args: pack: the container that this ``Box`` will be added to. cy: the row coordinate of the box's center, the unit is one pixel. cx: the column coordinate of the box's center, the unit is one pixel. height: the height of the box, the unit is one pixel. width: the width of the box, the unit is one pixel. image_payload_idx: the index of the image payload in the DataPack's image payload list. If it's not set, it defaults to 0 which meaning it will load the first image payload. Returns: A ``Box`` instance. """ # center location return cls( pack, [cy - round(height / 2), cx - round(width / 2)], [cy - round(height / 2) + height, cx - round(width / 2) + width], image_payload_idx, ) def compute_iou(self, other) -> float: """ A function computes iou(intersection over union) between two boxes (unit: pixel). It overwrites the ``compute_iou`` function in its parent class ``Region``. Args: other: the other ``Box`` object to be computed with. Returns: A float value which is (intersection area/ union area) between two boxes. """ if not isinstance(other, Box): raise ValueError( "The other object to compute iou with is" " not a Box object." "You need to check the type of the other object." ) if not self.is_overlapped(other): return 0 box_x_diff = min( abs(other.box_max_x - self.box_min_x), abs(other.box_min_x - self.box_max_x), ) box_y_diff = min( abs(other.box_max_y - self.box_min_y), abs(other.box_min_y - self.box_max_y), ) intersection = box_x_diff * box_y_diff union = self.area + other.area - intersection return intersection / union @property def center(self): return (self._cy, self._cx) @property def corners(self): """ Get corners of box. """ return [ (self._y0, self._x0), (self._y0, self._x1), (self._y1, self._x0), (self._y1, self._x1), ] @property def box_min_x(self): return self._x0 @property def box_max_x(self): return min(self._x1, self.max_x) @property def box_min_y(self): return self._y0 @property def box_max_y(self): return min(self._y1, self.max_y) @property def area(self): return self._height * self._width def is_overlapped(self, other): """ A function checks whether two boxes are overlapped(two box area have intersections). Note: in edges cases where two boxes' boundaries share the same line segment/corner in the image array, it won't be considered overlapped. Args: other: the other ``Box`` object to compared to. Returns: A boolean value indicating whether there is overlapped. """ # If one box is on left side of other if self.box_min_x > other.box_max_x or other.box_min_x > self.box_max_x: return False # If one box is above other if self.box_min_y > other.box_max_y or other.box_min_y > self.box_max_y: return False return True @dataclass class Payload(Entry): """ A payload class that holds data cache of one modality and its data source uri. Args: pack: The container that this `Payload` will be added to. Raises: ValueError: raised when the modality is not supported. """ def __init__( self, pack: PackType, ): # since we cannot pass different modality from generated ontology, and # we don't want to import base ontology in the header of the file # we import it here. if isinstance(self, TextPayload): modality = Modality.Text elif isinstance(self, AudioPayload): modality = Modality.Audio elif isinstance(self, ImagePayload): modality = Modality.Image else: supported_modality = [enum.name for enum in Modality] raise ValueError( f"The requested Payload modality is not supported. " f"Currently we only support {supported_modality}" ) self.modality = modality super().__init__(pack) self._uri: Optional[Union[Path, URL, str]] = None self._cache: Any = None self._cache_shape: Optional[Sequence[int]] = None self.replace_back_operations: Sequence[Tuple] = [] self.processed_original_spans: Sequence[Tuple] = [] self.orig_text_len: int = 0 def get_type(self) -> type: """ Get the class type of the payload class. For example, suppose a ``TextPayload`` inherits this ``Payload`` class, ``TextPayload`` will be returned. Returns: the type of the payload class. """ return type(self) def load(self): """ Behavior of this function depends on the registered `_load`, in general it should load the data from the provided URL to the internal cache. Returns: """ self._cache = self._load() # pylint: disable=assignment-from-no-return def _load(self) -> Any: pass @property def cache(self) -> Union[str, np.ndarray]: if self._cache is None: self.load() return self._cache @cache.setter def cache(self, data: Union[str, np.ndarray]): self._cache = data if isinstance(data, np.ndarray): # if it's a numpy array, we need to set the shape # even if user input a shape, it will be overwritten cache_shape = data.shape elif isinstance(data, str): cache_shape = (len(data),) else: raise ValueError( f"Unsupported data type {type(data)} for cache. " f"Currently we only support str and numpy.ndarray." ) self._cache_shape = cache_shape @property def uri(self) -> Optional[Union[str, Path, URL]]: """ Universal resource identifier of the data source. Returns: Universal resource identifier of the data source. """ return self._uri @uri.setter def uri(self, url: Optional[Union[str, Path, URL]]): self._uri = url def cache_shape(self) -> Optional[Sequence[int]]: return self._cache_shape def clear_cache(self): self._cache = None def __getstate__(self): r""" Convert ``modality`` ``Enum`` object to str format for serialization. """ # TODO: this function will be removed since # Entry store is being integrated into DataStore state = self.__dict__.copy() state["modality"] = self.modality.name state.pop("_Entry__pack") if isinstance(state["_cache"], np.ndarray): state["_cache"] = list(self._cache.tolist()) if isinstance(state["_embedding"], np.ndarray): state["_embedding"] = list(self._embedding.tolist()) return state def __setstate__(self, state): r""" Convert ``modality`` string to ``Enum`` object for deserialization. """ # TODO: this function will be removed since # Entry store is being integrated into DataStore self.__dict__.update(state) self.modality = getattr(Modality, state["modality"]) # During de-serialization, convert the list back to numpy array. if "_embedding" in state: state["_embedding"] = np.array(state["_embedding"]) else: state["_embedding"] = np.empty(0) # Here we assume that if the payload is not text (in which case # cache is stored a string), cache will always be stored as a # numpy array (which is converted to a list during serialization). # This check can be made more comprehensive when new types of # payloads are introduced. if "_cache" in state and isinstance(state["_cache"], list): state["_cache"] = np.array(state["_cache"]) @dataclass class AudioPayload(Payload): """ A payload that caches audio data Attributes: sample_rate (Optional[int]): """ sample_rate: Optional[int] def __init__(self, pack: PackType): super().__init__(pack) self.sample_rate: Optional[int] = None @dataclass class TextPayload(Payload): """ A payload that caches text data """ def __init__( # pylint: disable=useless-super-delegation self, pack: PackType ): super().__init__(pack) self._cache = "" def clear_cache(self): self._cache = "" @dataclass class ImagePayload(Payload): """ A payload that caches image data """ def __init__( # pylint: disable=useless-super-delegation self, pack: PackType ): super().__init__(pack) # # The registered functions for each payload loadng methods. # load_funcs: Dict[str, Callable[[str, Any], None]] = {} def load_func(payload_cls: Type["Payload"]): """ A decorator to assign a loading function to a Payload type. Args: payload_cls: the Payload class to register. Returns: """ def decorator_load_func(func: Callable): @functools.wraps(func) def assign_func(): payload_cls._load = func # type: ignore # pylint: disable=protected-access assign_func() return decorator_load_func SinglePackEntries = ( Link, Group, Annotation, Generics, AudioAnnotation, ImageAnnotation, Payload, TextPayload, AudioPayload, ImagePayload, ) MultiPackEntries = (MultiPackLink, MultiPackGroup, MultiPackGeneric) AnnotationLikeEntries = (Annotation, AudioAnnotation)