Federated Learning

10 minute read

Published:

Federated Learning is a way for server to train deep learning model without touching client data. It does so by outsourcing the training process to client node, and server only aggregates the model updates after such local training.

In the figure below, each client has a local model, which it trains locally using its private data. The device then upload the trained local model to a cloud server. The cloud server then calculate the update and let client download the update to continue training process.

image.png

In this notebook, I will make a demo to simulate Federated Learning process, including how client locally train a model, i.e. how the local model are generated, and how the cloud server gather the update and compute the global updated model, then send back to client.

In order to do so, let’s break down the components in this system:

  • Server: Has a global model, has a way to send and receive models from client, has a function to aggregate the local update from client
  • Client: Has a local dataset, has a way to train the model locally based on his dataset, has a way to send the update to server and receive the aggregated result from server.

Before we start, let’s review how a deep learning model look like. The model can be simplified as a stacked of multiple matrix multiplication: $y=W_n(…(W_1x+b_1)…)+b_n$. The “model” that we use to train is just a big list of matrices and vectors: $W_1,…,W_n$ and $b_1,…,b_n$. These matrices and vectors are not of the same shape, and in order to communicate $W$ and $b$ to central server, we would need to make them into a form of a long vector $v = \texttt{flat}(W_1,..,W_n,b_1,..,b_n)$. We would also need a function $\texttt{deflat}$ to convert back from $v$ to $W_1,..,W_n,b_1,..,b_n$.

In this notebook, we use two helper methods for that purpose in this notebook:

  • A function to convert deep learning model to a vector for averaging at server
  • A function to convert the vector back into the deep learning model to reconstruct the global update at client.
import torch

def model_to_vector(model):
  """
  Convert a PyTorch model to a vector of its weights.

  Args:
    model: A PyTorch model.

  Returns:
    A flattened vector of the model's weights.
  """
  vector = []
  for param in model.parameters():
    vector.append(param.data.view(-1))
  return torch.cat(vector)


def vector_to_model(vector, model):
  """
  Convert a vector of weights back to a PyTorch model.

  Args:
    vector: A flattened vector of weights.
    model: A PyTorch model (with the same architecture as the model used to create the vector).

  Returns:
    The PyTorch model with its weights updated from the vector.
  """
  offset = 0
  for param in model.parameters():
    param_size = param.data.numel()
    param.data = vector[offset:offset + param_size].view(param.data.shape)
    offset += param_size
  return model
# let's create some test code
from torchvision.models import resnet18

model = resnet18()
vector = model_to_vector(model)
print(vector.shape)
model_recon = vector_to_model(vector, model)
x = torch.rand(1, 3, 224, 224)
y = model_recon(x)
print(y.shape)

torch.Size([11689512])
torch.Size([1, 1000])

We see that the model input $x$ has a shape of $(1,3,244,244)$, which indicates an image of width 244, height 244, and 3 “channels”. These 3 channels are for coloring the image: R,G and B. A normal colored image has the value ranging from 0 to 255 at each pixel and each channel. We can try download an example image and print out the range of its value to see:

import matplotlib.pyplot as plt
plt.imshow(x[0].permute(1,2,0))
<matplotlib.image.AxesImage at 0x7b0f0b5ad0c0>

png

import matplotlib.pyplot as plt
# download the image
!wget https://www.stockvault.net/data/2011/02/06/117454/thumb16.jpg -O test.jpg

img = plt.imread('test.jpg')
print(img.shape)
print("Image height: {}, Image width: {}, #channels: {}".format(img.shape[0],img.shape[1],img.shape[2]))
print("Range: {}-{}".format(img.min(), img.max()))
plt.imshow(img)
--2024-11-30 23:58:54--  https://www.stockvault.net/data/2011/02/06/117454/thumb16.jpg
Resolving www.stockvault.net (www.stockvault.net)... 104.26.11.44, 172.67.74.48, 104.26.10.44, ...
Connecting to www.stockvault.net (www.stockvault.net)|104.26.11.44|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35298 (34K) [image/jpeg]
Saving to: ‘test.jpg’

