Source code for hypergol.tensorflow_model_manager

import json
from pathlib import Path

import tensorflow as tf
from tqdm.auto import tqdm


[docs]class TensorflowModelManager: """ Class for managing TensorFlow model training. """
[docs] def __init__(self, model, optimizer, batchProcessor, project, restoreWeightsPath=None): """ Parameters ---------- model: BaseModel model subclassed from BaseModel that is to be trained optimizer: TensorFlow optimizer optimizer from TensorFlow package to use for training batchProcessor: Dataset Hypergol dataset to use for training project: HypergolProject Hypergol project to handle directories restoreWeightsPath: path path to restore variables from a previously trained model """ self.model = model self.optimizer = optimizer self.batchProcessor = batchProcessor self.project = project self.restoreWeightsPath = restoreWeightsPath self.globalStep = 0 self.trainingSummaryWriter = None self.evaluationSummaryWriter = None
[docs] def save_model(self): """Saves TensorFlow model, block definitions, and weights """ modelDirectory = Path(self.project.modelDataPath, self.model.modelName, str(self.globalStep)) modelDirectory.mkdir(parents=True, exist_ok=False) tf.saved_model.save(self.model, export_dir=str(modelDirectory), signatures={'output': self.model.get_outputs}) for modelBlock in self.model.get_model_blocks(): json.dump(modelBlock.get_config(), open(f'{modelDirectory}/{modelBlock.blockName}.json', 'w')) self.model.save_weights(filepath=f'{modelDirectory}/{self.model.modelName}.h5', save_format='h5')
[docs] def restore_model_weights(self): """Restores TensorFlow model weights """ self.model.load_weights(filepath=f'{self.restoreWeightsPath}/{self.model.modelName}.h5')
[docs] def train(self, withTracing): """Runs a single training step for the model Parameters ---------- withTracing: bool log TensorFlow graph metadata for the step """ inputs, targets = next(self.batchProcessor) if withTracing: tf.summary.trace_on(graph=True, profiler=False) with tf.GradientTape() as tape: loss = self.model.get_loss(targets=targets, training=True, **inputs) grads = tape.gradient(loss, self.model.trainable_variables) self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) # Watch this issue: https://github.com/PyCQA/pylint/issues/3596 with self.trainingSummaryWriter.as_default(): # pylint: disable=not-context-manager tf.summary.scalar(name='Loss', data=loss, step=self.globalStep) if withTracing: tf.summary.trace_export( name=f'{self.model.modelName}{self.globalStep}', step=self.globalStep, profiler_outdir=str(Path(self.project.tensorboardPath, self.model.modelName, 'trainGraph')) ) self.globalStep += 1 return loss
[docs] def evaluate(self, withTracing): """Runs a single evaluation step for the model Parameters ---------- withTracing: bool log TensorFlow graph metadata for step """ inputs, targets = next(self.batchProcessor) if withTracing: tf.summary.trace_on(graph=True, profiler=False) loss = self.model.get_loss(targets=targets, training=False, **inputs) outputs = self.model.get_evaluation_outputs(**inputs) with self.evaluationSummaryWriter.as_default(): # pylint: disable=not-context-manager tf.summary.scalar(name='Loss', data=loss, step=self.globalStep) self.model.produce_metrics(targets=targets, training=False, globalStep=self.globalStep, **inputs) if withTracing: tf.summary.trace_export( name=f'{self.model.modelName}{self.globalStep}', step=self.globalStep, profiler_outdir=str(Path(self.project.tensorboardPath, self.model.modelName, 'evaluateGraph')) ) self.batchProcessor.save_batch(inputs=inputs, targets=targets, outputs=outputs) return loss
[docs] def start(self): """Prepares to run the training cycle by creating the model data directories, create the ``SummaryWriters`` for Tensorboard for training and evaluation, initialises the batchprocessor (opens the output dataset for writing) and reloads the weights if ``restoreWeightsPath`` is specified. """ Path(self.project.tensorboardPath, self.model.modelName).mkdir(parents=True, exist_ok=True) self.trainingSummaryWriter = tf.summary.create_file_writer(logdir=str(Path(self.project.tensorboardPath, self.model.modelName, 'train'))) self.evaluationSummaryWriter = tf.summary.create_file_writer(logdir=str(Path(self.project.tensorboardPath, self.model.modelName, 'evaluate'))) self.batchProcessor.start() if self.restoreWeightsPath is not None: self.evaluate(withTracing=False) # model call needed to initialize layers/weights before reloading self.model.built = True self.restore_model_weights()
[docs] def run(self, stepCount, evaluationSteps, tracingSteps): """Runs a training schedule Model is saved in every evaluation step and at the very last step. Parameters ---------- stepCount: int num total steps in the schedule evaluationSteps: List[int] which steps to produce metrics on an evaluation sample tracingSteps: List[int] which steps to log graph metadata (components, memory consumption, etc.) to Tensorboard """ self.start() try: for step in tqdm(range(stepCount)): if step in evaluationSteps: self.save_model() self.evaluate(withTracing=step in tracingSteps) self.train(withTracing=step in tracingSteps) self.save_model() finally: self.finish()
[docs] def finish(self): """Finishes training run by closing the output dataset. Runs even if the training was interrupted by an exception (e.g. Ctrl-C) """ self.batchProcessor.finish()