Source code for forte.train_pipeline

# Copyright 2019 The Forte Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Optional, List

from forte.common.configuration import Config
from forte.common.resources import Resources
from forte.data.base_reader import BaseReader
from forte.evaluation.base.base_evaluator import Evaluator
from forte.pipeline import Pipeline
from forte.processors.base import BaseProcessor
from forte.trainer.base import BaseTrainer

logger = logging.getLogger(__name__)


[docs]class TrainPipeline: def __init__( self, train_reader: BaseReader, trainer: BaseTrainer, dev_reader: BaseReader, configs: Config, preprocessors: Optional[List[BaseProcessor]] = None, evaluator: Optional[Evaluator] = None, predictor: Optional[BaseProcessor] = None, ): self.resource = Resources() self.configs = configs train_reader.initialize(self.resource, self.configs.reader) if preprocessors is not None: for p in preprocessors: p.initialize( resources=self.resource, configs=configs.preprocessor ) self.preprocessors = preprocessors else: self.preprocessors = [] self.train_reader = train_reader self.dev_reader = dev_reader self.trainer = trainer if predictor is not None: self.predictor = predictor if evaluator is not None: self.evaluator = evaluator self.evaluator.initialize(self.resource, self.configs.evaluator) def run(self): logging.info("Preparing the pipeline") self.prepare() logging.info("Initializing the trainer") # initialize the pipeline after prepare step, since prepare will update # the resources self.trainer.initialize(self.resource, self.configs) if self.predictor is not None: logger.info("Initializing the predictor") self.predictor.initialize(self.resource, self.configs) logging.info("The pipeline is training") self.train() self.finish() def prepare(self): prepare_pl: Pipeline = Pipeline() prepare_pl.set_reader(self.train_reader) for p in self.preprocessors: prepare_pl.add(p) prepare_pl.run(self.configs.config_data.train_path) def train(self): epoch = 0 while True: epoch += 1 for pack in self.train_reader.iter( self.configs.config_data.train_path ): for instance in pack.get_data(**self.trainer.data_request()): self.trainer.consume(instance) self.trainer.epoch_finish_action(epoch) if self.trainer.validation_requested(): dev_res = self._validate(epoch) self.trainer.validation_done() self.trainer.post_validation_action(dev_res) if self.trainer.stop_train(): return logging.info("End of epoch %d", epoch) def _validate(self, epoch: int): validation_result = {"epoch": epoch} if self.predictor is not None: for pack in self.dev_reader.iter(self.configs.config_data.val_path): predicted_pack = pack.view() self.predictor.process(predicted_pack) self.evaluator.consume_next(predicted_pack, pack) validation_result["eval"] = self.evaluator.get_result() if self.evaluator is not None: for pack in self.dev_reader.iter( self.configs.config_data.test_path ): predicted_pack = pack.view() self.predictor.process(predicted_pack) self.evaluator.consume_next(predicted_pack, pack) validation_result["test"] = self.evaluator.get_result() return validation_result def finish(self): self.train_reader.finish(self.resource) self.dev_reader.finish(self.resource) self.trainer.finish(self.resource) self.predictor.finish(self.resource)