test.jpg            100%[===================>]  34.47K  --.-KB/s    in 0.004s  

2024-11-30 23:58:54 (8.53 MB/s) - ‘test.jpg’ saved [35298/35298]

(340, 604, 3)
Image height: 340, Image width: 604, #channels: 3
Range: 0-255





<matplotlib.image.AxesImage at 0x7b0f0b621990>

png

In this demo, we will work on MNIST dataset, and have 10 clients, each hold a small dataset of 6000 digit images

image.png

Above is an example of MNIST images, it consists of 60000 images of digits, each image has size of 28x28

As the images are not colored, they only have 1 channel, making the whole shape of the input equals 784 = 28 x 28 x 1

# Let's start with a model
from torch import nn

class LogisiticRegression(nn.Module):
  def __init__(self, input_dim, output_dim, device = "cuda"):
    super(LogisiticRegression, self).__init__()
    self.linear = nn.Linear(input_dim, output_dim) # W,b Wx+b 784 -> 100
    self.device = device # GPU
    self.to(device)# pass the param to gpu

  def forward(self, x): #y = f(x)
    x = x.to(self.device) # pass the input to gpu
    x = nn.Flatten()(x) # make x into a 1d vector
    out = self.linear(x) # pass it through a linear layer, i.e. Wx+b

    return out # softmax will be applied later by loss function

def test_model():
  model = LogisiticRegression(784, 10) #784 = 28*28*1
  x = torch.randn(4, 28, 28)
  out = model(x)
  return out
  # print(out)


y = test_model()
print(y, y.shape)
tensor([[-0.1081,  0.2248, -0.4138, -0.5185,  0.0361,  0.7135, -0.0776, -0.0908,
         -0.1996,  1.5042],
        [-0.5581,  0.5065, -1.0035, -1.1533, -0.6793, -0.1168,  0.0972,  1.0448,
         -0.5441, -0.0458],
        [ 0.0676, -0.1420, -0.2188,  0.8619, -0.4488,  0.3295,  0.2454,  0.1947,
          0.6030,  0.3939],
        [ 1.2997,  0.7011,  0.3439,  0.6150,  0.6770,  0.7746,  0.4255,  0.0327,
         -1.3515,  0.9147]], device='cuda:0', grad_fn=<AddmmBackward0>) torch.Size([4, 10])
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())

# Calculate the size of each partition
partition_size = len(train_dataset) // 10

# Split the dataset into 10 non-intersecting partitions
PARTITIONS = random_split(train_dataset, [partition_size] * 10 + [len(train_dataset) - partition_size * 10])[:-1]

# Print the size of each partition
for i, partition in enumerate(PARTITIONS):
  print(f"Partition {i+1}: {len(PARTITIONS[i])} samples")
Partition 1: 6000 samples
Partition 2: 6000 samples
Partition 3: 6000 samples
Partition 4: 6000 samples
Partition 5: 6000 samples
Partition 6: 6000 samples
Partition 7: 6000 samples
Partition 8: 6000 samples
Partition 9: 6000 samples
Partition 10: 6000 samples

Client class

Each client has a unique client id, and hold a set of local data of 6000 images, and a local model. It has a method to connect to the server, a method to receive update from the server, a method to locally train the model and a method to send the update to the server after training.

We also have “hyperparameters” for tuning the training at client: a local_batch_size for number of images being passed to the model at one training iteration, and a local_iteration number for how many iteration of training that we have at client side

