코드
Tensorflow FedAVG
코코빵댕이
2022. 12. 21. 18:34

본문 : https://arxiv.org/pdf/1602.05629.pdf
Tensorflow based FedAVG
Deadline mode
FL setting에서 Straggler에 의한 Server 측 dependency를 막고 partial client Aggregation을 구현
독립 client 및 자원 할당
Socket Streaming 을 통한 model 전파 및 훈련 결과 송수신
Average Scheme

partial participant에 의한 Weighted Summation

CallBack Function
class TimingCallback(tf.keras.callbacks.Callback):
def __init__(self, logs=None):
self.epoch_cnt = 0
def on_epoch_begin(self, epoch, logs=None):
self.epoch_start_time = time.time()
def on_epoch_end(self, epoch, logs=None):
global st
self.epoch_cnt+=1
self.epoch_end_time = time.time()
self.epoch_during_time = self.epoch_end_time - self.epoch_start_time
self.now_during_time = self.epoch_end_time - st
if (tr_time - self.now_during_time) < self.epoch_during_time:
local_epoch[now_client_num] = self.epoch_cnt - 1
print('client :',str(now_client_num),'local epoch : ',str(self.epoch_cnt), 'epoch during :'\
,str(self.epoch_during_time),' remain :',str(tr_time - self.now_during_time))
self.model.stop_training = True
Call back 함수를 통해 추가 훈련 여부 추적
Remain time에 대한 지속적인 추적으로 인해 client 공회전 방지 및 자원 소모 감소
model 생성
def create_model():
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(32,32,3), activation='relu', padding='same'))
model.add(Dropout(0.2))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(Dropout(0.2))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(Dropout(0.2))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dropout(0.2))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
return model
Adam Optimizer
def adam(m,v,adam_iter,aggre_grad,global_grad):
lr = 0.001
beta1 = 0.9
beta2 = 0.999
adam_iter+=1
lr_t = lr * np.sqrt(1.0 - beta2**adam_iter) / (1.0 - beta1**adam_iter)
for x in range(len(v)):
m[x] += (1-beta1)*(aggre_grad[x] - m[x])
v[x] += (1-beta2)*(aggre_grad[x]**2 - v[x])
global_grad[x] -= (lr_t * m[x]) / (np.sqrt(v[x]) + 1e-5)
return global_grad , m, v ,adam_iter
straggler 및 partial Aggregation에 의한 Gradient drift를 막고 수렴 유도
Basic Approach
Client
def client(train_data,test_data,epochs,global_weight,ada):
cb = TimingCallback()
#전송받은 global weight를 client에 setting
local_weights = copy.deepcopy(global_weight)
local_model.set_weights(local_weights)
#load_data
local_x_train,local_y_train = divide(train_data)
local_x_test,local_y_test = divide(test_data)
local_x_train,local_y_train = np.array(local_x_train),np.array(local_y_train)
local_x_test,local_y_test = np.array(local_x_test),np.array(local_y_test)
#train
if ada ==1:
local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test),callbacks=[cb])
else:
local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test))
return local_weights - np.array(local_model.get_weights())
callback 함수를 통해 훈련과정에서 local update가 수행되지 않는 경우에 훈련 종료
Server
def fed_avg(total_iteration,local_epoch,client_num):
global_model = create_model()
global_weight = copy.deepcopy(np.array(global_model.get_weights()))
client_weights = [0 for x in range(client_num)]
total_client_data_size = np.sum(client_data_size)
total_acc = []
for iter in range(total_iteration):
######################################################
loss,acc = global_model.evaluate(x_test, y_test, verbose=2)
print(iter,"avg global 모델의 정확도: {:5.2f}%".format(100*acc))
total_acc.append(100*acc)
##################################################
for x in range(client_num):
client_weights[x] = client(client_train[x],client_test[x],local_epoch,global_weight,0)
##################################################
for x in range(client_num):
if x == 0:
aggre_grad = np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
else:
aggre_grad += np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
global_model.set_weights(global_weight-aggre_grad)
global_weight = copy.deepcopy(np.array(global_model.get_weights()))
return total_acc
client에게 가중치를 송부하고 훈련 결과 받아 Weighted Aggregation 수행
다음 포스트 에서는 가중치를 송부하고 수신하는 코드를 작성하며 최종적으로는 가중치에 노이즈가 포함되는 경우에 global model의 수렴을 촉진하는 코드를 작성합니다.