Federated Learning

Federated Learning

Federated Learning

Federated Learning (FL) is a machine learning approach that enables model training across decentralised devices or servers while keeping the training data localised.

In traditional machine learning, data is typically gathered and centralised in one location for model training.

  • However, in federated learning, the model is trained across multiple devices or servers, each holding its own data
  • The fundamental idea behind FL is to bring the ML model to the data rather than moving the data to a central location
  • This approach has two main advantages:
    • reduced communication costs and time (datasets can be very large, in the order of several Gigabytes)
    • privacy preservation (thanks to the ability to leverage distributed datasets without centralising sensitive information)

Principles of Federated Learning

  • Decentralized Training: In FL, the training process is distributed across multiple devices or servers

    • Each device independently computes an update to the model based on its local data
  • Model Aggregation: After local computations, these model updates are sent to a central server or aggregator, which combines them to update the global model

    • This global model is then sent back to the devices for the next round of local computations.
  • Privacy Preservation: One of the main advantages of FL is privacy

    • Since raw data never leaves the local devices, users’ sensitive information is kept private
    • Only model updates, which are typically anonymised and aggregated, are shared
  • Reduced Communication Overhead: FL helps in reducing the need for transferring large amounts of data to a central server, as only model updates are exchanged

    • This is particularly beneficial in scenarios with limited bandwidth or high communication costs.
  • Collaborative Learning: FL is well-suited for scenarios where multiple parties want to collaborate on building a shared ML model without sharing their raw data

  • Iterative Process: FL typically involves multiple rounds of communication and model updates

    • The process iterates until the model achieves satisfactory performance or convergence

Next-word prediction

Federated Learning was originally developed by Google as a way to train the next-word predictor for virtual keyboards of Android phones [1] Preserving users’ privacy: no direct use of their text messages

[1] Hard, Andrew, Kanishka Rao, Rajiv Mathews, Swaroop Ramaswamy, Françoise Beaufays, Sean Augenstein, Hubert Eichner, Chloé Kiddon, and Daniel Ramage. “Federated learning for mobile keyboard prediction.” arXiv preprint arXiv:1811.03604 (2018).

Is Federated Learning usable/suitable for cybersecurity applications?

Scenario: A cybersecurity company leverages AI technologies to provide security services (threat detection and response) to multiple organisations Problem: As for Google, using data from the clients for training purposes might be not possible for confidentiality reasons

Spam filtering application: data consists of e-mails, and not only malicious ones, we need legitimate messages as well

I n t r u s i o n d e t e c t i o n system: data consists of network traffic, malicious and benign

Anomaly detection in Industrial control system: data consists of sensor readings

The Federated Learning process

  • A central server collects the weights of the model trained locally by the clients (no data sharing among clients or with the server)
  • The server computes the weighted average of the weights and sends the resulting model back to the clients for further training

8

Steps of the FL process

    1. FL assumes a fixed set of K clients with a fixed local dataset
    1. At each round t, a random fraction F of clients is selected (for efficiency reasons) and the server sends them an ANN model for local training wt
    1. Each selected client updates the common model with local data with one or more steps of mini-batch gradient descent k wk twtαJk(wt )
    1. Next, the clients send the updated models to the server for aggregation (note that the aggregation is done with all K clients’ models) wk t wt+1 ← ∑K k=1 nk n wk t
    1. This process is iterated for several rounds until the desired test-set accuracy is reached
    1. The whole process is called Federated Averaging (FedAvg) algorithm.

Federated Averaging (FedAvg)

The server computes the average of clients’ models weighted with the number of local training samples ( ), nk

$$\boldsymbol{w}{t+1} \leftarrow \sum{k=1}^{K} \frac{n_k}{n} \boldsymbol{w}t^k \quad \text{where } n = \sum{k=1}^{K} n_k - 1$$

Important parameters of FedAvg

The fraction F of clients that execute the update (randomly selected)

The number of local updates performed by each client:

$$\mu_k = E \cdot S = E \cdot \frac{n_k}{B}$$

where E=epochs, S=MBGD steps and B=batch size

FedAvg algorithm

Server executes: initialize for each round t = 0, 1, 2, . . . do (random set of m clients) for each client in parallel do ClientUpdate( ) w0 mmax(FK,1) StkSt wk tk,wt wt+1 ← ∑K k=1 nk n wk t

ClientUpdate( ): // Run on client k (split into batches of size B) for each local epoch i from 1 to E do for batch do return w to server k,w ℬ ← Xk train b ∈ ℬ wkwkαJk(wk , b)

10

Limitations of FedAvg

Unlike in other domains (e.g. image classification [1]), FL is not so robust with non-IID and unbalanced datasets of network traffic

FL assumes test data available at the server site to control the training process

Volumetric DDoS Botnet C&C

[1] B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Artificial intelligence and statistics, 2017, pp. 1273–1282.

