코드

FedProx - pytorch implementation

코코빵댕이 2022. 12. 21. 21:16

원문: https://arxiv.org/pdf/1812.06127.pdf

 

이번 post에서는 FedProx를 구현합니다. 

 

다양한 system Heterogeneity에 내구성을 갖고 global model의 convergence를 유도 하기 위해 Proximal term을 둔것이 핵심 입니다. 

 

Server

def server(self):
    for t in tqdm(range(self.args.r)):
        print('round', t + 1, ':')
        # 클라이언트 샘플링
        m = np.max([int(self.args.C * self.args.K), 1])
        index = random.sample(range(0, self.args.K), m)
        
        # 배포
        self.dispatch(index)
        
        # 클라이언트 local updating
        self.client_update(index)
        
        # 집계
        self.aggregation(index)
        
    return self.nn

서버는 FedAVG와 다르지 않습니다. 전체 client 중 일부를 뽑아 훈련에 참여 하고 집계하는 과정입니다.

 

weight dispatch

def dispatch(self, index):
        for j in index:
            for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()):
                old_params.data = new_params.data.clone()

 

 

client training

 

def client_update(self, index):  # update client's local model
        for k in index:
            self.nns[k] = train(self.args, self.nns[k], self.nn)

실제로는 병렬이지만 시뮬레이션을 위해 순차적으로 각 client를 훈련합니다.

 

Aggregation

 

def aggregation(self, index):
        s = 0
        for j in index:
            # normal
            s += self.nns[j].len

        params = {}
        for k, v in self.nns[0].named_parameters():
            params[k] = torch.zeros_like(v.data)

        for j in index:
            for k, v in self.nns[j].named_parameters():
                params[k] += v.data * (self.nns[j].len / s)

        for k, v in self.nn.named_parameters():
            v.data = params[k].data.clone()

집계입니다. 각 client의 parameter를 데이터 크기에 따라 가중집계 합니다.

FedAVG의 normal mode 에서는 각 client의 데이터 크기에 상관없이 일괄적으로 평균화를 수행했습니다.

하지만 그 과정에서 데이터 크기를 고려하지 않는것은 global model의 bias의 원인이 됩니다.

 

Prox_server

 

import copy
import random

import numpy as np
import torch
from tqdm import tqdm

from model import ANN
from client import train, test


class FedProx:
    def __init__(self, args):
        self.args = args
        self.nn = ANN(args=self.args, name='server').to(args.device)
        self.nns = []
        for i in range(self.args.K):
            temp = copy.deepcopy(self.nn)
            temp.name = self.args.clients[i]
            self.nns.append(temp)

    def server(self):
        for t in tqdm(range(self.args.r)):
            print('round', t + 1, ':')
            
            m = np.max([int(self.args.C * self.args.K), 1])
            index = random.sample(range(0, self.args.K), m)  # st
            
            self.dispatch(index)
            
            self.client_update(index)
            
            self.aggregation(index)

        return self.nn

    def aggregation(self, index):
        s = 0
        for j in index:
            
            s += self.nns[j].len

        params = {}
        for k, v in self.nns[0].named_parameters():
            params[k] = torch.zeros_like(v.data)

        for j in index:
            for k, v in self.nns[j].named_parameters():
                params[k] += v.data * (self.nns[j].len / s)

        for k, v in self.nn.named_parameters():
            v.data = params[k].data.clone()

    def dispatch(self, index):
        for j in index:
            for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()):
                old_params.data = new_params.data.clone()

    def client_update(self, index):  
        for k in index:
            self.nns[k] = train(self.args, self.nns[k], self.nn)

model

class ANN(nn.Module):
    def __init__(self, args, name):
        super(ANN, self).__init__()
        self.name = name
        self.len = 0
        self.loss = 0
        self.fc1 = nn.Linear(args.input_dim, 20)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 1)

    def forward(self, data):
        x = self.fc1(data)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        x = self.fc4(x)
        x = self.sigmoid(x)

        return x

서버와 client가 공유하는 모델입니다. Edge device의 자원을 고려해 light한 형태가 이상적 입니다.

 

client train

 

FedAvg와 훈련 절차는 동일 합니다. 그 과정에서 Proximal term을 붙입니다.

def train(args, model, server):
    model.train()
    Dtr, Val, Dte = nn_seq_wind(model.name, args.B)
    model.len = len(Dtr)
    
    # 모델 복사
    global_model = copy.deepcopy(server)
    lr = args.lr
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                                     weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9, weight_decay=args.weight_decay)
    # optimizer에 따른 성능 비교
    
    stepLR = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    # 훈련
    # 최소 업데이트 수
    min_epochs = 5
    
    #최신 모델 설정
    recent_model = None
    
    min_val_loss = 5
    print('client 훈련 시작')
    loss_function = nn.MSELoss().to(args.device)
    for epoch in tqdm(range(args.E)):
        train_loss = []
        
        for (seq, label) in Dtr:
            seq = seq.to(args.device)
            label = label.to(args.device)
            y_pred = model(seq)
            optimizer.zero_grad()
            # proximal_term 계산
            proximal_term = 0.0
            for w, w_t in zip(model.parameters(), global_model.parameters()):
                proximal_term += (w - w_t).norm(2)
			#손실 계산
            loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term
            
            train_loss.append(loss.item())
            #역전파
            loss.backward()
            #optimizer 갱신
            optimizer.step()
        stepLR.step()
        
        # validation
        val_loss = get_val_loss(args, model, Val)
        if epoch + 1 >= min_epochs and val_loss < min_val_loss:
            min_val_loss = val_loss
            recent_model = copy.deepcopy(model)

        print('epoch {:03d} train_loss {:.8f} val_loss {:.8f}'.format(epoch, np.mean(train_loss), val_loss))
        model.train()

    return recent_model

def get_val_loss(args, model, Val):
    model.eval()
    loss_function = nn.MSELoss().to(args.device)
    val_loss = []
    for (seq, label) in Val:
        with torch.no_grad():
            seq = seq.to(args.device)
            label = label.to(args.device)
            y_pred = model(seq)
            loss = loss_function(y_pred, label)
            val_loss.append(loss.item())

    return np.mean(val_loss)