Source code for hype.train

#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch as th
import numpy as np
import timeit
import gc
from tqdm import tqdm
from torch.utils import data as torch_data

_lr_multiplier = 0.1


[docs]def train( device, model, data, optimizer, opt, log, rank=1, queue=None, ctrl=None, checkpointer=None, progress=False ): """ Function to train embeddings Args: device (torch.device): which device to train on model (torch.nn.Module): model to train data (BatchedDataset or AdjacencyDataset): dataloader optimizer (torch.optim.Optimizer): optimizer opt (SimpleNamespace): command line options log (logging.Logger): log rank (int): thread rank if using multiple training threads queue (multiprocessing.Queue): Queue to put epoch stats into if using multiple threads/asynchronous control checkpointer (Callable): checkpointing function progress (bool): whether or not to display progress bar per epoch """ if isinstance(data, torch_data.Dataset): loader = torch_data.DataLoader(data, batch_size=opt.batchsize, shuffle=True, num_workers=opt.ndproc) else: loader = data epoch_loss = th.Tensor(len(loader)) counts = th.zeros(model.nobjects, 1).to(device) for epoch in range(opt.epoch_start, opt.epochs): epoch_loss.fill_(0) data.burnin = False lr = opt.lr t_start = timeit.default_timer() if epoch < opt.burnin: data.burnin = True lr = opt.lr * _lr_multiplier if rank == 1: log.info(f'Burn in negs={data.nnegatives()}, lr={lr}') loader_iter = tqdm(loader) if progress and rank == 1 else loader for i_batch, (inputs, targets) in enumerate(loader_iter): elapsed = timeit.default_timer() - t_start inputs = inputs.to(device) targets = targets.to(device) # count occurrences of objects in batch if hasattr(opt, 'asgd') and opt.asgd: counts = th.bincount(inputs.view(-1), minlength=model.nobjects) counts.clamp_(min=1).div_(inputs.size(0)) counts = counts.double().unsqueeze(-1) optimizer.zero_grad() preds = model(inputs) loss = model.loss(preds, targets, size_average=True) loss.backward() optimizer.step(lr=lr, counts=counts) epoch_loss[i_batch] = loss.cpu().item() if rank == 1: if hasattr(data, 'avg_queue_size'): qsize = data.avg_queue_size() misses = data.queue_misses() log.info(f'Average qsize for epoch was {qsize}, num_misses={misses}') if queue is not None: queue.put((epoch, elapsed, th.mean(epoch_loss).item(), model)) elif ctrl is not None and epoch % opt.eval_each == (opt.eval_each - 1): with th.no_grad(): ctrl(model, epoch, elapsed, th.mean(epoch_loss).item()) else: log.info( 'json_stats: {' f'"epoch": {epoch}, ' f'"elapsed": {elapsed}, ' f'"loss": {th.mean(epoch_loss).item()}, ' '}' ) if checkpointer and hasattr(ctrl, 'checkpoint') and ctrl.checkpoint: checkpointer(model, epoch, epoch_loss) gc.collect()