Skip to content

Using collective learning with pytorch

This tutorial is a simple guide to trying out the collective learning protocol with your own machine learning code. Everything runs locally.

The most flexible way to use the collective learning backends is to make a class that implements the Collective Learning MachineLearningInterface defined in ml_interface.py. For more details on how to use the MachineLearningInterface see here

However, the simpler way is to use one of the helper classes that we have provided that implement most of the interface for popular ML libraries. In this tutorial we are going to walk through using the PytorchLearner. First we are going to define the model architecture, then we are going to load the data and configure the model, and then we will run Collective Learning.

A standard script for machine learning with Pytorch looks like the one below

# ------------------------------------------------------------------------------
#
#   Copyright 2021 Fetch.AI Limited
#
#   Licensed under the Creative Commons Attribution-NonCommercial International
#   License, Version 4.0 (the "License"); you may not use this file except in
#   compliance with the License. You may obtain a copy of the License at
#
#       http://creativecommons.org/licenses/by-nc/4.0/legalcode
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
# ------------------------------------------------------------------------------
from torchsummary import summary
from torchvision import transforms, datasets
import torch.utils.data

import torch.nn as nn
import torch.nn.functional as nn_func

# define some constants
batch_size = 64
seed = 42
n_rounds = 20
train_fraction = 0.9
learning_rate = 0.001
height = 28
width = 28
n_classes = 10
num_test_batches = 10

no_cuda = False
cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

# Load the data
data = datasets.MNIST('/tmp/mnist', transform=transforms.ToTensor(), download=True)
n_train = int(train_fraction * len(data))
n_test = len(data) - n_train
train_data, test_data = torch.utils.data.random_split(data, [n_train, n_test])

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, **kwargs)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, **kwargs)


# Define the model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, n_classes)

    def forward(self, x):
        x = nn_func.relu(self.conv1(x.view(-1, 1, height, width)))
        x = nn_func.max_pool2d(x, 2, 2)
        x = nn_func.relu(self.conv2(x))
        x = nn_func.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = nn_func.relu(self.fc1(x))
        x = self.fc2(x)
        return nn_func.log_softmax(x, dim=1)


model = Net()
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.NLLLoss()

# Train and evaluate the model
for round in range(n_rounds):
    # train model
    model.train()

    for batch_idx, (data, labels) in enumerate(train_dataloader):
        opt.zero_grad()

        # Data needs to be on same device as model
        data = data.to(device)
        labels = labels.to(device)

        output = model(data)

        loss = criterion(output, labels)
        loss.backward()
        opt.step()

    # evaluate model
    model.eval()
    total_score = 0
    all_labels = []
    all_outputs = []
    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(test_dataloader):
            if batch_idx == num_test_batches:
                break
            data = data.to(device)
            labels = labels.to(device)
            output = model(data)
            total_score += criterion(output, labels)
    avg_loss = float(total_score / (num_test_batches * batch_size))
    print(f"Average loss at round {round} is {avg_loss}")

There are three steps:

  1. Load the data
  2. Define the model
  3. Train the model

In this tutorial we are going to see how to modify each step to use collective learning. We'll end up with code like this:

# ------------------------------------------------------------------------------
#
#   Copyright 2021 Fetch.AI Limited
#
#   Licensed under the Creative Commons Attribution-NonCommercial International
#   License, Version 4.0 (the "License"); you may not use this file except in
#   compliance with the License. You may obtain a copy of the License at
#
#       http://creativecommons.org/licenses/by-nc/4.0/legalcode
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
# ------------------------------------------------------------------------------
import os

from typing_extensions import TypedDict
import torch.nn as nn
import torch.nn.functional as nn_func
import torch.utils.data
from torchsummary import summary
from torchvision import transforms, datasets

from colearn.training import initial_result, collective_learning_round, set_equal_weights
from colearn.utils.plot import ColearnPlot
from colearn.utils.results import Results, print_results
from colearn_pytorch.utils import categorical_accuracy
from colearn_pytorch.pytorch_learner import PytorchLearner

"""
MNIST training example using PyTorch

Used dataset:
- MNIST is set of 60 000 black and white hand written digits images of size 28x28x1 in 10 classes

What script does:
- Loads MNIST dataset from torchvision.datasets
- Randomly splits dataset between multiple learners
- Does multiple rounds of learning process and displays plot with results
"""

# define some constants
n_learners = 5
batch_size = 64

testing_mode = bool(os.getenv("COLEARN_EXAMPLES_TEST", ""))  # for testing
n_rounds = 20 if not testing_mode else 1
vote_threshold = 0.5
train_fraction = 0.9
vote_fraction = 0.05
learning_rate = 0.001
height = 28
width = 28
n_classes = 10
vote_batches = 2
score_name = "categorical accuracy"

