Source code for forte.data.ontology.top

# Copyright 2019 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import total_ordering
from typing import Optional, Set, Tuple, Type, Any, Dict, Union, Iterable, List

import numpy as np

from forte.data.base_pack import PackType
from forte.data.ontology.core import (
    Entry,
    BaseLink,
    BaseGroup,
    MultiEntry,
    EntryType,
)
from forte.data.span import Span

__all__ = [
    "Generics",
    "Annotation",
    "Group",
    "Link",
    "MultiPackGeneric",
    "MultiPackGroup",
    "MultiPackLink",
    "Query",
    "SinglePackEntries",
    "MultiPackEntries",
    "AudioAnnotation",
]

QueryType = Union[Dict[str, Any], np.ndarray]


[docs]class Generics(Entry): def __init__(self, pack: PackType): super().__init__(pack=pack)
[docs]@total_ordering class Annotation(Entry): r"""Annotation type entries, such as "token", "entity mention" and "sentence". Each annotation has a :class:`~forte.data.span.Span` corresponding to its offset in the text. Args: pack: The container that this annotation will be added to. begin: The offset of the first character in the annotation. end: The offset of the last character in the annotation + 1. """ def __init__(self, pack: PackType, begin: int, end: int): self._span: Optional[Span] = None self._begin: int = begin self._end: int = end super().__init__(pack) def __getstate__(self): r"""For serializing Annotation, we should create Span annotations for compatibility purposes. """ self._span = Span(self._begin, self._end) state = super().__getstate__() state.pop("_begin") state.pop("_end") return state def __setstate__(self, state): """ For de-serializing Annotation, we load the begin, end from Span, for compatibility purposes. """ super().__setstate__(state) self._begin = self._span.begin self._end = self._span.end @property def span(self) -> Span: # Delay span creation at usage. if self._span is None: self._span = Span(self._begin, self._end) return self._span @property def begin(self): return self._begin @property def end(self): return self._end def __eq__(self, other): r"""The eq function of :class:`Annotation`. By default, :class:`Annotation` objects are regarded as the same if they have the same type, span, and are generated by the same component. Users can define their own eq function by themselves but this must be consistent to :meth:`hash`. """ if other is None: return False return (type(self), self.begin, self.end) == ( type(other), other.begin, other.end, ) def __lt__(self, other): r"""To support total_ordering, `Annotation` must implement `__lt__`. The ordering is defined in the following way: 1. If the begin of the annotations are different, the one with larger begin will be larger. 2. In the case where the begins are the same, the one with larger end will be larger. 3. In the case where both offsets are the same, we break the tie using the normal sorting of the class name. """ if self.begin == other.begin: if self.end == other.end: return str(type(self)) < str(type(other)) return self.end < other.end else: return self.begin < other.begin @property def text(self): if self.pack is None: raise ValueError( "Cannot get text because annotation is not " "attached to any data pack." ) return self.pack.get_span_text(self.begin, self.end) @property def index_key(self) -> int: return self.tid
[docs] def get( self, entry_type: Union[str, Type[EntryType]], components: Optional[Union[str, Iterable[str]]] = None, include_sub_type=True, ) -> Iterable[EntryType]: """ This function wraps the :meth:`~forte.data.data_pack.DataPack.get` method to find entries "covered" by this annotation. See that method for more information. Example: .. code-block:: python # Iterate through all the sentences in the pack. for sentence in input_pack.get(Sentence): # Take all tokens from each sentence created by NLTKTokenizer. token_entries = sentence.get( entry_type=Token, component='NLTKTokenizer') ... In the above code snippet, we get entries of type :class:`~ft.onto.base_ontology.Token` within each ``sentence`` which were generated by ``NLTKTokenizer``. You can consider build coverage index between `Token` and `Sentence` if this snippet is frequently used. Args: entry_type: The type of entries requested. components: The component (creator) generating the entries requested. If `None`, will return valid entries generated by any component. include_sub_type: whether to consider the sub types of the provided entry type. Default `True`. Yields: Each `Entry` found using this method. """ yield from self.pack.get(entry_type, self, components, include_sub_type)
# pylint: disable=duplicate-bases
[docs]class Group(BaseGroup[Entry]): r"""Group is an entry that represent a group of other entries. For example, a "coreference group" is a group of coreferential entities. Each group will store a set of members, no duplications allowed. """ MemberType: Type[Entry] = Entry def __init__( self, pack: PackType, members: Optional[Iterable[Entry]] = None, ): # pylint: disable=useless-super-delegation self._members: Set[int] = set() super().__init__(pack, members)
[docs] def add_member(self, member: Entry): r"""Add one entry to the group. Args: member: One member to be added to the group. """ if not isinstance(member, self.MemberType): raise TypeError( f"The members of {type(self)} should be " f"instances of {self.MemberType}, but got {type(member)}" ) self._members.add(member.tid)
[docs] def get_members(self) -> List[Entry]: r"""Get the member entries in the group. Returns: A set of instances of :class:`~forte.data.ontology.core.Entry` that are the members of the group. """ if self.pack is None: raise ValueError( "Cannot get members because group is not " "attached to any data pack." ) member_entries = [] for m in self._members: member_entries.append(self.pack.get_entry(m)) return member_entries
[docs]class MultiPackGeneric(MultiEntry, Entry): def __init__(self, pack: PackType): super().__init__(pack=pack)
# pylint: disable=duplicate-bases
[docs]class MultiPackGroup(MultiEntry, BaseGroup[Entry]): r"""Group type entries, such as "coreference group". Each group has a set of members. """ MemberType: Type[Entry] = Entry def __init__( self, pack: PackType, members: Optional[Iterable[Entry]] = None ): # pylint: disable=useless-super-delegation self._members: List[Tuple[int, int]] = [] super().__init__(pack) if members is not None: self.add_members(members)
[docs] def add_member(self, member: Entry): if not isinstance(member, self.MemberType): raise TypeError( f"The members of {type(self)} should be " f"instances of {self.MemberType}, but got {type(member)}" ) self._members.append( # fix bug/enhancement 559: use pack_id instead of index (member.pack_id, member.tid) # self.pack.get_pack_index(..) )
[docs] def get_members(self) -> List[Entry]: members = [] for pack_idx, member_tid in self._members: members.append(self.pack.get_subentry(pack_idx, member_tid)) return members
[docs]@dataclass class Query(Generics): r"""An entry type representing queries for information retrieval tasks. Args: pack: Data pack reference to which this query will be added """ value: Optional[QueryType] results: Dict[str, float] def __init__(self, pack: PackType): super().__init__(pack) self.value: Optional[QueryType] = None self.results: Dict[str, float] = {}
[docs] def add_result(self, pid: str, score: float): """ Set the result score for a particular pack (based on the pack id). Args: pid: the pack id. score: the score for the pack Returns: None """ self.results[pid] = score
[docs] def update_results(self, pid_to_score: Dict[str, float]): r"""Updates the results for this query. Args: pid_to_score: A dict containing pack id -> score mapping """ self.results.update(pid_to_score)
[docs]@total_ordering class AudioAnnotation(Entry): r"""AudioAnnotation type entries, such as "recording" and "audio utterance". Each audio annotation has a :class:`~forte.data.span.Span` corresponding to its offset in the audio. Most methods in this class are the same as the ones in :class:`Annotation`, except that it replaces property `text` with `audio`. Args: pack: The container that this audio annotation will be added to. begin: The offset of the first sample in the audio annotation. end: The offset of the last sample in the audio annotation + 1. """ def __init__(self, pack: PackType, begin: int, end: int): self._span: Optional[Span] = None self._begin: int = begin self._end: int = end super().__init__(pack) @property def audio(self): if self.pack is None: raise ValueError( "Cannot get audio because annotation is not " "attached to any data pack." ) return self.pack.get_span_audio(self.begin, self.end) def __getstate__(self): r"""For serializing AudioAnnotation, we should create Span annotations for compatibility purposes. """ self._span = Span(self._begin, self._end) state = super().__getstate__() state.pop("_begin") state.pop("_end") return state def __setstate__(self, state): """ For de-serializing AudioAnnotation, we load the begin, end from Span, for compatibility purposes. """ super().__setstate__(state) self._begin = self._span.begin self._end = self._span.end @property def span(self) -> Span: # Delay span creation at usage. if self._span is None: self._span = Span(self._begin, self._end) return self._span @property def begin(self): return self._begin @property def end(self): return self._end def __eq__(self, other): r"""The eq function of :class:`AudioAnnotation`. By default, :class:`AudioAnnotation` objects are regarded as the same if they have the same type, span, and are generated by the same component. Users can define their own eq function by themselves but this must be consistent to :meth:`hash`. """ if other is None: return False return (type(self), self.begin, self.end) == ( type(other), other.begin, other.end, ) def __lt__(self, other): r"""To support total_ordering, `AudioAnnotation` must implement `__lt__`. The ordering is defined in the following way: 1. If the begin of the audio annotations are different, the one with larger begin will be larger. 2. In the case where the begins are the same, the one with larger end will be larger. 3. In the case where both offsets are the same, we break the tie using the normal sorting of the class name. """ if self.begin == other.begin: if self.end == other.end: return str(type(self)) < str(type(other)) return self.end < other.end else: return self.begin < other.begin @property def index_key(self) -> int: return self.tid
[docs] def get( self, entry_type: Union[str, Type[EntryType]], components: Optional[Union[str, Iterable[str]]] = None, include_sub_type=True, ) -> Iterable[EntryType]: """ This function wraps the :meth:`~forte.data.data_pack.DataPack.get()` method to find entries "covered" by this audio annotation. See that method for more information. For usage details, refer to :meth:`forte.data.ontology.top.Annotation.get()`. Args: entry_type: The type of entries requested. components: The component (creator) generating the entries requested. If `None`, will return valid entries generated by any component. include_sub_type: whether to consider the sub types of the provided entry type. Default `True`. Yields: Each `Entry` found using this method. """ yield from self.pack.get(entry_type, self, components, include_sub_type)
SinglePackEntries = (Link, Group, Annotation, Generics, AudioAnnotation) MultiPackEntries = (MultiPackLink, MultiPackGroup, MultiPackGeneric)