# Copyright 2019 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from abc import ABC
from typing import Counter as CounterType, Dict, List, Optional
from asyml_utilities.special_tokens import SpecialTokens
from forte.processors.base import PackProcessor
__all__ = [
"Alphabet",
"VocabularyProcessor",
]
[docs]class Alphabet:
"""
Args:
name: The name of the alphabet
keep_growing: If True, new instances not found ruing `get_index` will
be added to the vocabulary.
ignore_case_in_query:
If it's True, Alphabet will try to query the lower-cased input from
it's vocabulary if it cannot find the input in its keys.
"""
def __init__(
self,
name,
word_cnt: Optional[CounterType[str]] = None,
keep_growing: bool = True,
ignore_case_in_query: bool = True,
other_embeddings: Optional[Dict] = None,
):
self.__name = name
self.reserved_tokens = SpecialTokens
self.instance2index: Dict = {}
self.instances: List = []
for sp in [
self.reserved_tokens.PAD,
self.reserved_tokens.BOS,
self.reserved_tokens.EOS,
self.reserved_tokens.UNK,
]:
self.instance2index[sp] = len(self.instance2index)
self.instances.append(sp)
self.pad_id = self.instance2index[self.reserved_tokens.PAD] # 0
self.bos_id = self.instance2index[self.reserved_tokens.BOS] # 1
self.eos_id = self.instance2index[self.reserved_tokens.EOS] # 2
self.unk_id = self.instance2index[self.reserved_tokens.UNK] # 3
self.keep_growing = keep_growing
self.ignore_case_in_query = ignore_case_in_query
self.other_embeddings = other_embeddings
if word_cnt is not None:
for word in word_cnt:
self.add(word)
self.close()
def add(self, instance):
if instance not in self.instance2index:
self.instance2index[instance] = len(self.instance2index)
self.instances.append(instance)
[docs] def get_index(self, instance):
"""
Args:
instance: the input token
Returns:
the index of the queried token in the dictionary
"""
if instance is None:
return self.instance2index[self.reserved_tokens.PAD]
try:
return self.instance2index[instance]
except KeyError:
if self.keep_growing:
self.add(instance)
return self.instance2index[instance]
else:
if self.ignore_case_in_query:
try:
return self.instance2index[instance.lower()]
except KeyError:
return self.instance2index[self.reserved_tokens.UNK]
else:
return self.instance2index[self.reserved_tokens.UNK]
def get_instance(self, index):
try:
return self.instances[index]
except IndexError as e:
raise IndexError("unknown index: %d" % index) from e
def size(self):
return len(self.instances)
def items(self):
return self.instance2index.items()
def close(self):
self.keep_growing = False
def open(self):
self.keep_growing = True
def get_content(self):
return {
"instance2index": self.instance2index,
"instances": self.instances,
}
def __from_json(self, data):
self.instances = data["instances"]
self.instance2index = data["instance2index"]
[docs] def save(self, output_directory, name=None):
"""
Save both alphabet records to the given directory.
Args:
output_directory: Directory to save model and weights.
name: The alphabet saving name, optional.
"""
saving_name = name if name else self.__name
if not os.path.exists(output_directory):
os.makedirs(output_directory)
with open(
os.path.join(output_directory, saving_name + ".json"),
"w",
encoding="utf-8",
) as out:
json.dump(
self.get_content(),
out,
indent=4,
)
def load(self, input_directory, name=None):
loading_name = name if name else self.__name
with open(
os.path.join(input_directory, loading_name + ".json"),
encoding="utf-8",
) as f:
self.__from_json(json.load(f))
self.keep_growing = False
[docs]class VocabularyProcessor(PackProcessor, ABC):
"""
Build vocabulary from the input DataPack, write the result into the
shared resources.
"""
def __init__(self) -> None:
super().__init__()
self.min_frequency = 0