import torch
import torch.distributed as dist
import numpy as np
from dataclasses import dataclass
[docs]
@dataclass
class StreamerOptimizerArguments():
world_size: int = 1
r"""Number of gpus to distribute the dataset"""
alpha: int = 3
r""" The reach parameter for Hierarchical Gradient Normalization"""
max_layers: int = 3
r"""The maximum number of layers to stack"""
optimize_every: int = 100
r"""Take a gradient step every this value"""
average_every: int = 1000
r"""Average models across gpus every this value"""
hgn_timescale: bool = True
r"""Allow timescale parameter in Hierarchical Gradient Normalization"""
hgn_reach: bool = True
r"""Allow reach parameter in Hierarchical Gradient Normalization"""
@staticmethod
def from_args(args):
return StreamerOptimizerArguments(
world_size=args.world_size,
alpha=args.alpha,
max_layers=args.max_layers,
optimize_every=args.optimize_every,
average_every=args.average_every,
hgn_timescale = args.hgn_timescale,
hgn_reach = args.hgn_reach,
)
[docs]
class StreamerOptimizer():
r"""
The optimizer used with streamer.
This class takes care of optimization, Gradient normalization and averaging across gpus.
:param StreamerOptimizerArguments args: The parameters used for the Streamer optimizer
"""
def __init__(self, model, args:StreamerOptimizerArguments):
self.args = args
self.model = model
self.curr_n_layers = 0
self.average_counter = 0
self.step_counter = 0
self.reset()
[docs]
def get_param_groups(self, layer_num):
r"""
Calculates the parameter groups and their weights.
For example if the layer_num is 1 and we have 4 layers, then parameter groups will be [[1],[0,2],[3]]
and their weights will depend on alpha but typically more on the early groups (e.g., [0.8, 0.15, 0.05])
:param int layer_num: the index of the layer
:returns:
* (*List(List(int))*): List of Lists dividing the layers into groups to assign different gradient multipliers to them
* (*List(float)*): The weights assigned to the parameter groups
"""
less = [layer_num - i for i in range(len(self.f_params)) if layer_num-i>=0]
more = [layer_num + i for i in range(len(self.f_params)) if layer_num+i<len(self.f_params)]
groups = []
for i in range(max(len(less), len(more))):
g = set()
if i < len(less):
g.add(less[i])
if i < len(more):
g.add(more[i])
groups.append(list(g))
weights = torch.Tensor([1.0/(self.args.alpha**i) for i in range(len(groups))])
weights /= weights.sum()
return groups, weights
def get_current_params(self):
r"""
Gets a reference to all parameters of all layers in the streamer model
"""
n_layers = 1
ll = self.model.streamer
self.f_params = [ll.f.get_params()]
self.p_params = [ll.p.get_params()]
self.l_params = [ll.l.get_params()]
while ll.above != None:
n_layers += 1
ll = ll.above
self.f_params.append(ll.f.get_params())
self.p_params.append(ll.p.get_params())
self.l_params.append(ll.l.get_params())
return n_layers
def update_all_layers(self, layer):
# return if no loss to use
if not layer.objective_ready:
return
layer_num = layer.layer_num
layer_loss = layer.l.summarize_loss(layer.main_objective)
layer_timescale = float(layer.x_size) if self.args.hgn_timescale else 1.0
self.p_counter[layer_num] += layer_timescale
groups, reaches = self.get_param_groups(layer_num)
for g_i, (group, reach) in enumerate(zip(groups, reaches)):
# get parameters to accumulate
params = [*self.p_params[layer_num]]
for g_i in group: params.extend(self.f_params[g_i])
retain_graph = g_i!=(len(groups)-1) or True
(layer_loss*layer_timescale*reach).backward(inputs=params, retain_graph=retain_graph)
for i in group: self.f_counter[i] += layer_timescale
# reset loss
self.step_counter[layer_num] += 1
layer.objective_ready = False
[docs]
def get_gradients(self):
r"""
accumulates gradient on all layers' parameters from all losses in the model
"""
ll = self.model.streamer
self.update_all_layers(ll)
while ll.above != None:
ll = ll.above
self.update_all_layers(ll)
def equal_layers(self):
r"""
Determines if the streamer model has the same number of layers on all gpus
:returns:
* (*bool*): True if equal layers on all gpus
* (*int*): The number of layers on the current gpu
"""
l_global = torch.tensor([0.0]).cuda()
l_global = self.model.streamer.get_num_layers(l_global)
l_local = l_global.clone()
if self.args.world_size == 1:
return True, int(l_local.item())
dist.all_reduce(l_global)
return l_local*self.args.world_size == l_global, int(l_local.item())
[docs]
def reset(self):
r"""
Resets the counters of the optimizer
"""
self.f_counter = {i:0 for i in range(self.args.max_layers)}
self.p_counter = {i:0 for i in range(self.args.max_layers)}
self.l_counter = {i:0 for i in range(self.args.max_layers)}
self.step_counter = {i:0 for i in range(self.args.max_layers)}
[docs]
def scale_gradients(self):
r"""
Scales the gradients of all modules by the counters to normalize the gradients
"""
def scale_module(params, counter):
for p_ix, param in enumerate(params):
if counter[p_ix] == 0: continue
for p in param:
if p.grad != None: p.grad /= counter[p_ix]
scale_module(self.f_params, self.f_counter)
scale_module(self.p_params, self.p_counter)
[docs]
def average_models(self):
r"""
Average the model parameters across all gpus every :py:meth:`~StreamerOptimizerArguments.average_every`
"""
for name, param in self.model.named_parameters():
dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
param.data /= self.args.world_size
[docs]
def step(self):
r"""
Call the optimizer, which calulates the gradients and accumulates it on the parameters.
This function does not actually do gradient stepping. It has a counter that does it every :py:meth:`~StreamerOptimizerArguments.optimize_every`
"""
# === AVERAGE MODELS === #
if self.args.world_size > 1: self.average_counter += 1
equal, n_layers = self.equal_layers()
if n_layers > self.curr_n_layers:
self.curr_n_layers = self.get_current_params()
if equal and self.average_counter >= self.args.average_every:
self.average_models()
self.average_counter = 0
# === OPTIMIZATION === #
self.get_gradients() # backwards
if self.step_counter[self.curr_n_layers-1] >= self.args.optimize_every:
self.scale_gradients()
self.model.optimize_model()
self.reset()