Skip to content

πŸ›‘οΈ MATL Integration Tutorial

Complete guide to integrating MATL with your federated learning setup


🎯 What You'll Learn

This tutorial demonstrates how to:

  1. βœ… Install and configure MATL
  2. βœ… Convert existing FL code to use MATL
  3. βœ… Run a simple MNIST training example
  4. βœ… Monitor trust scores and detect attacks
  5. βœ… Deploy to production

Prerequisites: - Python 3.8+ - PyTorch 2.0+ - Basic understanding of federated learning

Time to Complete: ~30 minutes


Part 1: Installation & Setup

Install MATL SDK

pip install matl torch torchvision

Import Required Libraries

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

# Import MATL
from matl import MATLClient, MATLMode
from matl.attacks import SignFlipAttack, GaussianNoiseAttack
from matl.metrics import calculate_metrics

Part 2: Define Your Model

Important: Use your existing PyTorch model - nothing MATL-specific required!

class SimpleCNN(nn.Module):
    """
    Simple CNN for MNIST classification.
    This is standard PyTorch - no MATL changes needed!
    """
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

Part 3: Prepare Data

Standard federated learning data partitioning:

def load_mnist_data(num_clients=10, samples_per_client=5000):
    """
    Load MNIST and partition into client datasets.
    This is standard FL setup - no MATL-specific changes.
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load full dataset
    full_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)

    # Partition into client datasets (IID for simplicity)
    client_datasets = []
    indices = np.random.permutation(len(full_dataset))

    for i in range(num_clients):
        start_idx = i * samples_per_client
        end_idx = start_idx + samples_per_client
        client_indices = indices[start_idx:end_idx]
        client_dataset = Subset(full_dataset, client_indices)
        client_datasets.append(client_dataset)

    return client_datasets, test_dataset

# Load data
client_datasets, test_dataset = load_mnist_data(num_clients=10)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

Part 4: MATL Configuration

Initialize the MATL client:

# Initialize MATL client
matl_client = MATLClient(
    mode=MATLMode.MODE1,  # Use oracle-based validation (45% BFT)
    oracle_endpoint="https://oracle.matl.network",  # Public testnet
    node_id="tutorial_client",
    # private_key="path/to/key.pem",  # Optional: for production
)

print(f"MATL Client initialized")
print(f"Mode: {matl_client.mode}")
print(f"Oracle: {matl_client.oracle_endpoint}")

MATL Modes: - Mode 0: Peer comparison (33% BFT tolerance) - Mode 1: PoGQ oracle (45% BFT tolerance) ⭐ Recommended - Mode 2: PoGQ + TEE (50% BFT tolerance)


Part 5: Training WITH MATL (2-Line Integration!)

Here's the magic - only 2 lines of code change from standard FL:

def train_matl_fl(
    client_datasets,
    test_loader,
    matl_client,
    num_rounds=20,
    byzantine_ratio=0.0,  # % of malicious clients
):
    """
    Federated learning WITH MATL protection.

    Key differences from baseline FL:
    1. Submit gradient to MATL for validation
    2. Aggregate using reputation-weighted method
    """
    global_model = SimpleCNN()
    optimizer = optim.SGD(global_model.parameters(), lr=0.01)

    for round_num in range(num_rounds):
        client_gradients = []

        # Each client trains locally (standard FL)
        for client_id, client_data in enumerate(client_datasets):
            local_model = SimpleCNN()
            local_model.load_state_dict(global_model.state_dict())

            # ... local training loop (standard PyTorch) ...

            # Extract gradient
            gradient = [p.grad.clone() for p in local_model.parameters()]

            # ============================================================
            # ⭐ MATL INTEGRATION LINE 1: Submit for validation
            # ============================================================
            result = matl_client.submit_gradient(
                gradient=gradient,
                metadata={
                    "client_id": client_id,
                    "round": round_num,
                }
            )

            print(f"Client {client_id}: Trust score = {result.trust_score:.3f}")

            client_gradients.append({
                "gradient": gradient,
                "trust_score": result.trust_score,
            })

        # ============================================================
        # ⭐ MATL INTEGRATION LINE 2: Reputation-weighted aggregation
        # ============================================================
        aggregated_gradient = matl_client.aggregate(
            gradients=[g["gradient"] for g in client_gradients],
            trust_scores=[g["trust_score"] for g in client_gradients],
            method="reputation_weighted"  # Key difference from FedAvg!
        )

        # Update global model (standard FL)
        for param, grad in zip(global_model.parameters(), aggregated_gradient):
            param.grad = grad
        optimizer.step()

    return global_model

That's it! Two lines of code to get Byzantine resistance.


Part 6: Testing Under Attack

Let's simulate a Byzantine attack scenario:

# Baseline FL (no protection)
print("Training WITHOUT MATL...")
baseline_model = train_baseline_fl(
    client_datasets, test_loader, num_rounds=20
)

# MATL FL under attack (20% Byzantine clients)
print("\nTraining WITH MATL (20% Byzantine)...")
matl_model = train_matl_fl(
    client_datasets,
    test_loader,
    matl_client,
    num_rounds=20,
    byzantine_ratio=0.2,  # 20% malicious
)

Expected Results: - Baseline FL: Accuracy degrades significantly under attack - MATL FL: Maintains >90% of clean accuracy


Part 7: Monitoring & Visualization

Track trust scores and detection rates:

import matplotlib.pyplot as plt

def plot_results(baseline_acc, matl_acc, detection_rates):
    """Visualize training results"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot 1: Accuracy over rounds
    rounds = range(1, len(baseline_acc) + 1)
    ax1.plot(rounds, baseline_acc, 'b--', label='Baseline (under attack)', linewidth=2)
    ax1.plot(rounds, matl_acc, 'g-', label='MATL (protected)', linewidth=2)
    ax1.set_xlabel('Training Round')
    ax1.set_ylabel('Test Accuracy (%)')
    ax1.set_title('Federated Learning Performance')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Detection rate
    ax2.plot(rounds, detection_rates, 'r-', linewidth=2, marker='o')
    ax2.axhline(y=90, color='g', linestyle='--', label='Target (90%)')
    ax2.set_xlabel('Training Round')
    ax2.set_ylabel('Byzantine Detection Rate (%)')
    ax2.set_title('MATL Attack Detection')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('matl_results.png', dpi=300)
    plt.show()