no_cuda = False
cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
DataloaderKwargs = TypedDict('DataloaderKwargs', {'num_workers': int, 'pin_memory': bool}, total=False)
kwargs: DataloaderKwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

# Load the data and split for each learner.
DATA_DIR = os.environ.get('PYTORCH_DATA_DIR',
                          os.path.expanduser(os.path.join('~', 'pytorch_datasets')))
data = datasets.MNIST(DATA_DIR, transform=transforms.ToTensor(), download=True)
n_train = int(train_fraction * len(data))
n_vote = int(vote_fraction * len(data))
n_test = len(data) - n_train - n_vote
train_data, vote_data, test_data = torch.utils.data.random_split(data, [n_train, n_vote, n_test])

data_split = [len(train_data) // n_learners] * n_learners
learner_train_data = torch.utils.data.random_split(train_data, data_split)
learner_train_dataloaders = [torch.utils.data.DataLoader(
    ds,
    batch_size=batch_size, shuffle=True, **kwargs) for ds in learner_train_data]

data_split = [len(vote_data) // n_learners] * n_learners
learner_vote_data = torch.utils.data.random_split(vote_data, data_split)
learner_vote_dataloaders = [torch.utils.data.DataLoader(
    ds,
    batch_size=batch_size, shuffle=True, **kwargs) for ds in learner_vote_data]

data_split = [len(test_data) // n_learners] * n_learners
learner_test_data = torch.utils.data.random_split(test_data, data_split)
learner_test_dataloaders = [torch.utils.data.DataLoader(
    ds,
    batch_size=batch_size, shuffle=True, **kwargs) for ds in learner_test_data]


# Define the model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, n_classes)

    def forward(self, x):
        x = nn_func.relu(self.conv1(x.view(-1, 1, height, width)))
        x = nn_func.max_pool2d(x, 2, 2)
        x = nn_func.relu(self.conv2(x))
        x = nn_func.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = nn_func.relu(self.fc1(x))
        x = self.fc2(x)
        return nn_func.log_softmax(x, dim=1)


# Make n instances of PytorchLearner with model and torch dataloaders
all_learner_models = []
for i in range(n_learners):
    model = Net().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    learner = PytorchLearner(
        model=model,
        train_loader=learner_train_dataloaders[i],
        vote_loader=learner_vote_dataloaders[i],
        test_loader=learner_test_dataloaders[i],
        device=device,
        optimizer=opt,
        criterion=torch.nn.NLLLoss(),
        num_test_batches=vote_batches,
        vote_criterion=categorical_accuracy,
        minimise_criterion=False
    )

    all_learner_models.append(learner)

# Ensure all learners starts with exactly same weights
set_equal_weights(all_learner_models)

summary(all_learner_models[0].model, input_size=(width, height), device=str(device))

# Train the model using Collective Learning
results = Results()
results.data.append(initial_result(all_learner_models))

plot = ColearnPlot(score_name=score_name)

for round_index in range(n_rounds):
    results.data.append(
        collective_learning_round(all_learner_models,
                                  vote_threshold, round_index)
    )
    print_results(results)

    plot.plot_results_and_votes(results)

plot.block()

print("Colearn Example Finished!")

The first thing is to modify the data loading code. Each learner needs to have their own training and testing set from the data. This is easy to do with the pytorch random_split utility:

data_split = [len(test_data) // n_learners] * n_learners
learner_test_data = torch.utils.data.random_split(test_data, data_split)

The model definition is the same as before. To use collective learning, we need to create an object that implements the MachineLearningInterface. To make it easier to use the MachineLearningInterface with pytorch, we've defined PytorchLearner. PytorchLearner implements standard training and evaluation routines as well as the MachineLearningInterface methods.

# ------------------------------------------------------------------------------
#
#   Copyright 2021 Fetch.AI Limited
#
#   Licensed under the Creative Commons Attribution-NonCommercial International
#   License, Version 4.0 (the "License"); you may not use this file except in
#   compliance with the License. You may obtain a copy of the License at
#
#       http://creativecommons.org/licenses/by-nc/4.0/legalcode
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
# ------------------------------------------------------------------------------
from typing import Optional, Callable
from collections import OrderedDict, defaultdict

try:
    import torch
except ImportError:
    raise Exception(
        "Pytorch is not installed. To use the pytorch "
        "add-ons please install colearn with `pip install colearn[pytorch]`."
    )

import torch.nn
import torch.optim
import torch.utils
import torch.utils.data
from torch.nn.modules.loss import _Loss

from colearn.ml_interface import (
    MachineLearningInterface,
    Weights,
    ProposedWeights,
    ColearnModel,
    convert_model_to_onnx,
    ModelFormat,
    DiffPrivBudget,
    DiffPrivConfig,
    TrainingSummary,
    ErrorCodes,
)

from opacus import PrivacyEngine

_DEFAULT_DEVICE = torch.device("cpu")


class PytorchLearner(MachineLearningInterface):
    """
    Pytorch learner implementation of machine learning interface
    """

    def __init__(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        train_loader: torch.utils.data.DataLoader,
        vote_loader: torch.utils.data.DataLoader,
        test_loader: Optional[torch.utils.data.DataLoader] = None,
        need_reset_optimizer: bool = True,
        device=_DEFAULT_DEVICE,
        criterion: Optional[_Loss] = None,
        minimise_criterion=True,
        vote_criterion: Optional[Callable[[torch.Tensor, torch.Tensor], float]] = None,
        num_train_batches: Optional[int] = None,
        num_test_batches: Optional[int] = None,
        diff_priv_config: Optional[DiffPrivConfig] = None,
    ):
        """
        :param model: Pytorch model used for training
        :param optimizer: Training optimizer
        :param train_loader: Train dataset
        :param test_loader: Optional test dataset - subset of training set will be used if not specified
        :param need_reset_optimizer: True to clear optimizer history before training, False to kepp history.
        :param device: Pytorch device - CPU or GPU
        :param criterion: Loss function
        :param minimise_criterion: True to minimise value of criterion, False to maximise
        :param vote_criterion: Function to measure model performance for voting
        :param num_train_batches: Number of training batches
        :param num_test_batches: Number of testing batches
        :param diff_priv_config: Contains differential privacy (dp) budget related configuration
        """

        # Model has to be on same device as data
        self.model: torch.nn.Module = model.to(device)
        self.optimizer: torch.optim.Optimizer = optimizer
        self.criterion = criterion
        self.train_loader: torch.utils.data.DataLoader = train_loader
        self.vote_loader: torch.utils.data.DataLoader = vote_loader
        self.test_loader: Optional[torch.utils.data.DataLoader] = test_loader
        self.need_reset_optimizer = need_reset_optimizer
        self.device = device
        self.num_train_batches = num_train_batches or len(train_loader)
        self.num_test_batches = num_test_batches
        self.minimise_criterion = minimise_criterion
        self.vote_criterion = vote_criterion

        self.dp_config = diff_priv_config
        self.dp_privacy_engine = PrivacyEngine()

        if diff_priv_config is not None:
            (
                self.model,
                self.optimizer,
                self.train_loader,
            ) = self.dp_privacy_engine.make_private(
                module=self.model,
                optimizer=self.optimizer,
                data_loader=self.train_loader,
                max_grad_norm=diff_priv_config.max_grad_norm,
                noise_multiplier=diff_priv_config.noise_multiplier,
            )

        self.vote_score = self.test(self.vote_loader)

    def mli_get_current_weights(self) -> Weights:
        """
        :return: The current weights of the model
        """

        current_state_dict = OrderedDict()
        for key in self.model.state_dict():
            current_state_dict[key] = self.model.state_dict()[key].clone()
        w = Weights(
            weights=current_state_dict, training_summary=self.get_training_summary()
        )

        return w

    def mli_get_current_model(self) -> ColearnModel:
        """
        :return: The current model and its format
        """

        return ColearnModel(
            model_format=ModelFormat(ModelFormat.ONNX),
            model_file="",
            model=convert_model_to_onnx(self.model),
        )

    def set_weights(self, weights: Weights):
        """
        Rewrites weight of current model
        :param weights: Weights to be stored
        """

        self.model.load_state_dict(weights.weights)

    def reset_optimizer(self):
        """
        Clear optimizer state, such as number of iterations, momentums.
        This way, the outdated history can be erased.
        """

        self.optimizer.__setstate__({"state": defaultdict(dict)})

    def train(self):
        """
        Trains the model on the training dataset
        """

        if self.need_reset_optimizer:
            # erase the outdated optimizer memory (momentums mostly)
            self.reset_optimizer()

        self.model.train()

        for batch_idx, (data, labels) in enumerate(self.train_loader):
            if batch_idx == self.num_train_batches:
                break
            self.optimizer.zero_grad()

            # Data needs to be on same device as model
            data = data.to(self.device)
            labels = labels.to(self.device)

            output = self.model(data)

            loss = self.criterion(output, labels)
            loss.backward()
            self.optimizer.step()

    def mli_propose_weights(self) -> Weights:
        """
        Trains model on training set and returns new weights after training
        - Current model is reverted to original state after training
        :return: Weights after training
        """

        current_weights = self.mli_get_current_weights()
        training_summary = current_weights.training_summary
        if (
            training_summary is not None
            and training_summary.error_code is not None
            and training_summary.error_code == ErrorCodes.DP_BUDGET_EXCEEDED
        ):
            return current_weights

        self.train()
        new_weights = self.mli_get_current_weights()
        self.set_weights(current_weights)

        training_summary = new_weights.training_summary
        if (
            training_summary is not None
            and training_summary.error_code is not None
            and training_summary.error_code == ErrorCodes.DP_BUDGET_EXCEEDED
        ):
            current_weights.training_summary = training_summary
            return current_weights

        return new_weights

    def mli_test_weights(self, weights: Weights) -> ProposedWeights:
        """
        Tests given weights on training and test set and returns weights with score values
        :param weights: Weights to be tested
        :return: ProposedWeights - Weights with vote and test score
        """

        current_weights = self.mli_get_current_weights()
        self.set_weights(weights)

        vote_score = self.test(self.vote_loader)

        if self.test_loader:
            test_score = self.test(self.test_loader)
        else:
            test_score = 0
        vote = self.vote(vote_score)

        self.set_weights(current_weights)
        return ProposedWeights(
            weights=weights, vote_score=vote_score, test_score=test_score, vote=vote
        )

    def vote(self, new_score) -> bool:
        """
        Compares current model score with proposed model score and returns vote
        :param new_score: Proposed score
        :return: bool positive or negative vote
        """

        if self.minimise_criterion:
            return new_score < self.vote_score
        else:
            return new_score > self.vote_score

    def test(self, loader: torch.utils.data.DataLoader) -> float:
        """
        Tests performance of the model on specified dataset
        :param loader: Dataset for testing
        :return: Value of performance metric
        """

        if not self.criterion:
            raise Exception("Criterion is unspecified so test method cannot be used")

        self.model.eval()
        total_score = 0
        all_labels = []
        all_outputs = []
        batch_idx = 0
        total_samples = 0
        with torch.no_grad():
            for batch_idx, (data, labels) in enumerate(loader):
                total_samples += labels.shape[0]
                if self.num_test_batches and batch_idx == self.num_test_batches:
                    break
                data = data.to(self.device)
                labels = labels.to(self.device)
                output = self.model(data)
                if self.vote_criterion is not None:
                    all_labels.append(labels)
                    all_outputs.append(output)
                else:
                    total_score += self.criterion(output, labels).item()
        if batch_idx == 0:
            raise Exception("No batches in loader")
        if self.vote_criterion is None:
            return float(total_score / total_samples)
        else:
            return self.vote_criterion(
                torch.cat(all_outputs, dim=0), torch.cat(all_labels, dim=0)
            )

    def mli_accept_weights(self, weights: Weights):
        """
        Updates the model with the proposed set of weights
        :param weights: The new weights
        """

        self.set_weights(weights)
        self.vote_score = self.test(self.vote_loader)

    def get_training_summary(self) -> Optional[TrainingSummary]:
        """
        Differential Privacy Budget
        :return: the target and consumed epsilon so far
        """

        if self.dp_config is None:
            return None

        delta = self.dp_config.target_delta
        target_epsilon = self.dp_config.target_epsilon
        consumed_epsilon = self.dp_privacy_engine.get_epsilon(delta)

        budget = DiffPrivBudget(
            target_epsilon=target_epsilon,
            consumed_epsilon=consumed_epsilon,
            target_delta=delta,
            consumed_delta=delta,  # delta is constatnt per training
        )

        err = (
            ErrorCodes.DP_BUDGET_EXCEEDED
            if consumed_epsilon >= target_epsilon
            else None
        )

        return TrainingSummary(
            dp_budget=budget,
            error_code=err,
        )

We create a set of PytorchLearners by passing in the model and the datasets:

all_learner_models = []
for i in range(n_learners):
    model = Net()
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    learner = PytorchLearner(
        model=model,
        train_loader=learner_train_dataloaders[i],
        vote_loader=learner_vote_dataloaders[i],
        test_loader=learner_test_dataloaders[i],
        device=device,
        optimizer=opt,
        criterion=torch.nn.NLLLoss(),
        num_test_batches=vote_batches,
        vote_criterion=categorical_accuracy,
        minimise_criterion=False
    )

    all_learner_models.append(learner)

Then we give all the models the same weights to start off with:

set_equal_weights(all_learner_models)

And then we can move on to the final stage, which is training with Collective Learning. The function collective_learning_round performs one round of collective learning. One learner is selected to train and propose an update. The other learners vote on the update, and if the vote passes then the update is accepted. Then a new round begins.

# Train the model using Collective Learning
results = Results()
results.data.append(initial_result(all_learner_models))

for round in range(n_rounds):
    results.data.append(
        collective_learning_round(all_learner_models,
                                  vote_threshold, round)
    )

    plot_results(results, n_learners, score_name=score_name)
    plot_votes(results)

# Plot the final result with votes
plot_results(results, n_learners, score_name=score_name)
plot_votes(results, block=True)