Notice
Recent Posts
Recent Comments
Link
«   2026/01   »
1 2 3
4 5 6 7 8 9 10
11 12 13 14 15 16 17
18 19 20 21 22 23 24
25 26 27 28 29 30 31
Tags
more
Archives
Today
Total
관리 메뉴

코코와 나

FedAVG pytorch implementation 본문

코드

FedAVG pytorch implementation

코코빵댕이 2022. 12. 21. 15:00

본문 : https://arxiv.org/pdf/1602.05629.pdf 

 

Federated Averaging scheme

 

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import copy
import random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, utils, datasets
from torchsummary import summary

#random seed 생성
seed = 55

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# GPU 설정
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

 

관련 library import

transforms_mnist = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])
#data transformer 생성

mnist_data_train = datasets.MNIST('./data/mnist/', train=True, download=True, transform=transforms_mnist)
mnist_data_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=transforms_mnist)
#MNIST data load

훈련 및 테스트 데이터 로드

 

classes = np.array(list(mnist_data_train.class_to_idx.values()))
classes_test = np.array(list(mnist_data_test.class_to_idx.values()))
num_classes = len(classes_test)

#IID setting data 분할
def iid_partition(dataset, clients):

  num_items_per_client = int(len(dataset)/clients)
  client_dict = {}
  image_idxs = [i for i in range(len(dataset))]

  for i in range(clients):
    client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))
    image_idxs = list(set(image_idxs) - client_dict[i])

  return client_dict
  

# non-IID setting 데이터 분할
def non_iid_partition(dataset, clients, total_shards, shards_size, num_shards_per_client):
  shard_idxs = [i for i in range(total_shards)]
  client_dict = {i: np.array([], dtype='int64') for i in range(clients)}
  idxs = np.arange(len(dataset))
  data_labels = dataset.targets.numpy()
  
  label_idxs = np.vstack((idxs, data_labels))
  label_idxs = label_idxs[:, label_idxs[1,:].argsort()]
  idxs = label_idxs[0,:]

  for i in range(clients):
    rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False))
    shard_idxs = list(set(shard_idxs) - rand_set)

    for rand in rand_set:
      client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0)
  
  return client_dict

데이터 분할 및 client setting

 

#훈련 모델 생성 (서버 측)
class MNIST_CNN(nn.Module):
  def __init__(self):
    super(MNIST_CNN, self).__init__()

    self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
    self.conv3 = nn.Conv2d(64, 128, kernel_size=5)
    
    self.pool = nn.MaxPool2d(2,2)
    self.dropout = nn.Dropout(p=0.3)

    self.fc1 = nn.Linear(512, 256)
    self.out = nn.Linear(256, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = self.pool(F.relu(self.conv3(x)))
    x = self.dropout(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.out(x)
    out = F.log_softmax(x, dim=1)

    return out

훈련 모델 생성

 

class CustomDataset(Dataset):
  def __init__(self, dataset, idxs):
      self.dataset = dataset
      self.idxs = list(idxs)

  def __len__(self):
      return len(self.idxs)

  def __getitem__(self, item):
      image, label = self.dataset[self.idxs[item]]
      return image, label

class ClientUpdate(object):
  def __init__(self, dataset, batchSize, learning_rate, epochs, idxs):
    self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batchSize, shuffle=True)
	# client data 로드
    self.learning_rate = learning_rate
    self.epochs = epochs

  def train(self, model):

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)
    # SGD 기반 최적화

    e_loss = []
    # 송신된 local update 수만큼
    for epoch in range(1, self.epochs+1):
      train_loss = 0.0
      model.train()
      #데이터 로더만큼 돌면서
      for data, labels in self.train_loader:
		# GPU에 밀어넣기
        if torch.cuda.is_available():
          data, labels = data.cuda(), labels.cuda()

        # 매 업데이트마다 gradident 초기화
        optimizer.zero_grad()
        
        #순전파
        output = model(data)
        
        # 손실계산
        loss = criterion(output, labels)
        
        # 역전파
        loss.backward()
        
        # optimizer 업데이트
        optimizer.step()
        # 손실 업데이트
        train_loss += loss.item()*data.size(0)

      # 평균 손실
      train_loss = train_loss/len(self.train_loader.dataset)
      e_loss.append(train_loss)

    total_loss = sum(e_loss)/len(e_loss)

    return model.state_dict(), total_loss

client 훈련

 

def training(model, rounds, batch_size, lr, ds, data_dict, C, K, E, plt_title, plt_color):

  # global model 가중치 복사
  global_weights = model.state_dict()

  train_loss = []
  
  #global iteration 동작
  for curr_round in range(1, rounds+1):
    w, local_loss = [], []

    m = max(int(C*K), 1)
    
    S_t = np.random.choice(range(K), m, replace=False)
    
    #참여 client에 훈련 전송
    for k in S_t:
      local_update = ClientUpdate(dataset=ds, batchSize=batch_size, learning_rate=lr, epochs=E, idxs=data_dict[k])
      weights, loss = local_update.train(model=copy.deepcopy(model))

      w.append(copy.deepcopy(weights))
      local_loss.append(copy.deepcopy(loss))

    # client 훈련 값 평균 (normal)
    weights_avg = copy.deepcopy(w[0])
    for k in weights_avg.keys():
      for i in range(1, len(w)):
        weights_avg[k] += w[i][k]

      weights_avg[k] = torch.div(weights_avg[k], len(w))
	
    #평균
    global_weights = weights_avg

    # 가중치 갱신
    model.load_state_dict(global_weights)

    # 손실
    loss_avg = sum(local_loss) / len(local_loss)
    train_loss.append(loss_avg)

  return model

서버측 훈련

 

구현 한계점 : 

1. 병렬훈련을 표상하는 Federated learning에도 순차적 훈련에 국한됨

2. single device에서 가상으로 동작하므로 system Hetero 를 구현하기 어려움

3. normal averaging으로 특정 small client에 대한 bias가 존재 할 수 있음

4. client drift 가능성이 있음 (non-IID setting에서)

 

다음 post에서는 client 가상화를 통해 독립된 client를 구성하고 독립적으로 훈련을 수행하며 system hetero를 구현

 

 

 

 

 

 

Comments