코코와 나
Tensorflow FedAVG 본문

본문 : https://arxiv.org/pdf/1602.05629.pdf
Tensorflow based FedAVG
Deadline mode
FL setting에서 Straggler에 의한 Server 측 dependency를 막고 partial client Aggregation을 구현
독립 client 및 자원 할당
Socket Streaming 을 통한 model 전파 및 훈련 결과 송수신
Average Scheme

partial participant에 의한 Weighted Summation

CallBack Function
class TimingCallback(tf.keras.callbacks.Callback):
def __init__(self, logs=None):
self.epoch_cnt = 0
def on_epoch_begin(self, epoch, logs=None):
self.epoch_start_time = time.time()
def on_epoch_end(self, epoch, logs=None):
global st
self.epoch_cnt+=1
self.epoch_end_time = time.time()
self.epoch_during_time = self.epoch_end_time - self.epoch_start_time
self.now_during_time = self.epoch_end_time - st
if (tr_time - self.now_during_time) < self.epoch_during_time:
local_epoch[now_client_num] = self.epoch_cnt - 1
print('client :',str(now_client_num),'local epoch : ',str(self.epoch_cnt), 'epoch during :'\
,str(self.epoch_during_time),' remain :',str(tr_time - self.now_during_time))
self.model.stop_training = True
Call back 함수를 통해 추가 훈련 여부 추적
Remain time에 대한 지속적인 추적으로 인해 client 공회전 방지 및 자원 소모 감소
model 생성
def create_model():
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(32,32,3), activation='relu', padding='same'))
model.add(Dropout(0.2))
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(Dropout(0.2))
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(Dropout(0.2))
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dropout(0.2))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
return model
Adam Optimizer
def adam(m,v,adam_iter,aggre_grad,global_grad):
lr = 0.001
beta1 = 0.9
beta2 = 0.999
adam_iter+=1
lr_t = lr * np.sqrt(1.0 - beta2**adam_iter) / (1.0 - beta1**adam_iter)
for x in range(len(v)):
m[x] += (1-beta1)*(aggre_grad[x] - m[x])
v[x] += (1-beta2)*(aggre_grad[x]**2 - v[x])
global_grad[x] -= (lr_t * m[x]) / (np.sqrt(v[x]) + 1e-5)
return global_grad , m, v ,adam_iter
straggler 및 partial Aggregation에 의한 Gradient drift를 막고 수렴 유도
Basic Approach
Client
def client(train_data,test_data,epochs,global_weight,ada):
cb = TimingCallback()
#전송받은 global weight를 client에 setting
local_weights = copy.deepcopy(global_weight)
local_model.set_weights(local_weights)
#load_data
local_x_train,local_y_train = divide(train_data)
local_x_test,local_y_test = divide(test_data)
local_x_train,local_y_train = np.array(local_x_train),np.array(local_y_train)
local_x_test,local_y_test = np.array(local_x_test),np.array(local_y_test)
#train
if ada ==1:
local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test),callbacks=[cb])
else:
local_model.fit(local_x_train,local_y_train,epochs=epochs,batch_size=32,verbose=0,validation_data=(local_x_test,local_y_test))
return local_weights - np.array(local_model.get_weights())
callback 함수를 통해 훈련과정에서 local update가 수행되지 않는 경우에 훈련 종료
Server
def fed_avg(total_iteration,local_epoch,client_num):
global_model = create_model()
global_weight = copy.deepcopy(np.array(global_model.get_weights()))
client_weights = [0 for x in range(client_num)]
total_client_data_size = np.sum(client_data_size)
total_acc = []
for iter in range(total_iteration):
######################################################
loss,acc = global_model.evaluate(x_test, y_test, verbose=2)
print(iter,"avg global 모델의 정확도: {:5.2f}%".format(100*acc))
total_acc.append(100*acc)
##################################################
for x in range(client_num):
client_weights[x] = client(client_train[x],client_test[x],local_epoch,global_weight,0)
##################################################
for x in range(client_num):
if x == 0:
aggre_grad = np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
else:
aggre_grad += np.multiply(client_weights[x],(client_data_size[x]/total_client_data_size))
global_model.set_weights(global_weight-aggre_grad)
global_weight = copy.deepcopy(np.array(global_model.get_weights()))
return total_acc
client에게 가중치를 송부하고 훈련 결과 받아 Weighted Aggregation 수행
다음 포스트 에서는 가중치를 송부하고 수신하는 코드를 작성하며 최종적으로는 가중치에 노이즈가 포함되는 경우에 global model의 수렴을 촉진하는 코드를 작성합니다.
'코드' 카테고리의 다른 글
| FedProx - pytorch implementation (0) | 2022.12.21 |
|---|---|
| Socket steaming (0) | 2022.12.21 |
| Python socket basic implementation (0) | 2022.12.21 |
| FedAVG pytorch implementation (0) | 2022.12.21 |
| 효율적인 프로그램과 리스트 (1) | 2022.12.21 |
Comments