Part 8: Advanced Configuration

Custom Trust Policies

Customize how MATL aggregates gradients:

custom_aggregation = matl_client.aggregate(
    gradients=gradients,
    trust_scores=trust_scores,
    method="custom",
    policy={
        "min_trust_threshold": 0.7,  # Exclude gradients with trust < 0.7
        "weight_function": "quadratic",  # Use scoreΒ² as weight
        "outlier_detection": True,  # Additional statistical filtering
    }
)

Attack Types Supported

# Test different attack scenarios
attack_types = [
    "sign_flip",        # Flip gradient signs
    "gaussian_noise",   # Add random noise
    "model_poisoning",  # Target specific labels
    "sleeper_agent",    # Honest β†’ attack after N rounds
]

for attack in attack_types:
    print(f"\nTesting {attack}...")
    model = train_matl_fl(
        client_datasets,
        test_loader,
        matl_client,
        byzantine_ratio=0.2,
        attack_type=attack
    )

Part 9: Production Deployment Checklist

Installation & Setup

  • Install MATL SDK: pip install matl
  • Generate DID and keypair: matl keygen --output mykey.pem
  • Configure oracle endpoint (or run your own)

Code Integration

  • Replace aggregation logic with matl_client.aggregate()
  • Add gradient submission with matl_client.submit_gradient()
  • Test with small-scale experiment (10 clients, 10 rounds)

Security

  • Store private keys securely (never commit to git!)
  • Use HTTPS for oracle communication
  • Enable authentication tokens for production oracle

Monitoring

  • Log trust scores for all clients
  • Set up alerts for low trust scores
  • Monitor detection rates (should be >85% for known attacks)

Performance

  • Benchmark overhead (should be <30% vs baseline)
  • Test with realistic network latency
  • Validate memory usage (<2GB per node for 1000 nodes)

Testing

  • Unit tests for MATL integration
  • Integration tests with attack simulations
  • Load tests (100+ clients)

Part 10: Next Steps

1. Try Different Attack Scenarios

  • Gaussian noise attack
  • Model poisoning
  • Sleeper agent (honest β†’ attack after N rounds)

2. Experiment with Hyperparameters

  • Byzantine ratio (10%, 20%, 30%, 40%)
  • Trust score thresholds
  • Aggregation methods

3. Test with Your Own Model

  • Replace SimpleCNN with your architecture
  • Use your dataset (not MNIST)
  • Adapt local training loop to your needs

4. Deploy to Production

  • Set up Holochain DHT network
  • Run MATL oracle service
  • Configure monitoring and alerts

5. Join the Community


πŸ“Š Expected Results

Metric Baseline FL MATL (Clean) MATL (20% Byz)
Final Accuracy 98.5% 98.7% 97.8%
Convergence Rounds 15 16 18
Detection Rate N/A N/A 95%+
False Positives N/A N/A <1%

πŸŽ‰ Summary

You've learned how to:

βœ… Install and configure MATL βœ… Integrate MATL with existing FL code (just 2 lines!) βœ… Detect and mitigate Byzantine attacks βœ… Achieve 45% BFT tolerance

MATL protects your federated learning from adversaries while maintaining model performance.



Ready to secure your federated learning? Start integrating MATL today! πŸ›‘οΈ