Source code for forte.data.data_utils_io

# Copyright 2019 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions related to data processing input/output.
"""
import os
from typing import Dict, List, Iterator, Any, Tuple

from forte.data.types import ReplaceOperationsType
from forte.data.span import Span

__all__ = [
    "batch_instances",
    "merge_batches",
    "slice_batch",
    "dataset_path_iterator",
]


[docs]def batch_instances(instances: List[Dict]): r"""Merge a list of ``instances``.""" batch: Dict[str, Any] = {} for instance in instances: for entry, fields in instance.items(): if isinstance(fields, dict): if entry not in batch: batch[entry] = {} for k, value in fields.items(): if k not in batch[entry]: batch[entry][k] = [] batch[entry][k].append(value) else: # context level feature if entry not in batch: batch[entry] = [] batch[entry].append(fields) return batch
[docs]def merge_batches(batches: List[Dict]): r"""Merge a list of ``batches``.""" merged_batch: Dict = {} for batch in batches: for entry, fields in batch.items(): if isinstance(fields, dict): if entry not in merged_batch: merged_batch[entry] = {} for k, value in fields.items(): if k not in merged_batch[entry]: merged_batch[entry][k] = [] merged_batch[entry][k].extend(value) else: # context level feature if entry not in merged_batch: merged_batch[entry] = [] merged_batch[entry].extend(fields) return merged_batch
[docs]def slice_batch(batch, start, length): r"""Return a sliced batch of size ``length`` from ``start`` in ``batch``.""" sliced_batch: Dict = {} for batch_key, fields in batch.items(): if isinstance(fields, dict): if batch_key not in sliced_batch: sliced_batch[batch_key] = {} for k, value in fields.items(): sliced_batch[batch_key][k] = value[start : start + length] else: # context level feature sliced_batch[batch_key] = fields[start : start + length] return sliced_batch
def dataset_path_iterator_with_base( dir_path: str, file_extension: str ) -> Iterator[Tuple[str, str]]: r"""An iterator returning file_paths in a directory containing files of the given datasets, including the original directory as the first element. """ for root, _, files in os.walk(dir_path): for data_file in files: if len(file_extension) > 0: if data_file.endswith(file_extension): yield dir_path, os.path.join(root, data_file) else: yield dir_path, os.path.join(root, data_file)
[docs]def dataset_path_iterator(dir_path: str, file_extension: str) -> Iterator[str]: r"""An iterator returning the file paths in a directory containing files of the given datasets. """ if not os.path.exists(dir_path): raise FileNotFoundError(f"Cannot find the directory [{dir_path}].") for root, _, files in os.walk(dir_path): for data_file in files: if len(file_extension) > 0: if data_file.endswith(file_extension): yield os.path.join(root, data_file) else: yield os.path.join(root, data_file)
def modify_text_and_track_ops( original_text: str, replace_operations: ReplaceOperationsType ) -> Tuple[str, ReplaceOperationsType, List[Tuple[Span, Span]], int]: r"""Modifies the original text using ``replace_operations`` provided by the user to return modified text and other data required for tracking original text. Args: original_text: Text to be modified. replace_operations: A list of spans and the corresponding replacement string that the span in the original string is to be replaced with to obtain the original string. Returns: modified_text: Text after modification. replace_back_operations: A list of spans and the corresponding replacement string that the span in the modified string is to be replaced with to obtain the original string. processed_original_spans: List of processed span and its corresponding original span. orig_text_len: length of original text. """ orig_text_len: int = len(original_text) mod_text: str = original_text increment: int = 0 prev_span_end: int = 0 replace_back_operations: List[Tuple[Span, str]] = [] processed_original_spans: List[Tuple[Span, Span]] = [] # Sorting the spans such that the order of replacement strings # is maintained -> utilizing the stable sort property of python sort replace_operations.sort(key=lambda item: item[0]) for span, replacement in replace_operations: if span.begin < 0 or span.end < 0: raise ValueError("Negative indexing not supported") if span.begin > len(original_text) or span.end > len(original_text): raise ValueError( "One of the span indices are outside the string length" ) if span.end < span.begin: print(span.begin, span.end) raise ValueError( "One of the end indices is lesser than start index" ) if span.begin < prev_span_end: raise ValueError( "The replacement spans should be mutually exclusive" ) span_begin = span.begin + increment span_end = span.end + increment original_span_text = mod_text[span_begin:span_end] mod_text = mod_text[:span_begin] + replacement + mod_text[span_end:] increment += len(replacement) - (span.end - span.begin) replacement_span = Span(span_begin, span_begin + len(replacement)) replace_back_operations.append((replacement_span, original_span_text)) processed_original_spans.append((replacement_span, span)) prev_span_end = span.end return ( mod_text, replace_back_operations, sorted(processed_original_spans), orig_text_len, )