코드
FedProx - pytorch implementation
코코빵댕이
2022. 12. 21. 21:16
원문: https://arxiv.org/pdf/1812.06127.pdf

이번 post에서는 FedProx를 구현합니다.
다양한 system Heterogeneity에 내구성을 갖고 global model의 convergence를 유도 하기 위해 Proximal term을 둔것이 핵심 입니다.
Server
def server(self):
for t in tqdm(range(self.args.r)):
print('round', t + 1, ':')
# 클라이언트 샘플링
m = np.max([int(self.args.C * self.args.K), 1])
index = random.sample(range(0, self.args.K), m)
# 배포
self.dispatch(index)
# 클라이언트 local updating
self.client_update(index)
# 집계
self.aggregation(index)
return self.nn
서버는 FedAVG와 다르지 않습니다. 전체 client 중 일부를 뽑아 훈련에 참여 하고 집계하는 과정입니다.
weight dispatch
def dispatch(self, index):
for j in index:
for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()):
old_params.data = new_params.data.clone()
client training
def client_update(self, index): # update client's local model
for k in index:
self.nns[k] = train(self.args, self.nns[k], self.nn)
실제로는 병렬이지만 시뮬레이션을 위해 순차적으로 각 client를 훈련합니다.
Aggregation
def aggregation(self, index):
s = 0
for j in index:
# normal
s += self.nns[j].len
params = {}
for k, v in self.nns[0].named_parameters():
params[k] = torch.zeros_like(v.data)
for j in index:
for k, v in self.nns[j].named_parameters():
params[k] += v.data * (self.nns[j].len / s)
for k, v in self.nn.named_parameters():
v.data = params[k].data.clone()
집계입니다. 각 client의 parameter를 데이터 크기에 따라 가중집계 합니다.
FedAVG의 normal mode 에서는 각 client의 데이터 크기에 상관없이 일괄적으로 평균화를 수행했습니다.
하지만 그 과정에서 데이터 크기를 고려하지 않는것은 global model의 bias의 원인이 됩니다.
Prox_server
import copy
import random
import numpy as np
import torch
from tqdm import tqdm
from model import ANN
from client import train, test
class FedProx:
def __init__(self, args):
self.args = args
self.nn = ANN(args=self.args, name='server').to(args.device)
self.nns = []
for i in range(self.args.K):
temp = copy.deepcopy(self.nn)
temp.name = self.args.clients[i]
self.nns.append(temp)
def server(self):
for t in tqdm(range(self.args.r)):
print('round', t + 1, ':')
m = np.max([int(self.args.C * self.args.K), 1])
index = random.sample(range(0, self.args.K), m) # st
self.dispatch(index)
self.client_update(index)
self.aggregation(index)
return self.nn
def aggregation(self, index):
s = 0
for j in index:
s += self.nns[j].len
params = {}
for k, v in self.nns[0].named_parameters():
params[k] = torch.zeros_like(v.data)
for j in index:
for k, v in self.nns[j].named_parameters():
params[k] += v.data * (self.nns[j].len / s)
for k, v in self.nn.named_parameters():
v.data = params[k].data.clone()
def dispatch(self, index):
for j in index:
for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()):
old_params.data = new_params.data.clone()
def client_update(self, index):
for k in index:
self.nns[k] = train(self.args, self.nns[k], self.nn)
model
class ANN(nn.Module):
def __init__(self, args, name):
super(ANN, self).__init__()
self.name = name
self.len = 0
self.loss = 0
self.fc1 = nn.Linear(args.input_dim, 20)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout()
self.fc2 = nn.Linear(20, 20)
self.fc3 = nn.Linear(20, 20)
self.fc4 = nn.Linear(20, 1)
def forward(self, data):
x = self.fc1(data)
x = self.sigmoid(x)
x = self.fc2(x)
x = self.sigmoid(x)
x = self.fc3(x)
x = self.sigmoid(x)
x = self.fc4(x)
x = self.sigmoid(x)
return x
서버와 client가 공유하는 모델입니다. Edge device의 자원을 고려해 light한 형태가 이상적 입니다.
client train
FedAvg와 훈련 절차는 동일 합니다. 그 과정에서 Proximal term을 붙입니다.


def train(args, model, server):
model.train()
Dtr, Val, Dte = nn_seq_wind(model.name, args.B)
model.len = len(Dtr)
# 모델 복사
global_model = copy.deepcopy(server)
lr = args.lr
if args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=lr,
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=lr,
momentum=0.9, weight_decay=args.weight_decay)
# optimizer에 따른 성능 비교
stepLR = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
# 훈련
# 최소 업데이트 수
min_epochs = 5
#최신 모델 설정
recent_model = None
min_val_loss = 5
print('client 훈련 시작')
loss_function = nn.MSELoss().to(args.device)
for epoch in tqdm(range(args.E)):
train_loss = []
for (seq, label) in Dtr:
seq = seq.to(args.device)
label = label.to(args.device)
y_pred = model(seq)
optimizer.zero_grad()
# proximal_term 계산
proximal_term = 0.0
for w, w_t in zip(model.parameters(), global_model.parameters()):
proximal_term += (w - w_t).norm(2)
#손실 계산
loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term
train_loss.append(loss.item())
#역전파
loss.backward()
#optimizer 갱신
optimizer.step()
stepLR.step()
# validation
val_loss = get_val_loss(args, model, Val)
if epoch + 1 >= min_epochs and val_loss < min_val_loss:
min_val_loss = val_loss
recent_model = copy.deepcopy(model)
print('epoch {:03d} train_loss {:.8f} val_loss {:.8f}'.format(epoch, np.mean(train_loss), val_loss))
model.train()
return recent_model
def get_val_loss(args, model, Val):
model.eval()
loss_function = nn.MSELoss().to(args.device)
val_loss = []
for (seq, label) in Val:
with torch.no_grad():
seq = seq.to(args.device)
label = label.to(args.device)
y_pred = model(seq)
loss = loss_function(y_pred, label)
val_loss.append(loss.item())
return np.mean(val_loss)