# Copyright 2020 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["Vocabulary", "VocabFilter", "FrequencyVocabFilter"]
from abc import ABC
from collections import Counter
from typing import List, Tuple, Dict, Union, Hashable, Iterable, Optional
from typing import TypeVar, Generic, Any, Set
from asyml_utilities.special_tokens import SpecialTokens
from forte.common import InvalidOperationException
ElementType = TypeVar("ElementType", bound=Hashable)
[docs]class Vocabulary(Generic[ElementType]):
r"""This class will store "Elements" that are added, assign "Ids" to them
and return "Representations" if queried. These three are the main concepts
in this class.
1. Element: Any hash-able instance that the user want to store.
2. Id: Each element will have an unique Id, which is an integer.
3. Representation: according to the configuration, the representation for
an element could be an integer (in this case, would be "Id"), or
an one-hot vector (in this case, would be a list of integer).
The class adopts the special elements from `Texar-Pytorch`, which are:
1. <PAD>: which will be mapped into Id of 0 or -1 and have
different representation according to different setting.
2. <UNK>: if added into the vocabulary,
will be the default element if the queried element is not found.
Note that these two special tokens are necessary for the system in certain
cases and thus must present in the vocabulary. The behavior of these
special tokens are pre-defined based on different settings. To get around
the default behavior (for example, if you have a pre-defined vocabulary
with different setups), you can instruct the class to not adding these
tokens automatically, and use the :func:`mark_special_element` instead.
Here is a table on how our Vocabulary class behavior under different
settings. Element0 means the first element that is added to the vocabulary.
Elements added later will be element1, element2 and so on. They will follow
the same behavior as element0. For readability, they are not
listed in the table.
.. list-table:: Vocabulary Behavior under different settings.
* - `vocab_method`
- custom (handle and implemented by the user)
- indexing
- indexing
- one-hot
- one-hot
* - `need_pad`
- assume False
- True
- False
- True
- False
* - `get_pad_value`
- None
- 0
- None
- [0,0,0]
- None
* - `inner_mapping`
- None
- 0:pad 1:element0
- 0:element0
- -1:<PAD> 0:element0
- 0:element0
* - `element2repr`
- raise Error
- pad->0 element0->1
- element0->0
- <PAD>->[0,0,0] element0->[1,0,0]
- element0->[1,0,0]
* - `id2element`
- raise Error
- 0->pad 1->element0
- 0->element0
- -1 -> <PAD> 0->element0 (be careful)
- 0->element0
Args:
method: The method to represent element in vocabulary, currently
supporting "indexing" and "one-hot".
use_pad: Whether to add <PAD> element to the vocabulary on
creation. It will be added to the vocabulary first, but the id of
it depends on the specific settings.
use_unk: Whether to add <UNK> element to the vocabulary on
creation. Elements that are not found in vocabulary will be
directed to <UNK> element. It will be added right after the <PAD>
element if provided.
special_tokens: Additional special tokens to be added, they
will be added at the beginning of vocabulary (but right after the
<UNK> token) one by one.
do_counting: Whether the vocabulary class will count the
elements.
pad_value: A customized value/representation to be used for
padding, for example, following the PyTorch convention you may
want to use -100. This value is only needed when `use_pad` is True.
Default is None, where the value of padding is determined by
the system.
unk_value: A customized value/representation to be used for
unknown value (`unk`). This value is only needed when `use_unk`
is True. Default is None, where the value of `UNK` is determined
by the system.
Attributes:
method (str): Same as above.
use_pad (bool): Same as above.
use_unk (bool): Same as above.
do_counting (bool): Same as above.
"""
def __init__(
self,
method: str = "indexing",
use_pad: bool = True,
use_unk: bool = True,
special_tokens: Optional[List[str]] = None,
do_counting: bool = True,
pad_value: Any = None,
unk_value: Any = None,
):
self.method: str = method
self.use_pad: bool = use_pad
self.use_unk: bool = use_unk
self.do_counting: bool = do_counting
self._pad_id: Optional[int] = None
self._unk_id: Optional[int] = None
# Maps the raw element to the internal id.
self._element2id: Dict = {}
# Maps the internal id to the raw element.
self._id2element: Dict = {}
# Maps the internal id to the representation. This dict is populated
# when users provided customized representation of elements.
self._id2repr: Dict = {}
# Count the number of appearance of an element, indexed by the element
# id.
self.__counter: Counter = Counter()
# Initialize the id auto counter.
self.next_id = 0
# Store the base special token names and their surface form.
# By default, following the texar-pytorch special tokens:
# PAD: <PAD>
# UNK: <UNK>
self._base_special_tokens: Dict[str, str] = {}
# Store the id position of the special ids.
self.__special_ids: Set[int] = set()
if use_pad:
# If not specified, will use -1 for padding in the case of
# one-hot. This will make the actual PAD representation to be
# a vector of zeros.
pad_id = -1 if method == "one-hot" else None
self.add_special_element(
SpecialTokens.PAD,
element_id=pad_id,
special_token_name="PAD",
representation=pad_value,
)
if use_unk:
self.add_special_element(
SpecialTokens.UNK,
special_token_name="UNK",
representation=unk_value,
)
if special_tokens is not None:
for t in special_tokens:
self.add_special_element(t)
[docs] def get_count(self, e: Union[ElementType, int]) -> int:
"""
Get the counts of the vocabulary element.
Args:
e: The element to get counts for. It can be the element id or the
element's raw type.
Returns:
The count of the element.
"""
if not self.do_counting:
if not self.do_counting:
raise InvalidOperationException(
"The vocabulary is not configured to count the elements."
)
eid: int = e if isinstance(e, int) else self._element2id[e]
if self.is_special_token(eid):
raise InvalidOperationException(
"Count for special element is not available."
)
return self.__counter[eid]
[docs] def mark_special_element(
self, element_id: int, element_name: str, representation: Any = None
):
"""
Mark a particular (but already existed) index in the vocabulary to be
a special required element (i.e `PAD` or `UNK`).
Args:
element_id: The id to be set for the special element.
element_name: The name of this element to be set, it can
be one of `PAD`, `UNK`.
representation: The representation/value that this element should
be assigned. Default is None, then its representation will be
computed from the internal indexing.
"""
if element_name in ("PAD", "UNK"):
if element_name == "PAD":
self.use_pad = True
if element_name == "UNK":
self.use_unk = True
if element_id in self._id2element:
self._base_special_tokens[element_name] = self._id2element[
element_id
]
else:
raise ValueError(
f"Supplied {element_id} is not in the"
f" current vocabulary."
)
# Store the customized representation
if representation is not None:
self._id2repr[element_id] = representation
else:
raise ValueError(
f"{element_name} is not a required special element, you can"
f" add it in through `special_tokens` argument during class"
f" creation, or calling the `add_special_element` method"
)
[docs] def is_special_token(self, element_id: int):
"""Check whether the element is a special token."""
return element_id in self.__special_ids
[docs] def add_special_element(
self,
element: str,
element_id: Optional[int] = None,
representation=None,
special_token_name: Optional[str] = None,
):
"""
This function will add special elements to the vocabulary, such as
`UNK`, `PAD`, `BOS`, `CLS` symbols. Some special tokens will not be
filtered by any `VocabFilter`. Some special tokens has their
unique behavior in the system.
.. note::
most of the time, you don't have to call this method yourself,
but should let the `init` function to handle that.
Args:
element: The surface form of this special element.
element_id: The to be used for this special token.
If not provided, the vocabulary will use the next id internally.
If the provided id is occupied, a `ValueError` will be thrown.
The id can be any integer, including negative ones.
representation: The representation you want to assign to this
special token. If None, the representation may be computed
based on the index (which depends on the vocabulary setting).
special_token_name: An internal name of
this special token. This only matters for the base special
tokens: <PAD> or <UNK>, and the name should be "PAD" and "UNK"
respectively. Any other name here is considered invalid,
and a `ValueError` will be thrown if provided.
"""
if special_token_name is not None:
if special_token_name not in ("PAD", "UNK"):
raise ValueError(
"You don't have to and shouldn't provide the "
"`special_token_name` if this token is not PAD or UNK"
)
self._base_special_tokens[special_token_name] = element
if element_id is not None:
if element_id in self._id2element:
raise ValueError(
f"ID {element_id} has already been used in Vocabulary. "
)
else:
# Use auto-incremented id.
element_id = self.__get_next_available_id()
self._element2id[element] = element_id
self._id2element[element_id] = element
self.__special_ids.add(element_id)
if representation is not None:
self._id2repr[element_id] = representation
[docs] def add_element(
self, element: ElementType, representation: Any = None, count: int = 1
) -> int:
r"""This function will add a regular element to the vocabulary.
Args:
element: The element to be added.
representation: The vocabulary representation of this element
will use this value. For example, you may want to use `-100`
for ignored tokens for PyTorch skipped tokens. Note that the
class do not check whether this representation is used by
another element, so the caller have to manage the behavior
itself.
count: the count to be incremented for this element, default
is 1 (i.e. consider it appear once on every add). This value
will have effect only if `do_counting` is True.
Returns:
The internal id of the element.
"""
element_id_: int
try:
element_id_ = self._element2id[element]
if self.do_counting:
self.__counter[element_id_] += count
except KeyError:
element_id_ = self.__get_next_available_id()
self._element2id[element] = element_id_
self._id2element[element_id_] = element
if representation:
self._id2repr[element_id_] = representation
if self.do_counting:
self.__counter[element_id_] = count
return element_id_
def __get_next_available_id(self):
"""Find the next available id by incrementing the auto counter until
one is found.
"""
eid = self.next_id
while eid in self._id2element:
self.next_id += 1
eid = self.next_id
return eid
[docs] def id2element(self, idx: int) -> ElementType:
r"""This function will map id to element.
Args:
idx: The queried id of element.
Returns:
The corresponding element if exist. Check the behavior
of this function under different setting in the documentation.
Raises:
KeyError: If the id is not found.
"""
return self._id2element[idx]
[docs] def element2repr(
self, element: Union[ElementType, Any]
) -> Union[int, List[int]]:
r"""This function will map element to representation.
Args:
element: The queried element. It can be either the same
type as the element, or string (for the special tokens).
Returns:
Union[int, List[int]]: The corresponding representation
of the element. Check the behavior of this function
under different setting in the documentation.
Raises:
KeyError: If element is not found and vocabulary does
not use <UNK> element.
"""
if self.use_unk:
idx = self._element2id.get(
element, self._element2id[self._base_special_tokens["UNK"]]
)
else:
idx = self._element2id[element]
# If a custom representation is set for this idx, we will use it.
if idx in self._id2repr:
return self._id2repr[idx]
elif self.method == "indexing":
return idx
elif self.method == "one-hot":
return self._one_hot(idx)
else:
raise InvalidOperationException(
f"Cannot find the representation for idx at [{idx}], it does"
f" not have a customized representation, and the representation"
f" method [{self.method}] is not supported."
)
[docs] def to_dict(self) -> Dict[ElementType, Any]:
"""
Create a dictionary from the vocabulary storing all the known elements.
Returns:
The vocabulary as a Dict from ElementType to the
representation of the element (could be Integer or One-hot vector,
depending on the settings of this class).
"""
vocab_dict: Dict[ElementType, Any] = {}
for element in self._element2id:
vocab_dict[element] = self.element2repr(element)
return vocab_dict
def _one_hot(self, idx: int):
"""Compute the one-hot encoding on the fly."""
vec_size = len(self._element2id)
if self.use_pad:
vec_size -= 1
vec = [0 for _ in range(vec_size)]
if idx != -1:
vec[idx] = 1
return vec
def __len__(self) -> int:
r"""This function return the size of vocabulary.
Returns:
int: The number of elements, including <PAD>, <UNK>.
"""
return len(self._element2id)
[docs] def has_element(self, element: Union[ElementType, str]) -> bool:
r"""This function checks whether an element is added to vocabulary.
Args:
element: The queried element.
Returns:
bool: Whether element is found.
"""
return element in self._element2id
[docs] def vocab_items(self) -> Iterable[Tuple[Union[ElementType, str], int]]:
r"""This function will loop over the (element, id) pair inside this
class.
Returns:
Iterable[Tuple]: Iterables of (element, id) pair.
"""
return self._element2id.items()
[docs] def get_pad_value(self) -> Union[None, int, List[int]]:
r"""This function will get the representation of the PAD element for
the vocabulary. The representation depends on the settings of this
class, it can be an integer or a list of int (e.g. a vector).
Returns:
Union[None, int, List[int]]: The PAD element. Check
the behavior of this function in the class documentation.
"""
if self.use_pad:
return self.element2repr(self._base_special_tokens["PAD"])
return None
[docs] def filter(self, vocab_filter: "VocabFilter") -> "Vocabulary":
"""
This function will create a new vocabulary object, which
is based on the current vocabulary, but filter out elements that
appear fewer times than the `min_count` value. Calling this function
will cause a full iteration over the vocabulary, thus normally, it
should be called after collecting all the vocabulary in the dataset.
Args:
vocab_filter: The filter used to filter the vocabulary.
Returns:
A new vocabulary after filtering.
"""
# Make a new vocab class:
# 1. We do not add the PAD or UNK at init, but will copy them later.
# 2. We also ignore other special tokens at init, but copy them later.
# 2. We follow the do_counting setup.
vocab: Vocabulary = Vocabulary(
self.method,
use_pad=False,
use_unk=False,
do_counting=self.do_counting,
)
# We then set these flag manually based on this vocabulary.
vocab.use_pad = self.use_pad
vocab.use_unk = self.use_unk
# Now we copy all the vocabulary items to the new vocab.
for element, eid in self.vocab_items():
# Copy the special tokens regardless of the filter.
if self.is_special_token(eid):
element_name = None
try:
if element == self._base_special_tokens["PAD"]:
element_name = "PAD"
except KeyError:
# No PAD in the origin vocab.
pass
try:
if element == self._base_special_tokens["UNK"]:
element_name = "UNK"
except KeyError:
# No UNK in the origin vocab.
pass
# Special element value must be string.
assert isinstance(element, str)
vocab.add_special_element(
element, eid, self._id2repr.get(element, None), element_name
)
elif not vocab_filter.filter(eid):
vocab.add_element(
element,
count=self.get_count(eid) if self.do_counting else 1,
)
return vocab
[docs]class VocabFilter(ABC):
"""
Base class for vocabulary filters, which is used to implement constraints
to choose a subset of vocabulary. For example, one can filter out vocab
elements that happen fewer than a certain frequency.
Args:
vocab: The vocabulary object to be filtered.
"""
def __init__(self, vocab: Vocabulary):
self._vocab = vocab
[docs] def filter(self, element_id: int) -> bool:
"""
Given the element id, it will determine whether the element should be
filtered out.
Args:
element_id: The element id to be checked.
Returns:
None
"""
raise NotImplementedError
[docs]class FrequencyVocabFilter(VocabFilter):
"""
A frequency based filter. It will filter vocabulary elements that appear
fewer than `min_frequency` or more than `max_frequency`. The check will
be skipped if the threshold values are negative.
Args:
vocab: The vocabulary object.
min_frequency: The min frequency threshold, default -1 (i.e. no
frequency check for min).
max_frequency: The max frequency threshold, default -1 (i.e. no
frequency check for max).
"""
def __init__(
self,
vocab: Vocabulary,
min_frequency: int = -1,
max_frequency: int = -1,
):
super().__init__(vocab)
self.min_freq = min_frequency
self.max_freq = max_frequency
if not vocab.do_counting:
raise InvalidOperationException(
"The provided vocabulary is not configured to collect counts, "
"cannot filter the vocabulary based on counts."
)
[docs] def filter(self, element_id: int) -> bool:
freq = self._vocab.get_count(element_id)
will_filter = False
if self.min_freq >= 0 and freq < self.min_freq:
will_filter = True
if 0 <= self.max_freq < freq:
will_filter = True
return will_filter