Source code for hype.checkpoint

#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from os.path import join as pjoin
import time
import torch


[docs]class LocalCheckpoint(object): """ Module for managing model checkpoints. Args: path (str): path to save the checkpoint to include_in_all (dict): a dictionary of objects to save in every call to :func:``save`` start_fresh (bool): If ``True``, then ignore any existing checkpoint, otherwise initialize from previous checkpoint """ def __init__(self, path, include_in_all=None, start_fresh=False): self.path = path self.start_fresh = start_fresh self.include_in_all = {} if include_in_all is None else include_in_all
[docs] def initialize(self, params): """ Initialize the checkpoint. If ``start_fresh`` is ``True``, then ``params`` is returned. Otherwise if a checkpoint at ``self.path`` exists, the checkpoint is loaded and returned Args: params (dict): checkpoint contents Returns: dict: Either ``params`` or the contents of the checkpoint stored at ``self.path`` """ if not self.start_fresh and os.path.isfile(self.path): print(f'Loading checkpoint from {self.path}') return torch.load(self.path) else: return params
[docs] def save(self, params, tries=10): """ Save a checkpoint containing ``params`` merged with ``self.include_in_all`` Args: params(dict): data to store in checkpoint. This is merged with anything supplied to ``include_in_all`` in the constructor tries(int): number of attempts to try and save the checkpoint. If the number of attempts exhausts, then no checkpoint is saved Returns: None """ try: torch.save({**self.include_in_all, **params}, self.path) except Exception as err: if tries > 0: print(f'Exception while saving ({err})\nRetrying ({tries})') time.sleep(60) self.save(params, tries=(tries - 1)) else: print("Giving up on saving...")