코드

Tensorflow FedAVG

코코빵댕이 2022. 12. 21. 18:34

본문 : https://arxiv.org/pdf/1602.05629.pdf

 

Tensorflow based FedAVG

Deadline mode

 

FL setting에서 Straggler에 의한 Server 측 dependency를 막고 partial client Aggregation을 구현

 

독립 client 및 자원 할당

 

Socket Streaming 을 통한 model 전파 및 훈련 결과 송수신

 

Average Scheme

partial participant에 의한 Weighted Summation 

 

 

CallBack Function

class TimingCallback(tf.keras.callbacks.Callback):
    def __init__(self, logs=None):
        self.epoch_cnt = 0
        

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()
    def on_epoch_end(self, epoch, logs=None):
        global st
        self.epoch_cnt+=1
        self.epoch_end_time = time.time()
        self.epoch_during_time = self.epoch_end_time - self.epoch_start_time
        self.now_during_time = self.epoch_end_time - st
        if (tr_time - self.now_during_time) < self.epoch_during_time:
            local_epoch[now_client_num] = self.epoch_cnt - 1
            print('client :',str(now_client_num),'local epoch : ',str(self.epoch_cnt), 'epoch during :'\
                ,str(self.epoch_during_time),' remain :',str(tr_time - self.now_during_time))
            self.model.stop_training = True

Call back 함수를 통해 추가 훈련 여부 추적

Remain time에 대한 지속적인 추적으로 인해 client 공회전 방지 및 자원 소모 감소

 

model 생성

def create_model():
    model = Sequential()
    model.add(Conv2D(32, (3, 3), input_shape=(32,32,3), activation='relu', padding='same'))
    model.add(Dropout(0.2))
    model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(Dropout(0.2))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(Dropout(0.2))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dropout(0.2))
    model.add(Dense(1024, activation='relu'))
    model.add(Dropout(0.3))
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.25))
    model.add(Dense(1024, activation='relu'))
    model.add(Dropout(0.3))
    model.add(Dense(10, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    return model

 

Adam Optimizer

 

def adam(m,v,adam_iter,aggre_grad,global_grad):
  lr = 0.001
  beta1 = 0.9
  beta2 = 0.999

  adam_iter+=1
  lr_t  = lr * np.sqrt(1.0 - beta2**adam_iter) / (1.0 - beta1**adam_iter)

  for x in range(len(v)):
    m[x] += (1-beta1)*(aggre_grad[x] - m[x])
    v[x] += (1-beta2)*(aggre_grad[x]**2 - v[x])

    global_grad[x] -= (lr_t * m[x]) / (np.sqrt(v[x]) + 1e-5)
  return global_grad , m, v ,adam_iter

straggler 및 partial Aggregation에 의한 Gradient drift를 막고 수렴 유도

 

Basic Approach

 

Client

def client(train_data,test_data,epochs,global_weight,ada):
    cb = TimingCallback()
    #전송받은 global weight를 client에 setting
    local_weights = copy.deepcopy(global_weight)
    local_model.set_weights(local_weights)
	
    #load_data
    local_x_train,local_y_train = divide(train_data)
    local_x_test,local_y_test = divide(test_data)
  
    local_x_train,local_y_train = np.array(local_x_train),np.array(local_y_train)
    local_x_test,local_y_test = np.array(local_x_test),np.array(local_y_test)
	
    #train
    if ada ==1:
        local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test),callbacks=[cb])
    else:
        local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test))
    return local_weights - np.array(local_model.get_weights())

callback 함수를 통해 훈련과정에서 local update가 수행되지 않는 경우에 훈련 종료

 

Server

def fed_avg(total_iteration,local_epoch,client_num):
    
    global_model = create_model()
    global_weight = copy.deepcopy(np.array(global_model.get_weights()))

    client_weights = [0 for x in range(client_num)]
    total_client_data_size = np.sum(client_data_size)
    total_acc = []
    for iter in range(total_iteration):
        ######################################################
        loss,acc = global_model.evaluate(x_test,  y_test, verbose=2)
        print(iter,"avg global 모델의 정확도: {:5.2f}%".format(100*acc))
        total_acc.append(100*acc)
        ##################################################
        for x in range(client_num):
            
            client_weights[x] = client(client_train[x],client_test[x],local_epoch,global_weight,0)

        ##################################################

        for x in range(client_num):
            if x == 0:
                aggre_grad = np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
            else:
                aggre_grad += np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
        
        

        global_model.set_weights(global_weight-aggre_grad)
        global_weight = copy.deepcopy(np.array(global_model.get_weights()))
    return total_acc

client에게 가중치를 송부하고 훈련 결과 받아 Weighted Aggregation 수행

 

다음 포스트 에서는 가중치를 송부하고 수신하는 코드를 작성하며 최종적으로는 가중치에 노이즈가 포함되는 경우에 global model의 수렴을 촉진하는 코드를 작성합니다.