FedAvg assigns the same amount of computation to all the clients selected for a round of training, irrespective of the accuracy level reached by the global model on specific clients’ data.

The weighted average of FedAvg gives more importance to the weights of the clients with large local training sets, to the detriment of the smallest ones.

  • A model trained with FedAvg may fail in learning attacks characterised by out-ofdistribution features that are available only in small local training sets wt+1 ← ∑K k=1 nk n wk t

State of the art

1. Focuses on performance:

  • Max accuracy (e.g., comparison with centralised training)
  • Min communication rounds
    1. Data sharing for improving convergence on non-IID and unbalanced datasets
    1. Often, the vanilla FedAvg algorithm is used (weighted averaging)
    1. The stopping procedure is usually neglected Source: S. I. Popoola, R. Ande, B. Adebisi, G. Gui, M. Hammoudeh and O.

Jogunola, “Federated Deep Learning for Zero-Day Botnet Attack Detection in IoT-Edge Devices,” in IEEE Internet of Things Journal, vol. 9, no. 5, pp. 3930-3944, 1 March1, 2022

FLAD: Adaptive Federated Learning for DDoS Attack Detection

Roberto Doriguzzi-Corin, Domenico Siracusa, “FLAD: Adaptive Federated Learning for DDoS attack detection”, in Computers & Security, Volume 137, 2024, doi: 10.1016/j.cose.2023.103597

FLAD in a nutshell

FLAD enhances FedAVG with a mechanism to monitor the performance of the global model on the clients’ data with no sharing of clients’ data

To ensure that all the traffic profiles (benign and malicious) have been learnt by the model To implement an early-stopping strategy

FLAD workflow

  • A time the server generates random parameters (weights and biases) and distributes them to the K clients t 0
  • A time t1 the server receives the updates from the clients and aggregates them to obtain new parameters w1
  • A time the server sends the updated model to all the clients for local validation of the new model t 1 w1
  • A time t1 the server receives the accuracy scores (e.g., F1 score) computed by the clients on their respective validation sets.
  • The number of epochs and MBGD steps depends on the accuracy score of each client: the higher the score, the lower steps and epochs. No MBGD steps means that the client is not selected for a specific round

The process continues until convergence is reached (early stopping based on the accuracy score).

Dataset: CIC-DDoS2019

Attack #Flows Transport Description
DNS
LDAP
MSSOL
NTP
NetBIOS
Portmap
441931
11400
9559537
1194836
7553086
186449
UDP DDoS attacks that exploit a specific UDP-based network service to overwhelm the victim with
responses to queries sent by the attacks to a server using the spoofed victim’s IP address. Six
types of network services have been exploited to generate these attacks: Domain Name System
(DNS), Lightweight Directory Access Protocol (LDAP), Microsoft SQL (MSSQL), Network
Time Protocol (NTP), Network Basic Input/Output System (NetBIOS) and Port Mapper
(Portmap).
SNMP 1334534 UDP Reflected amplification attack leveraging the Simple Network Management Protocol (SNMP)
protocol (UDP-based) used to configure network devices.
SSDP 2580154 UDP Attack based on the Simple Service Discovery Protocol (SSDP) protocol that enables UPnP
devices to send and receive information over UDP. Vulnerable devices send UPnP replies to
the spoofed IP address of the victim.
TFTP 6503575 UDP Attack built by reflecting the files requested to a Trivial File Transfer Protocol (TFTP) server
toward the victim’s spoofed IP address.
Syn Flood 6056402 TCP Attack that exploits the TCP three-handshake mechanism to consume the victim’s resources
with a flood of SYN packets.
UDP Flood 6969476
UDP
network resources.
Attack built with high rates of small spoofed UDP packets with the aim to consume the victim’s
UDPLag 474018 UDP UDP traffic generated to slow down the victim’s connection with the online gaming server.
WebDDoS 146 TCP A short DDoS attack (around 3100 packets) against a web server on port 80.
Total 42865789 Despite the huge amount of flows, the dataset is heavily imbalanced containing 8 predominant
DDoS attack types, with more than one million of flows each, a tew tenths of thousands flows
for the LDAP and Portmap reflection attacks, and only 146 flows for the WebDDoS attack.

Independent and identically distributed data

Definition: a collection of random variables is independent and identically distributed if each random variable has the same probability distribution as the others and all are mutually independent.

(Clauset, Aaron (2011). “A brief primer on probability distributions” (PDF). Santa Fe Institute.)

Non-i.i.d. data

Jensen-Shannon Distance between probability distributions

FL with non-i.i.d. data

MNIST is a dataset of handwritten digits CIFAR is a dataset of images (cats, dogs, aeroplanes, etc.)

They are two public datasets for image classification applications (10 classes each)

Source: Zhao et al. “Federated Learning with Non-IID Data”, 2018

