Notice
Recent Posts
Recent Comments
Link
«   2026/01   »
1 2 3
4 5 6 7 8 9 10
11 12 13 14 15 16 17
18 19 20 21 22 23 24
25 26 27 28 29 30 31
Tags
more
Archives
Today
Total
관리 메뉴

코코와 나

Tensorflow FedAVG 본문

코드

Tensorflow FedAVG

코코빵댕이 2022. 12. 21. 18:34

본문 : https://arxiv.org/pdf/1602.05629.pdf

 

Tensorflow based FedAVG

Deadline mode

 

FL setting에서 Straggler에 의한 Server 측 dependency를 막고 partial client Aggregation을 구현

 

독립 client 및 자원 할당

 

Socket Streaming 을 통한 model 전파 및 훈련 결과 송수신

 

Average Scheme

partial participant에 의한 Weighted Summation 

 

 

CallBack Function

class TimingCallback(tf.keras.callbacks.Callback):
    def __init__(self, logs=None):
        self.epoch_cnt = 0
        

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()
    def on_epoch_end(self, epoch, logs=None):
        global st
        self.epoch_cnt+=1
        self.epoch_end_time = time.time()
        self.epoch_during_time = self.epoch_end_time - self.epoch_start_time
        self.now_during_time = self.epoch_end_time - st
        if (tr_time - self.now_during_time) < self.epoch_during_time:
            local_epoch[now_client_num] = self.epoch_cnt - 1
            print('client :',str(now_client_num),'local epoch : ',str(self.epoch_cnt), 'epoch during :'\
                ,str(self.epoch_during_time),' remain :',str(tr_time - self.now_during_time))
            self.model.stop_training = True

Call back 함수를 통해 추가 훈련 여부 추적

Remain time에 대한 지속적인 추적으로 인해 client 공회전 방지 및 자원 소모 감소

 

model 생성

def create_model():
    model = Sequential()
    model.add(Conv2D(32, (3, 3), input_shape=(32,32,3), activation='relu', padding='same'))
    model.add(Dropout(0.2))
    model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(Dropout(0.2))
    model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(Dropout(0.2))
    model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dropout(0.2))
    model.add(Dense(1024, activation='relu'))
    model.add(Dropout(0.3))
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.25))
    model.add(Dense(1024, activation='relu'))
    model.add(Dropout(0.3))
    model.add(Dense(10, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    return model

 

Adam Optimizer

 

def adam(m,v,adam_iter,aggre_grad,global_grad):
  lr = 0.001
  beta1 = 0.9
  beta2 = 0.999

  adam_iter+=1
  lr_t  = lr * np.sqrt(1.0 - beta2**adam_iter) / (1.0 - beta1**adam_iter)

  for x in range(len(v)):
    m[x] += (1-beta1)*(aggre_grad[x] - m[x])
    v[x] += (1-beta2)*(aggre_grad[x]**2 - v[x])

    global_grad[x] -= (lr_t * m[x]) / (np.sqrt(v[x]) + 1e-5)
  return global_grad , m, v ,adam_iter

straggler 및 partial Aggregation에 의한 Gradient drift를 막고 수렴 유도

 

Basic Approach

 

Client

def client(train_data,test_data,epochs,global_weight,ada):
    cb = TimingCallback()
    #전송받은 global weight를 client에 setting
    local_weights = copy.deepcopy(global_weight)
    local_model.set_weights(local_weights)
	
    #load_data
    local_x_train,local_y_train = divide(train_data)
    local_x_test,local_y_test = divide(test_data)
  
    local_x_train,local_y_train = np.array(local_x_train),np.array(local_y_train)
    local_x_test,local_y_test = np.array(local_x_test),np.array(local_y_test)
	
    #train
    if ada ==1:
        local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test),callbacks=[cb])
    else:
        local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test))
    return local_weights - np.array(local_model.get_weights())

callback 함수를 통해 훈련과정에서 local update가 수행되지 않는 경우에 훈련 종료

 

Server

def fed_avg(total_iteration,local_epoch,client_num):
    
    global_model = create_model()
    global_weight = copy.deepcopy(np.array(global_model.get_weights()))

    client_weights = [0 for x in range(client_num)]
    total_client_data_size = np.sum(client_data_size)
    total_acc = []
    for iter in range(total_iteration):
        ######################################################
        loss,acc = global_model.evaluate(x_test,  y_test, verbose=2)
        print(iter,"avg global 모델의 정확도: {:5.2f}%".format(100*acc))
        total_acc.append(100*acc)
        ##################################################
        for x in range(client_num):
            
            client_weights[x] = client(client_train[x],client_test[x],local_epoch,global_weight,0)

        ##################################################

        for x in range(client_num):
            if x == 0:
                aggre_grad = np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
            else:
                aggre_grad += np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
        
        

        global_model.set_weights(global_weight-aggre_grad)
        global_weight = copy.deepcopy(np.array(global_model.get_weights()))
    return total_acc

client에게 가중치를 송부하고 훈련 결과 받아 Weighted Aggregation 수행

 

다음 포스트 에서는 가중치를 송부하고 수신하는 코드를 작성하며 최종적으로는 가중치에 노이즈가 포함되는 경우에 global model의 수렴을 촉진하는 코드를 작성합니다.

 

 

'코드' 카테고리의 다른 글

FedProx - pytorch implementation  (0) 2022.12.21
Socket steaming  (0) 2022.12.21
Python socket basic implementation  (0) 2022.12.21
FedAVG pytorch implementation  (0) 2022.12.21
효율적인 프로그램과 리스트  (1) 2022.12.21
Comments