🚨 Advanced Geometric Aggregation and Gradient Regularization Framework for Semi-supervised Federated Intrusion Detection in Heterogeneous IoT Environments
This repository contains the implementation of an Enhanced Semi-Supervised Federated Learning-based Intrusion Detection System (USSFL-IDS). It is based on the paper:
📄 Zhao et al., "Semi-supervised Federated-Learning-Based Intrusion Detection Method for Internet of Things", IEEE 2023
We further extend their work by integrating robust aggregation, gradient variance tracking, Fishr regularization, and visual analytics to improve generalization and transparency.
Detecting network intrusions in Internet of Things (IoT) environments using semi-supervised federated learning, where client data is heterogeneous (non-IID) and privacy-sensitive.
Feature | Description |
---|---|
Gradient Aggregation | Supports both Arithmetic and Geometric Mean |
Fishr Regularization | Penalizes gradient variance divergence across clients |
Gradient Variance | Tracks gradient dispersion for robust training |
Discriminator Training | Differentiates known vs unknown samples |
Voting Mechanism | Aggregates logits using hard label voting |
Visualizations | Confusion Matrix, Class Distribution, Accuracy/F1 Trends |
USSFL/
├── utils/
│ ├── helper.py # Gradient aggregation & Fishr loss
│ ├── model_utils.py # CNN model architecture
│ ├── creator.py # Dataset creation and client/server setup
│ ├── train_utils.py # Training, prediction, evaluation utilities
│ ├── process_data_utils.py # Data preprocessing, splitting
│ └── visualization.py # 📊 All plots for metrics & analysis
├── data/
│ └── nba_iot_1000/ # N-BaIoT preprocessed CSVs per device & attack type
├── USSFL-IDS.py # Main training script
└── README.md
- Federated Learning (FL) across IoT devices
- Semi-supervised Learning (SSL) with open and private sets
- Dirichlet distribution-based split for simulating realistic non-IID scenarios
- Fishr Regularization: Aligns gradient variance to minimize distributional shift
- Hard-label voting aggregation from client logits
Plot Type | Generated By |
---|---|
Client-Class Distribution | plot_class_distribution() |
Accuracy, F1, Precision Curves | plot_metrics() |
Confusion Matrix | plot_conf_matrix() |
Comm. Overhead vs Accuracy | plot_comm_overhead_vs_accuracy() |
θc Threshold Impact | plot_theta_accuracy() |
Label Strategy Comparison | plot_label_strategy_comparison() |
All graphs are generated after the final communication round.
- Python 3.8+
- PyTorch
- NumPy
- Scikit-learn
- Matplotlib
pip install -r requirements.txt
Place preprocessed N-BaIoT
CSVs under:
data/nba_iot_1000/<DeviceName>/<AttackType>_train.csv
data/nba_iot_1000/<DeviceName>/<AttackType>_test.csv
python USSFL-IDS.py
This will:
- Load the dataset.
- Create clients and server.
- Train for configured rounds.
- Display visual metrics.
- Final test accuracy
- F1-score and precision plots
- Confusion matrix for test predictions
- Communication overhead calculations
- Class-wise client distribution
- Accuracy
- F1-Score (Macro)
- Precision (Macro)
- Confusion Matrix
- Communication Cost
- Zhao et al., "Semi-supervised Federated-Learning-Based Intrusion Detection Method for Internet of Things", IEEE 2023
- This implementation includes original enhancements like Fishr loss and improved aggregation.
This project is for academic and research use only.