Module tiresias.server.handler.gradient

Expand source code
import torch
import numpy as np
from tiresias.core import b64_decode, b64_encode
from tiresias.core.gradients import merge_gradients, put_gradients

def handle_gradient(task, data):
    model = b64_decode(task["model"])
    optimizer = torch.optim.Adam(model.parameters(), lr=task["lr"])
    put_gradients(model, merge_gradients([b64_decode(g) for g in data]))
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    return b64_encode(model)

Functions

def handle_gradient(task, data)
Expand source code
def handle_gradient(task, data):
    model = b64_decode(task["model"])
    optimizer = torch.optim.Adam(model.parameters(), lr=task["lr"])
    put_gradients(model, merge_gradients([b64_decode(g) for g in data]))
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    return b64_encode(model)