from __future__ import annotations
# client class
class Client:
  def __init__(self, client_id, local_batch_size, local_iteration, device="cuda"):
    self.client_id = client_id
    self.local_bs = local_batch_size
    self.local_iter = local_iteration
    self.device = device
    self.data_loader = DataLoader(PARTITIONS[client_id], self.local_bs, shuffle = True)
    self.model = LogisiticRegression(784, 10, device)
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
    self.loss_fn = nn.CrossEntropyLoss().to(device)
    self.losses = []

  def connect(self, server: Server):
    server.append_(self)

  # method to reconstruct the model from weight update
  def receive_update(self, update: torch.Tensor):
    self.model = vector_to_model(update, self.model)

  # method to locally train the data for local_iter round
  def local_train(self) -> None:
    iter = 0
    for batch_idx, (image, label) in enumerate(self.data_loader):
      self.optimizer.zero_grad()
      label = label.to(self.device)
      output = self.model(image).to(self.device)
      loss = self.loss_fn(output, label)
      loss.backward()
      self.optimizer.step()
      iter += 1
      if iter == self.local_iter:
        local_loss = loss.item()
        self.losses.append(local_loss)
        break

  def send_local_model(self)->torch.Tensor:
    return model_to_vector(self.model)


Server class

There is one central server. As mentioned above, the central server connect to a list of clients, and have a method to collect the updates from the client and make an aggregation. In this example, we simply do the update as average of all the local updates.

Server has one special parameters rounds, which indicates how many rounds do we want the training to last for. If we train for small number of rounds, the model quality may not be very good, while if we train for too long, it wastes the computing resource and the model might get overfitted.

from tqdm import tqdm
class Server:
  def __init__(self, rounds, device="cuda"):
    self.device = device
    self.model = LogisiticRegression(784, 10, device)
    self.clients = []
    self.rounds = rounds

  def append_(self, client: Client):
    self.clients.append(client)

  def train_round_(self):
    weights = []
    for client in self.clients:
      client.local_train()
      weights.append(client.send_local_model())
    new_update = self.aggregate(weights)
    for client in self.clients:
      client.receive_update(new_update)
    return new_update

  def aggregate(self, weights: torch.Tensor):
    return torch.mean(torch.stack(weights), dim=0)

  def train(self):
    for i in tqdm(range(self.rounds)):
      last_update= self.train_round_()
    self.model = vector_to_model(last_update, self.model)


server = Server(100)
for i in range(10):
  client = Client(client_id = i, local_batch_size = 32, local_iteration = 4)
  client.connect(server)
server.train()
100%|██████████| 100/100 [00:21<00:00,  4.76it/s]

Let’s plot out the loss behavior. In Machine Learning, the loss indicates how “close” the predicted value and the targeted value is in the training dataset. So if the loss decreases over time, it is a good indication showing the training is successful.

import matplotlib.pyplot as plt
# plot clients loss progress
for client in server.clients:
  plt.plot(client.losses)
plt.title("Clients Loss Progress")
plt.xlabel("Local Iterations")
plt.ylabel("Loss")
plt.show()

png

Test model performance

Now we first try to pick some sample and see if the model training is successful. We load the test dataset of MNIST, which the model has not been trained on. If the model can predict images on the test dataset, it shows the generalize performance of the model is good.

import random
# visualize an example
test_mnist = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(dataset=test_mnist, batch_size=1, shuffle=False)


index = random.randint(0, len(test_mnist))
image, label = test_mnist[index]
prediction = server.model(image.to("cuda")).argmax().item()
print(f"Label: {label}, prediction: {prediction}")
plt.imshow(image.squeeze(), cmap='gray')
plt.show()
Label: 4, prediction: 4

png

Accuracy

Now we will compute how many percent of the test dataset does the model predict correctly with the accuracy function.


def accuracy(model):
  model.eval()
  correct = 0
  total = 0
  for images, labels in test_loader:
    output = model(images.to("cuda"))
    labels = labels.to("cuda")
    _, predicted = torch.max(output.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
  return 100 * correct / total
acc = accuracy(server.model)
print("Final model testing accuracy: {}%".format(acc))
Final model testing accuracy: 91.59%

There are many ways we can make the model performance better/worse. We can change the set of parameters that mentioned above: local training iteration at client, batch size at client, the shared model between client and server, the total number of training rounds for Federated Learning.

–> Try to tweak all of the number and come up with a new model that can beat the performance we have at this notebook. <–


Leave a Comment