SGD: centralised training using Mini-batch Gradient Descent IID: each client is randomly assigned a uniform distribution over 10 classes Non-IID: the data is sorted by class and divided to create two extreme cases: 1-class non-IID: each client receives data partition from only a single class 2-class non-IID: the sorted data is divided into 20 partitions and each client is randomly assigned 2 partitions from 2 classes.

Experimental setup: data preparation

Attack Samples Training Validation Test Pathol
WebDDoS 402 321 37 44
LDAP 854 633 135 86 data pa
Portmap 1605 1299 145 161 receive
DNS 3207 2595 291 321
UDPLag 6400 5184 576 640 single a
NTP 12807 10372 1153 1282
SNMP 25649 20775 2309 2565 equival
SSDP 51207 41477 4609 5121 sample
Syn Flood 102400 82940 9216 10244
TFTP 204800 165887 18433 20480
UDP Flood 409601 331772 36864 40965
NetBIOS 819200 663551 73728 81921
MSSQL 1638404 1327105 147457 163842

Pathological non-IID data partition: each client receives data from only a single attack plus an equivalent amount of samples of benign traffic.

Experimental setup: ANN architecture

Name
Value
Description
PATIENCE 25 Max FL rounds with no progress.
Min epochs 1 Min number of local training epochs.
Max epochs 5 Max number of local training epochs.
Min steps 10 Min number MBGD steps.
1000
Max steps
Max number MBGD steps.
n x f 10 × 11 Size of the MLP input layer.
2 Number of hidden layers.
32
m
Number of neurons/layer.

27

Convergence analysis on non-i.i.d attacks

Evaluation on unseen data

Attack FLAD
E=A,S=A
FedAvg
E=1,B=50
FedAvg
E=5,B=50
WebDDoS 0.7864 0.0727 0.7182
LDAP 0.9306 0.8972 0.9306
Portmap 0.9250 0.8548 0.9221
DNS 0.9779 0.9060 0.8799
UDPLag 0.9652 0.9978 0.9984
NTP 0.9660 0.9874 0.9701
SNMP 0.9574 0.9211 0.9586
SSDP 0.9663 0.9983 0.9988
Syn 0.9767 0.3188 0.3254
TFTP 0.9439 0.9483 0.9372
UDP 0.9656 0.9996 0.9995
NetBIOS 0.9218 0.8581 0.9272
MSSQL 0.9981 0.9994 0.9176
Average 0.9699 0.9396 0.9155

True Positive Rate (TPR) measured on the test sets of the clients using the models obtained with FLAD and FedAvg.

We want to compare the ability of the aggregated models to correctly classify all the attacks.

FEDAVG does not perform well on out-ofdistribution (o.o.d.) attack traffic (the TCPbased attacks WebDDoS and Syn Flood)

Summary

FLAD FedAVG
Selection of clients Accuracy-based Random
Local Computation Dynamic Static
Stopping procedure Multiple (e.g., patience, target accuracy/std_dev) Test set

Further remarks

FLAD has been evaluated for DDoS attack scenarios.

  • However, the proposed approach can be effectively adopted in other cybersecurity applications where clients are expected to contribute with zero-day attack samples, whose profiles are not available to the server to assess the global model. In the experiments, PATIENCE=25 has been adopted to stop the FL process

  • Alternatively, more advanced stopping strategies are also possible with FLAD:

    • E.g., wait until a target average accuracy is reached, perhaps also combined with a target standard deviation of the accuracy scores to ensure that the performance is stable across all local datasets

Challenges in FL: malicious clients

A compromised client seeks to compromise the global model [1]: data poisoning label flipping: setting the labels of certain attacks to 0 (benign)

This manipulation may result in the model missing certain types of cyber attacks

[1] Kumar, K. N., Mohan, C. K., & Cenkeramaddi, L. R. (2023). The impact of adversarial attacks on federated learning: A survey. IEEE Transactions on Pattern Analysis and Machine Intelligence, 46(5), 2672-2691.

Challenges in FL: malicious server

Clients of FL may face the threat of reconstruction attacks [1] perpetrated by a malicious server

  • It can exploit information on global model architecture, clients’ gradients and other metadata to infer details of the original clients’ training data 33 [1] Geiping, J., Bauermeister, H., Dröge, H., & Moeller, M. (2020). Inverting gradients-how easy is it to break privacy in federated learning?. Advances in neural information processing systems, 33, 16937-16947.

Defenses against malicious servers

Homomorphic Encryption (HE) [1]:

  • Allows computations to be performed on encrypted data
  • With HE, gradient aggregation can be performed on ciphertexts without decrypting them in advance
  • HE makes it challenging for a malicious server to access sensitive information

[1] Zhang, C., Li, S., Xia, J., Wang, W., Yan, F., & Liu, Y. (2020). {BatchCrypt}: Efficient homomorphic encryption for {Cross-Silo} federated learning. In 2020 USENIX annual technical conference (USENIX ATC 20) (pp. 493-506).