Next Article in Journal
Automatic Vertebral Rotation Angle Measurement of 3D Vertebrae Based on an Improved Transformer Network
Previous Article in Journal
How the Post-Data Severity Converts Testing Results into Evidence for or against Pertinent Inferential Claims
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

FedTKD: A Trustworthy Heterogeneous Federated Learning Based on Adaptive Knowledge Distillation

1
School of Computer Science and Technology, China University of Petroleum (East China), Qingdao 266580, China
2
CSIRO’Data61, Sydney 2015, Australia
3
School of Computer Science, Southwest Petroleum University, Chengdu 610500, China
4
School of Software, Tiangong University, Tianjin 300387, China
5
School of Computer Science and Engineering, Nanyang Technological University, Singapore 639798, Singapore
*
Authors to whom correspondence should be addressed.
Entropy 2024, 26(1), 96; https://doi.org/10.3390/e26010096
Submission received: 18 December 2023 / Revised: 18 January 2024 / Accepted: 19 January 2024 / Published: 22 January 2024

Abstract

:
Federated learning allows multiple parties to train models while jointly protecting user privacy. However, traditional federated learning requires each client to have the same model structure to fuse the global model. In real-world scenarios, each client may need to develop personalized models based on its environment, making it difficult to perform federated learning in a heterogeneous model environment. Some knowledge distillation methods address the problem of heterogeneous model fusion to some extent. However, these methods assume that each client is trustworthy. Some clients may produce malicious or low-quality knowledge, making it difficult to aggregate trustworthy knowledge in a heterogeneous environment. To address these challenges, we propose a trustworthy heterogeneous federated learning framework (FedTKD) to achieve client identification and trustworthy knowledge fusion. Firstly, we propose a malicious client identification method based on client logit features, which can exclude malicious information in fusing global logit. Then, we propose a selectivity knowledge fusion method to achieve high-quality global logit computation. Additionally, we propose an adaptive knowledge distillation method to improve the accuracy of knowledge transfer from the server side to the client side. Finally, we design different attack and data distribution scenarios to validate our method. The experiment shows that our method outperforms the baseline methods, showing stable performance in all attack scenarios and achieving an accuracy improvement of 2% to 3% in different data distributions.

1. Introduction

As deep learning continues to advance, many organizations use artificial intelligence technology to optimize their digital management processes. Nevertheless, while artificial intelligence brings undeniable benefits, it has ushered in a new set of challenges. First and foremost, the widespread adoption of cross-industry data analysis methodologies has transformed data analysis into a collaborative, multidisciplinary endeavor. This shift means that data analysis no longer confines itself to a single industry but rather spans various sectors. Consequently, the need to pool data resources from diverse domains has become increasingly apparent. This collaborative approach extends to constructing model training datasets, necessitating cooperation among institutions within the same industry and across different sectors. One solution to addressing these challenges is multi-party data sharing. However, implementing this approach introduces its own set of issues. Traditional data analysis practices entail the collection of data from multiple sources, centralizing it on a server for analysis. Unfortunately, this centralized data collection method poses a significant risk to data privacy. Furthermore, in certain sectors such as healthcare and finance, where data often contains sensitive personal or proprietary information, direct industry data sharing becomes infeasible. Consequently, the need to achieve multi-party federated data analysis while preserving data privacy has become an urgent concern. Simultaneously, the demand for complex tasks has driven organizations to train larger network models with an increased number of parameters. These large, highly accurate models have successfully addressed challenges in various domains. However, resource constraints and associated costs have made it challenging for some participants to engage in the training of such complex models. This predicament has placed the spotlight on addressing model training for resource-constrained participants.
To tackle the data privacy conundrum on mobile devices, Google pioneered federated learning technology [1]. Federated learning is a distributed computing method that solves data privacy issues by exchanging model parameters between mobile devices and servers instead of sharing raw data. Google also proposed the FedAvg method [1], a global fusion method that calculates the client weights based on the ratio of client samples to the total samples. However, in practical scenarios, this method faces two significant challenges. On the one hand, the inherent diversity in devices and data sources makes it difficult to achieve convergence in the fused global model when using the average weighting method employed by FedAvg. On the other hand, frequent model transfers between the server and clients result in increased communication costs, which can be unacceptable in resource-constrained wireless and mobile network environments. Therefore, some research [2,3,4] has concentrated on enhancing model convergence speed and optimizing communication costs in federated learning. Despite these efforts to address model convergence in federated learning, most of these methods assume a uniform network model and structure across all clients. In practice, different clients may design their own network models to suit their computational environments, giving rise to the “heterogeneous model fusion” challenge in federated learning. Conventional federated learning algorithms are ill-equipped to resolve this challenge.
To address the constraints imposed by computationally limited devices, Hinton et al. introduced knowledge distillation [5]. This technique guides the training of a small model (student model) on resource-constrained devices by first pre-training a complex network model (teacher model) with high performance. Knowledge is then transferred from the teacher model to the student model, enabling the latter to perform better. Notably, knowledge distillation permits the teacher and student models to possess different network structures, effectively solving the problem of knowledge transfer between heterogeneous networks. Leveraging the advantages of knowledge distillation, researchers have explored its integration with federated learning, which brings the series of federated knowledge distillation methods. Federated knowledge distillation, a two-step knowledge distillation process encompassing server-side and client-side knowledge distillation, has emerged as a solution to the heterogeneous model fusion problem in federated learning. However, the combination also brings some new challenges.
On the one hand, in the server-side knowledge distillation phase, each client model assumes the role of a teacher, while the server-side model serves as the student. Allocating the fusion weight of each student is the key issue. Some research endeavors implement multi-teacher knowledge distillation by maintaining fixed teacher weights. Moreover, certain approaches employ rotating teacher selection [6] to facilitate knowledge distillation. These methods address the challenge of knowledge fusion in heterogeneous models but do not entirely resolve the issue of knowledge distillation in federated learning. Federated learning, characterized by dynamic changes in the client model with each communication round, necessitates a method that can dynamically select the teacher model based on the quality of the client model. Furthermore, each federated learning communication round involves distributing server-side knowledge to individual clients, posing challenges in leveraging global knowledge to guide the training of each client’s private model.
On the other hand, we cannot share clients’ private data in the federated learning environment, so training a teacher model while preserving client privacy remains the new issue. Current research in this direction can be broadly categorized into two main approaches: the public dataset method [7,8,9] and the method based on Generative Adversarial Networks (GANs) [10,11,12]. Nevertheless, the GAN-based method requires clients to possess significant computational resources, and GAN training is a time-consuming process, limiting its accessibility to some participants. The public dataset method, too, faces its own set of challenges. It assumes the trustworthiness of all participating clients, but in reality, some clients may exhibit malicious behavior. These clients may intentionally manipulate the outputs of their local models on public datasets. Incorporating malicious information into global knowledge fusion can significantly impact overall accuracy. Additionally, variations in each client’s computational power and data quality result in differences in the quality of information output by client models on public datasets. Some algorithms (e.g., FedMD [7]) employ an average weighting method to merge this information, leading to compromised global information performance. While current approaches, such as public datasets and GAN-based methods, offer partial solutions to the problem of heterogeneous federated learning, they presume the trustworthiness of all federated learning participants. However, some clients may produce malicious or low-quality knowledge in real-world environments, making achieving trusted knowledge aggregation in heterogeneous settings challenging. For instance, in medical image applications, certain hospitals have designed personalized models tailored to their specific computational environments. These hospitals seek to reuse their original models and collaboratively train a more accurate model based on their existing models. Before embarking on the federated task, these hospitals must annotate their respective sample data to prepare the training dataset. Medical dataset annotation is a complex task that demands domain expertise and significant time investment. Annotators’ domain knowledge directly impacts the quality of image annotations. Errors in annotation may result in discrepancies in the knowledge output by each hospital’s model. Additionally, in an attempt to reduce costs or subvert the co-trained models, some individuals may intentionally upload entirely random knowledge or employ randomly labeled samples for model training, generating malicious knowledge. Utilizing such malicious or low-quality knowledge for global knowledge fusion can jeopardize the entire federated task.
These challenges make building a framework for Trustworthy AI increasingly imperative. In efforts to promote Trustworthy AI applications, Lu et al. have explored various domains, including software development processes [13] and pattern design [14]. Trustworthy federated learning technology plays a pivotal role in safeguarding privacy and enabling collaborative learning for Trustworthy AI. To actualize Trustworthy AI, Chen et al. [15] have devised a computational framework for trustworthy federated learning, ensuring the security of the entire federated learning process.
In response to these challenges, we propose a knowledge distillation-based approach that utilizes public data to address the problem of heterogeneous model fusion in federated learning. Our objectives encompass (1) achieving trustworthy global information fusion in heterogeneous federated learning environments and (2) enhancing the accuracy of knowledge transfer between servers and clients through knowledge distillation techniques. To attain these goals, we present a trustworthy federated learning framework grounded in knowledge distillation. This framework incorporates client information verification on the server side to prevent malicious clients from participating in information fusion. Additionally, we introduce a model training methodology based on two-stage knowledge refinement to elevate the accuracy of client models.
The main contributions of this paper are as follows:
  • We design a trustworthy federated learning framework in a heterogeneous environment, which can realize client identification and trustworthy knowledge fusion.
  • We propose a malicious client identification method based on client logit features. To the best of our knowledge, we are the first to propose a malicious client identification method based on logit information in heterogeneous federated learning.
  • We propose a trustworthy global logit computation method, which ensures the accuracy of global logit information by training the model on the server side. At the same time, it can adaptively fuse the logit information of each client to realize a high-quality fusion of logit information.
  • We propose an adaptive knowledge distillation method, which can adaptively select the parameters of knowledge distillation according to the confidence of the server-side global logit, and the method can improve the accuracy of the client model.
  • We design different attack scenarios to verify the reliability of our methods. Meanwhile, we compare five baseline algorithms under different data distribution scenarios to verify the performance of our approach.

2. Related Work

2.1. Federated Learning

The FedAvg method, initially proposed by Google [1], computes fusion weights for each client model based on the ratio of client samples to the total. However, due to device and data heterogeneity, FedAvg may lead to slow global model convergence. Therefore, several research focuses on improving convergence problem, such as the FedProx algorithm [3], which introduced a loss function that corrects the parameters of the clients’ models that deviate from the global model, thus accelerating the convergence of the global model. Karimireddy et al. also proposed the Scaffold algorithm [2], which corrects the client drift problem by introducing control variables. Wang et al. proposed the FedNova algorithm [16]. This method requires each client to upload normalized model parameters, and then the server calculates the average parameters, thus solving the global model fusion problem. Li et al. proposed the MOON algorithm [17], which solves the global model convergence problem by introducing the contrast loss of the model.
Although these methods solve the global model convergence problem, they require clients to use the same structural model, so they cannot work in a heterogeneous model environment, making it challenging to perform federated learning in a heterogeneous environment.

2.2. Federated Knowledge Distillation

Knowledge distillation can transfer knowledge between heterogeneous networks. Some research has applied knowledge distillation to federated learning scenarios to solve the fusion problem of heterogeneous models. However, we cannot share participants’ private data in the federated learning environment. How to train teacher models while guaranteeing privacy becomes an urgent problem. We can categorize this research into three primary categories: public dataset-based, calculating category information, and using data-generated approaches.
The public dataset-based approach accomplishes client-side and server-side knowledge distillation by constructing a public dataset between the server and the client. For example, Li et al. proposed FedMD [7], which supports using different models for each client and accomplishes heterogeneous model knowledge distillation through the public dataset. Lin et al. proposed the FedDF method, which aggregates the average soft-label information uploaded by the clients to the server side and then computes the average soft-label information sent to the client. Jiang et al. proposed the FedDistill method [9], which builds a personalized model for each client and transfers the knowledge from the global model to the personalized model. Some researchers use the category information of the client’s private dataset to transfer knowledge. For example, the FedDistill+ method [18] requires each client to calculate the average logit information of each category, and the server-side fuses this information to form the global logit information of each category and finally guides the client model training through this information. Chan et al. proposed the FedHe method, which also aggregates the logit information of private data to achieve knowledge aggregation and then uses this knowledge to guide the model training of each client. Chen et al. also proposed FedHKD [19], which sends the client’s data representations and corresponding soft predictions to the server, and the server aggregates this information to achieve knowledge merging. Generative model-based methods mainly solve the knowledge fusion problem by training GAN networks to generate client data samples. For example, Zhu et al. proposed the FedGEN [10] method to achieve client-side model aggregation by generating a lightweight data generator on the server side. Such as, Zhang et al. proposed the FedFTG method [11] to transfer knowledge from local models to global models by exploring the input space of local models through generators. The FedDTG method [12] implements knowledge fusion by training GAN networks to generate pseudo-samples.
Although the GAN method solves the problem of client data protection, training GAN requires the client to have high computational power, and the training process of GAN is very time-consuming. Therefore, some resource-limited clients cannot use this method. Although the public dataset approach can use fewer resources to train teacher models, the current method cannot guarantee trustworthy federated learning, which needs an evaluation step for client information. The private data-based approach does not require the construction of additional datasets. However, the method uses the average knowledge of each category to guide the client model training, which is only effective on very simple data sets. Therefore, achieving a credible heterogeneous federated learning algorithm is important while maintaining the efficiency and accuracy of knowledge calculation.

2.3. Trustworthy Federated Learning

Current research aims to achieve credible federated learning in three main categories: client identification, model parameter-based identification, and model parameter pruning method. The client selection method prevents malicious clients from participating in fusion by identifying their behavior. For example, Li et al. proposed a reliable federated learning framework [20], which uses a spectral anomaly detection method on the server side to realize the identification of the malicious client. Chen et al. also proposed a reinforcement learning method to adaptively select the trusted clients for model aggregation [21]. Model parameter identification-based methods mainly distinguish malicious clients by calculating the Euclidean distance between models, such as Krum and Mutil-Krum [22]. In addition, Chen et al. [23] also proposed a model fusion feature approach to identify malicious clients. Model parameter pruning-based methods achieve secure global model fusion by pruning the outlier parameters of malicious models. These methods include TrimmedMean [24], ClippedClustering [25], and Centered Clipping [26] methods.
Although these methods address the problem of trustworthy federated learning, most research is based on evaluating clients’ model accuracy or detecting models’ parameters. However, each client may have a different model structure in the heterogeneous federated learning scenario. Therefore, these methods cannot solve the problem of heterogeneous model detection. Some federated learning algorithms use model logit on datasets in heterogeneous environments for global information fusion. Therefore, using the logit information to detect malicious information becomes the key to solving this problem. Zhang et al. proposed a credible federated learning framework (RobustFL) [27], which realizes the identification of malicious clients by constructing a predictive model based on logit on the server side. Wang et al. also proposed a method based on the distribution of logit to determine the distribution of malicious samples [28].
However, these methods encounter the cold-start problem because they must collect specific malicious logit information and train a predicted model before detecting malicious information. The malicious model outputs different information in each communication round. We cannot train a recognition model in every communication round, which will affect the progress of federated learning. Therefore, we need a real-time dynamic detection method for malicious message identification.

3. Method

3.1. Problem Definition

In real federated learning environments, there are differences in computing resources and computing environments across clients, leading to clients needing to design their respective models according to their environments and problems. Therefore, each client may use a different network structure, known as the model heterogeneity problem. Traditional federated learning algorithms (e.g., FedAvg, FedPorx, etc.) require clients to use the same model and structure. These methods ensure that the parameters of each layer of the model are consistent, so the server-side computes the global model parameters by fusion of the model parameters of each layer. However, these methods cannot be used in heterogeneous federated learning. To describe the heterogeneous model problem, we will describe the critical processes of heterogeneous model fusion in federated learning and the problems in the process. The knowledge distillation method is a common method to solve the information transfer of heterogeneous networks. Therefore, we also use the knowledge distillation method to solve the fusion problem of heterogeneous models in federated learning.

3.1.1. The Processes of Federated Knowledge Distillation

Traditional federated learning approaches accomplish information transfer by passing models between the server and the client. In a federated heterogeneous modeling environment, the structure of each client’s model is different, which makes it unable to use the traditional approach. To complete the information transfer between the client and the server side, we exchange information through the logit information output from the model. The method is divided into six steps, as shown in Figure 1, where the key processes are described as follows.
(1) Train Local Model: The client trains the local model on the private dataset and then uses the local model to output the logit information on the private or public dataset.
(2) Output Logit Information: Different federated knowledge distillation methods require different logit information output. For example, public dataset-based knowledge distillation methods need to share a public dataset between server and client, and these methods require each client to upload each sample’s logit of the public dataset. However, private dataset-based methods only need to upload the logit information of the private dataset of each client (e.g., the average logit information of the client-side categories).
(3) Global Logit Information Fusion: After receiving the logit from each client, the server performs the global logit information fusion for this round. Then, the server sends down the global logit to each client.
(4) Model Training Based on Knowledge Distillation Method: Each client receives the global logit information and uses the knowledge distillation method to train the local model. Each client treats the global logit information as the teacher model output information and the local model as the student model.
The entire federated task repeats these steps until the specified communication rounds are completed.

3.1.2. The Problem of Heterogeneous Federated Learning

Although some federated knowledge distillation methods solve the heterogeneous model fusion problem, these methods assume that the logit information output by each client is trustworthy. In practical scenarios, a malicious client-side may tamper with the logit information to launch an attack behavior. Therefore, the following three problems are encountered during federated heterogeneous model fusion.
(1) Client Identification Problem: Some clients may intentionally upload malicious logit information to launch attacks. If the server uses this malicious information for global information fusion, this will seriously affect the accuracy of global information.
We use the CIFAR-10 dataset as an example to describe the problem, as shown in Figure 2. The CIFAR-10 is a 10-categorical dataset, and the logit output by each client for the samples in this dataset is a vector of probabilities of length 10, with each element of the vector representing the probability of belonging to a certain category. We assume that there are N clients involved in the fusion of global information, and the N-th client is malicious. Each client outputs logit information for the same sample data, and the N-th client outputs a malicious logit. If the server side directly adopts the average fusion method for information fusion, this will lead to the final prediction result being the cat. However, the actual sample’s category is a dog.
To solve this problem, we must design a client identification method to filter out malicious information to participate in the global information fusion. We assume that n clients participate in the global information fusion. We define C i as the ith client, whereas the set of n clients is defined as { C 1 , C 2 , , C n } . Then, the problem is shown in Equation (1). The S e l e c t T r u s t w o r t h y C l i e n t ( . ) is a method that selects the trustworthy client from all the clients.
{ C 1 , C 2 , , C k } SelectTrustworthyClient ( { C 1 , C 2 , , C n } )
(2) Global Knowledge Fusion Problem: When the server identifies the malicious client, the next step of the server is to use the selected trustworthy client for global information fusion. Due to the problem of computational resources and data distribution of the clients, this leads to quality differences in the accuracy of the output logit information of each client model, e.g., some clients may output incorrect or low-quality logit. The traditional federated heterogeneous model methods (FedMD, FedDistill) directly use the average weight fusion method to fuse the logit information, and these methods are unreasonable. To solve this problem, we must design an adaptive information fusion method to achieve high-quality global information fusion. We define the logit information output by the ith client as L i , and then the logit set output by all clients is defined as { L 1 , L 2 , , L k } . The problem is shown in Equation (2).
L global FusionGlobalLogit ( { L 1 , L 2 , , L k } )
(3) Client Knowledge Distillation Problem: In each communication round of federated learning, the server sends the global logit information to each client. How to use this global knowledge to train the local model is also a problem. We also use the knowledge distillation approach to solve this problem, where we treat the global logit information as the teacher model and the client’s model as the student model. We use the knowledge distillation method to transfer global information from the teacher to the student model. However, in real scenarios, the information output by the teacher model may also have errors. In that case, this will lead to the client using the wrong knowledge to guide the training of the local model, so we need to design an adaptive knowledge distillation method to solve this problem. This method can select the correct knowledge to guide the student model. We define M i as a model of the ith client, and the set of models of clients is denoted as { M 1 , M 2 , , M k } . Then, we define the problem as Equation (3).
{ M 1 , M 2 , , M k } AdaptiveKD ( L global )
The AdaptiveKD(.) is the adaptive knowledge distillation method, which can adaptively transfer the server-side global knowledge to the client models.
In summary, To accomplish trustworthy federated learning in a heterogeneous environment, we need to solve the following problems.
(1) Client Identification Problem: How to identify a malicious client based on logit information. We must exclude malicious clients from global knowledge fusion.
(2) Global Knowledge Fusion Problem: How to adaptively select clients based on the logit information to achieve high-quality global knowledge fusion.
(3) Global Knowledge Migration Problem: How to transfer global information from the server to the client and use the global knowledge to guide the client model training.
In this paper, we address the problem of federated heterogeneous model fusion by sharing a labeled dataset between the server and each client. The following four sections, Section 3.2, Section 3.3, Section 3.4 and Section 3.5, will detail how to solve these problems.

3.2. Malicious Client Identification Method Based on Logit Feature

Client identification stands as a foundational step in realizing trustworthy federated learning. This process involves identifying and excluding malicious clients, thus preventing their logit information from being incorporated into global information fusion. Our approach shares a public dataset between the server and each client to facilitate knowledge fusion within a heterogeneous federated learning framework. The global knowledge is derived by amalgamating the logit information produced by each client based on this public dataset.
Drawing on the method of [29], we also train a model at the server and use the model to output each sample’s logit for the public dataset. Since the server-side logit information is independent of the client model, this ensures that the server-side logit information is trustworthy. We can compare the logit information of each client and server to identify the type of client. The process consists of four stages, as shown in Figure 3.
Before describing these stages, we define the following elements. We define the set of clients as { C 1 , C 2 , , C k } , and the public dataset is denoted as D 0 ; it has a total of m samples and n categories. We define s j as the j-th sample in the public dataset D 0 , and l i j as the logit corresponding to the sample s j of client C i , then the set of logits on the public dataset output by the i-th client is { l i 1 , l i 2 , , l i m } . Similarly, we define the set of logit output from the server side as { l s 1 , l s 2 , , l s m } .
Stage 1 (Preprocessing of Logit Information): we first preprocess the logit information from each client and server, which mainly includes two steps: (1) logit information grouping and (2) vectorizing each subgroup logit.
(1) Logit information grouping: When the server receives the logit information uploaded by the client, it first groups the logit information according to its category. The process is shown in Equation (4).
The G r o u p B y C a t e g o r y ( . ) denotes grouping sample logit according to the sample category, where G i j is the j-th subgroup of the i-th client, and each client has n subgroups.
{ G i 1 , G i 2 , , G i n } GroupLogitByCategory ( { l i 1 , l i 2 , , l i m } )
At the same time, we also need to group the server-side logit information, and we define the set of the subgroup of the server as G s = { G s 1 , G s 2 , , G s n } .
(2) Vectorize sub-grouping logit: After finishing logit grouping, we need to merge logit within each sub-grouping into a one-dimensional vector. We define v i j as j-th vector of j-th subgroup, and then n subgroups correspond to n vectors. The process is shown in Equation (5).
{ v i 1 , v i 2 , , v i n } VectorLogit ( { G i 1 , G i 2 , , G i n } )
V e c t o r L o g i t ( . ) is denoted as a vectorized subgroup. We define the set of vectors for the i-th client as V i , where V i = { v i 1 , v i 2 , , v i n } .
Repeating steps (1) and (2), we compute the set of vectors for each client and server, respectively.
Stage 2 (Logit Feature Extraction): In the previous phase, we extracted the client and server logit vectors, and in this stage, we need to compute the logit feature values for each client based on the vector information. Since we also output the logit information independently on the server side, this can guarantee that the server-side logit information is trustworthy. Therefore, we can identify the client type by comparing the similarity between the client-side and server-side vectors.
To ensure trustworthy federated learning, we employ cosine similarity as a key metric to compare the server’s and clients’ vectors. Cosine similarity values range from 1 to 1, indicating the degree of similarity between two vectors. A value close to 1 signifies that the vectors are nearly identical, whereas a value approaching 1 indicates that the vectors are opposed. In this context, the eigenvectors of a malicious client are expected to differ significantly from those on the server side. Conversely, a trustworthy client’s vectors will consistently align closely with the server’s. This characteristic variance in cosine similarity values provides a reliable means to identify trustworthy clients.
We assume that there are n categories in the public dataset, and for the j-th category, we define the j-th vector of the i-th client as v i j and the vector of the server as v s j . We compute the client cosine similarity value using Equation (6). Then, we compute the cosine similarity values for all categories on the client side in turn.
c s i j = cosim ( v i j , v s j ) = v i j · v s j v i j · v s j , j [ 1 , n ]
The set of all categories of the i-th client is expressed as { c s i 1 , c s i 1 , , c s i m } . We take the cosine similarity set of each client as the client’s feature value.
Stage 3 (Clustering based on client features): In this stage, we use the client features extracted in the previous step and group the clients into two clusters using the clustering method, as shown in Equation (7).
{ G r o u p 1 , G r o u p 2 } Cluster ( { C S 1 , C S 2 , , C S k } )
represents the clustering method. G 1 and G 2 represent two class clusters, respectively.
Stage 4 (Verify Clients): In the previous stage, we divided the clients into two groups. However, we do not know which group is normal, so we need to verify these two groups. We use the accuracy of the clients in each group on the public dataset to determine whether they are credible clients. The group with the highest accuracy is the trusted client group, and the lower group is the malicious client group. The process is shown in Equation (8).
{ A c c 1 , A c c 2 } ValidateGroupAcc ( { G r o u p 1 , G r o u p 2 } )
Although this step can identify most malicious clients, some malicious clients (e.g., noise attack clients) can escape the method. To further exclude malicious clients, we need to verify the accuracy of the clients in the normal group. We first validate the accuracy of each client on the public dataset and then calculate the average accuracy of those clients. Finally, we calculate the difference between the client and average accuracy. We also exclude clients from participating in global logit calculations when their accuracy is below the average value and the difference exceeds a set threshold. The process is shown in Algorithm 1.

3.3. Global Information Fusion Methods

Once we have identified the malicious clients, our next task is to use these trustworthy clients’ logit to fuse the global logit. Most federated knowledge distillation methods (such as FedMD and FedDF) use average weight methods to fuse the global logit, which calculates the global logit value by the average of all clients’ logit value, and each client is given the same weight. However, in real scenarios, each client’s data distribution and computational resources are different, which leads to differences in the quality of the logit for the same sample. At the same time, some clients may output an error or a low-quality logit. Using the error logit to fuse the global logit will affect the accuracy of the global information. Therefore, we need to filter out the error information to improve the performance of global information. To accomplish this goal, we design a trustworthy logit calculation method that calculates the global logit in two parts. The flow of the method is shown in Figure 4. The flow of the algorithm is shown in Algorithm 2.
Algorithm 1 Client identification algorithms
Input: 
Each client’s Logit: L = { L 1 , L 2 , , L n } , Public Data: D 0
Output: 
Trustworthy Clients C = { C 1 , C 2 , , C K }
1:
/* Center Node Process */
2:
for Round t from 1 to T do
3:
    Init Trustworthy Client List T C _ L i s t = { } // Initialise client features List
4:
     M s TrainServerModel ( D 0 )
5:
     { l s 1 , l s 2 , , l s m } OutPutLogit ( M s , D 0 ) // Output logit on the Public Dataset
6:
     { G s 1 , G s 2 , , G s n } GroupLogitByCategories ( { l s 1 , l s 2 , , l s m } )
7:
     { v s 1 , v s 2 , , v s n } VectorLogitEachCategories ( { G s 1 , G s 2 , , G s n } )
8:
    for Client C i from 1 to K do
9:
          { G i 1 , G i 2 , , G i n } GroupLogitByCategories ( { l i 1 , l i 2 , , l i n } )
10:
         { v i 1 , v i 2 , , v i n } VectorLogitEachCategories ( { G i 1 , G i 2 , , G i n } )
11:
         { c s i 1 , c s i 2 , , c s i n } CalculateCosineValue ( { G i 1 , G i 2 , , G i n } )
12:
         C S i GetClientFeature ( { c s i 1 , c s i 2 , , c s i n } )
13:
     { Group 1 , Group 2 } Cluster ( { C S i , C S 2 , , C S k } ) // Clustering the clients
14:
     { A c c 1 , A c c 2 } ValidateACC ( G r o u p 1 , G r o u p 2 ) // Validation accuracy
15:
     U C _ L i s t AddMaxAccGroup ( { C a , C b , , C k } )
16:
     A a v g CalculateAvgAcc ( { A 1 , A 2 , , A k } )
17:
    for client C j from 1 to t in UC_List do
18:
        if  A i A a v g > ϵ  then
19:
            T C _ L i s t AddClientToList ( A i )
20:
Output Trustworthy Clients T C = { C a , C b , , C K }
In our method, we train a model on the server side using the public dataset, and we use the server model to output the logit of each sample of the public dataset. After we obtain all the samples’ logit, we divide the logit into two groups according to whether the predicted result of the logit is consistent with the actual labels of the sample. We divide the logit into the correctly predicted and the incorrectly predicted groups. When the server model has incorrect predictions, it means that the server side cannot predict this information, which requires the integration of logit from the clients to obtain the correct value [29]. We calculate the logit in two parts according to the grouping of the samples, and the calculation process is shown in Figure 4.
We define ( x i , y i ) as the i-th sample and the label, and the D 0 as the public dataset. Suppose D 0 is a K-categorical dataset with m samples. The set of all samples is denoted as D 0 = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , , ( x n , y n ) } , y i [ 1 , K ] . We define l i m as the m-th sample logit output by the i-th client and define L i as the set of the logit of the public dataset output by client C i . The set is denoted as L i = { l i 1 , l i 2 , , l i m } .
We define l s m as the m-th sample’s logit output by the server model and p s ( x m ) as the predicted label of the server model for the sample x m . Then, we calculate the global logit for the sample x m based on the results of the server-side prediction. We define l g l o b a l m as the global logit of the m-th sample. The FusionLogit ( { l 1 m , l 2 m , , L k m } ) represents fusing k clients’ logit to calculate the new global logit for the sample x m . We use Equation (9) to calculate all the sample logit in turn. We will introduce how to calculate the logit in the following.
l g l o b a l m = l s m , p s x m = y . FusionLogit ( { l m 1 , l m 2 , , l m k } ) , p s x m y .
Part 1: If the server-side output logit is consistent with the actual category of labels, we directly use the server-side model output logit as part of the global logit.
Part 2: If the server-side output logit is wrong, we will obtain the logit from the client’s logit. In this condition, we can regard the global logit calculation problem as a multi-teacher logit fusion problem. Each client’s logit is the output of a teacher model, and our goal is to obtain the global logit from those teachers’ logit. However, the quality of logit output from each teacher model is different. To adaptively fuse multiple teachers’ logit information, we need to assign higher weights to high-quality information while assigning lower weights to low-quality information. We draw on the paper’s method [30] to calculate the weight value. This calculation process consists of three steps.
(1) The Calculation of Fusion Weight: We first calculate the cross-entropy loss for each client, where k denotes the k-th client. This loss can reflect the confidence level of each client in the sample [30]. The calculation process is shown in Equation (10).
L C E k = 1 C y c log ( S o f t m a x ( L c k , T ) )
where S o f t m a x ( L c k , T ) denotes the calculation of L c k at temperature T using the softmax function. Then, we calculate the weight of each client according to Equation (11).
w k = 1 K 1 1 exp ( L C E k ) 1 j exp L C E j
Unlike [30], we calculate each client’s weight for each category according to logit categories. We first group each client’s logit categories on the public dataset’s output and then use Equation (11) to calculate the client’s fusion weight. Eventually, we obtain a weight matrix where each row represents each category of logit, and each column represents each client’s weight. The weight matrix is shown in Equation (12).
W ( C , Y ) = w 11 w 12 w 1 k w m 1 w m 2 w m k
(2) Re-adjust Client Weight: After determining each client’s weight for each category, it is essential to address the issue of incorrect logit outputs. Some clients may produce an error logit for a sample, which necessitates the removal of these incorrect logits from the fusion process. When computing the logit for a sample, if an incorrect logit is identified and removed, the fusion weight assigned to that logit should consequently be reduced to zero. This adjustment then requires recalibrating the fusion weights for the remaining clients to maintain the balance in the model’s overall learning. To recalibrate these weights, we employ L1-normalization, which effectively redistributes the weights among the remaining clients. The p k ( x ) is the k-th client prediction, and y is the label of a sample x. The calculation process of the new weight is shown in Equation (13).
w k = 0 , p k ( x ) y . w j j = 1 N w j , p k ( x ) = y .
where w j is the L1-normalization for all the correct logits.
(3) New Logit Calculation: Having obtained the revised weights in the previous stage, we now proceed to recalculate the logit for each sample using these new weights in conjunction with the logit information from each client. The process for this new logit calculation is outlined in Equation (14).
L new = 1 K w k L n k
When we obtain the logit of all the samples in part 1 and part 2, we merge these two to form the global logit for this communication round. Then, the server sends this global logit down to each client. The process is shown in Algorithm 2.
Algorithm 2 The Calculation of Glogit Logit Information
Input: 
Client Logit List: L = { L 1 , L 2 , , L k } . Public Dataset D 0
Output: 
Global Logit L g = { l 1 , l 2 , , l m }
1:
/*Public Dataset Logit Calculate Process*/
2:
CrroctList = { } // Initialise IDs of Correctly predicted sample Logit
3:
ErrorCrroctList = { } // Initialise IDs of Error predicted sample Logit
4:
NewCrroctList = { } // Initialise IDs of new calculate sample Logit
5:
M s TrainServerModel ( D 0 )
6:
{ l 1 , l 2 , , l m } OutPutLogit ( M s , D 0 ) // OutPut Logit on the Public Dataset
7:
W ( C , Y ) CalculateWeightMatirc ( { L 1 , L 2 , , L k } )
8:
/**The Calculation of Logit in Part-1**/
9:
for Logit l s from 1 to m in Public Dataset do
10:
    if  l s = l a b e l  then
11:
         C r r o c t L i s t AddLogit ( l s )
12:
    else
13:
         E r r o r L i s t AddLogit ( l s )
14:
/**The Calculation of Logit in Part-2**/
15:
for Logit l s from 1 to n in ErrorList do
16:
    Init S u m L o g i t s = 0 // Initialise the Logit value for this sample
17:
    for Client C i from 1 to C k  do
18:
         l s GetLogitFromClient ( C i )
19:
         w s GetWeightFromWeightMatric ( C i )
20:
         S u m L o g i t s = S u m L o g i t s + w s · l s
21:
     l s = S u m L o g i t s
22:
     N e w C o r e c t L i s t AddNewLogit ( l s )
23:
L g = MergeTwoPartLogit ( { C r r o c t L i s t , N e w L i s t } )
24:
Output Global logit L g

3.4. Calculation of Weighting Parameters Based on Teacher Logit Confidence Level

Having calculated the global logit information, the server will distribute the global logit to each client. In this phase, clients use the global logit as the output of a teacher model, training their local models on the public dataset via knowledge distillation. Traditional knowledge distillation involves a combined loss function, consisting of cross-entropy loss ( L C E ) and distillation loss ( L K D ), with a balance weight α . The total loss ( L t o t a l ) used to optimize the model is shown in Equation (15):
L t o t a l = α L C E + ( 1 α ) L K D
The values of α are fixed in the traditional knowledge distillation method throughout the model training process. However, in the heterogeneous federated learning knowledge fusion process, the data distribution and computational resources are different across clients, leading to quality differences in logit. When we use these client logits to compute the global logit, the accuracy of the global logit is dynamically changing, which will make the accuracy of the global logit different for each category in the same communication round [31]. To address this, an adaptive knowledge distillation method is required, one that can adjust weight parameters in line with sample categories to enhance client model accuracy. Thus, we propose a redefined equation for L t o t a l , as shown in Equation (16), where w k represents the weight for the k-th category.
L t o t a l = ( 1 w k ) L C E + w k L K D
Next, we calculate each category’s w k value. The value is determined by the accuracy of the global logit on the public dataset, which means that the higher the accuracy of the global logit for a category, the higher the global logit confidence value in that category. If the global logit has a low accuracy for one category, this means that the global logit is not credible in that category. To prevent the global logit’s incorrect knowledge from passing to the client model, we need to set w k to 0. Therefore, we can calculate the w k by evaluating the confidence value of the global logit. We borrow the calculation method from [32]. We divide the weight calculation process into three steps, and this process is shown in Figure 5.
Step 1 (Marginal Value Calculation): We need first to calculate the marginal value of each category, which can measure the confidence value of the teacher model in each category in the public dataset, and the larger the marginal value, the higher the prediction accuracy of the teacher model for that category. We define M k ( k , p t ( x ) ) as the marginal value of the sample x. The p t ( x ) is the probability that the teacher model predicts a sample x, where k is the actual category of the sample, and k is the other category. The L is the total number of categories. We obtain the marginal value by calculating the difference between the probability of predicting the sample as category k and the probability of predicting the sample as any other category. The calculation process is shown in Equation (17).
M k ( k , p t ( x ) ) = p k t ( x ) 1 L 1 k k p k t ( x )
Step 2 (Confidence Value Calculation): After calculating each sample’s marginal value, we first sum up the marginal values of each category, and then we calculate the average value of the marginal values of each category. We define C v a l u e k as the confidence value of the k-th category. The calculation process of C v a l u e k is shown in Equation (18).
C v a l u e k = E x k [ M k ( k , p k t ( x ) ) ]
Step 3 (Weight Calculation): When we calculate the value of the C v a l u e k , we can get the value of the w k based on this value. Notably, as the accuracy of global logit typically increases over communication rounds, the confidence value of the global logit for each category will be close to 1 after a certain number of rounds. At this point, if we directly use this value ( C v a l u e k = 1 ) to determine the weight value ( w k = 1 ), it will make the weight value of L C E become 0 ( 1 w k = 0 ), leading to ignoring the loss of the actual label of the sample. However, we aim to use the two losses fully to complete the model training. Therefore, we added the parameter β to limit the value of C v a l u e k , which prevents the value from being too large and leads to the problem of ignoring the L C E during the training process. The process is shown in Equation (19).
w k = 0 , C v a l u e k 0 . ( 1 β ) · C v a l u e k , C v a l u e k > 0 .
When C v a l u e k 0 , it indicates that the global logit has low confidence in the category k. We directly ignore the teacher’s knowledge in this situation. When C v a l u e k > 0 , it indicates that the global logit has higher confidence in the class. We set the value of β according to different data distribution scenarios, and we will describe the setting of this value in the subsequent experimental section.

3.5. Client Model Training Based on Adaptive Knowledge Distillation

This phase completes the client model training, which consists of two steps: (1) training the client model on the public dataset using adaptive knowledge distillation and (2) model retraining on the private dataset. The process is shown in Figure 6, and the algorithm flow is shown in Algorithm 3.
(1) Train Model on Public Dataset by Adaptive Knowledge Distillation: We regard the global logit as the output of the teacher model and the client model as the student model. Then, we use the knowledge distillation method to complete the training of the client model. We define the predicted value of the teacher’s model for the sample x i as p t ( x i ) . We define l k as the logit of a sample x i . We calculate the predicted value at sample temperature T by the softmax function. The calculation is shown in Equation (20).
p k t ( x ) = e l k ( x ) / T j = 1 K e l j ( x ) / T
We define L C E as the cross-entropy loss of the client. We can calculate the cross-entropy loss by client model prediction with the label y, as shown in Equation (21).
L C E = 1 N n = 1 N k = 1 K y k log ( p k ( x ) )
We define the knowledge distillation loss as L K D . The knowledge distillation loss L K D is obtained by calculating the K L loss between the teacher and student models. The calculation is shown in Equation (22).
L K D = 1 N n = 1 N k = 1 K p k t ( x i ) log ( p k ( x ) )
Before each communication round starts, the server sends the global logit and weight list to each client, and the client selects the w k adaptively according to the sample category. We define the set of weights for communication round t as { w 1 , w 2 , , w k } . Each client completes the computation of the loss of each category according to the category weights w k . The calculation process is shown in Equation (23).
L t o t a l = ( 1 w k ) L C E + w k L K D
Finally, each client uses the public dataset and the adaptive knowledge distillation method to train the local model.
(2) Model training using Private Dataset: In this stage, clients continue to train their local models on their private datasets using the model obtained from the previous phase. We use cross-entropy loss during model training at this node to optimize the model. We define the cross-entropy loss as L p r i v a t e _ C E and the client model prediction as p k ( x ) , where x is the private sample of the client, and k is the sample category. Then, we obtain L p r i v a t e _ C E by calculating the model prediction with the actual labels of y. The process is shown in Equation (24).
L p r i v a t e _ C E = 1 N n = 1 N k = 1 K y k log ( p k ( x ) )
Algorithm 3 Train Client Model
Input: 
Global Logit: L g , Each categories weight: W = { w 1 , w 2 w k }
Output: 
Trained Client Model M = { M 1 , M 2 M n }
1:
/* Each Client Training Step */
2:
for  C i from 1 to C n  do
3:
    /* Step 1:Train Client Model on Public Dataset */
4:
     M 0 InitClientModel ( r o u n d ) // Initiate the model with the previous model
5:
     p r e PredValueOnPublicData ( M 0 )
6:
     L C E GetCELoss ( y p r e , y l a b e l ) //Calculate CE Loss on Public Dataset
7:
     L K D GetKDLoss ( { s o f t m a x ( y p r e ) } , L g ) //Calculate KD Loss on Public Dataset
8:
     w k SelectKDWeightByCategories ( { w 1 , w 2 w k } )
9:
     L t o t a l = ( 1 w k ) L C E + w k L K D
10:
     M i TrainClientModel ( L t o t a l )
11:
    /* Step 2:Train Client Model on Private Dataset */
12:
     y p r e PredValueOnPrivateData ( M i )
13:
     L p r v i a t e _ C E GetCELoss ( y p r e , y l a b e l ) //Calculate CE Loss on Private Dataset
14:
     M i TrainClientModel ( L p r v i a t e _ C E )
15:
Output Model of Clients = { M 1 , M 2 , , M n }

4. System Design

In addressing the challenges of trustworthy federated learning in heterogeneous environments, we design the FedTKD framework. This framework includes two core processes: client model training through adaptive knowledge distillation and trustworthy information fusion on the server side. The server is responsible for client identification and the fusion of global information, while the client focuses on model training and logit information output. The process is shown in Figure 7, and the algorithm flow detailed in Algorithm 4.
(1) Initialization Phase: The server distributes a public dataset to each client. This dataset is sent only once at the beginning.
(2) Output Client Logit: Clients train their local models using both public and private datasets. Then, each client outputs a logit for each sample in the public dataset using the local model.
(3) Upload Logit Information: After generating logits for the public dataset, clients upload this logit information to the server. We have designed an efficient logit information storage format, as shown in Equation (25), where each sample in the public dataset corresponds to a unique piece of information.
{ I n d e x , L o g i t [ c l a s s _ 1 _ v a l u e , c l a s s _ 2 _ v a l u e , , c l a s s _ n _ v a l u e ] }
(4) Global Logit Information Fusion: Upon receiving logit from all clients in a communication round, the server identifies the type of each client and selects the information from trustworthy clients to perform global logit fusion. This is a three-step process:
Step 1 (Identify Client): The server uses Algorithm 1 to distinguish between malicious and trustworthy clients.
Step 2 (Trustworthy Logit Fusion): The server calculates new global logit information using Algorithm 2.
Step 3 (Category Weight Calculation): After global logit calculation, the server computes the weight value for each category based on confidence levels.
(5) Distribute Global Logit and Weight List: The server sends the latest global logit information and category weight list to each client for the new round.
(6) Train Client Model: This stage is crucial for transferring knowledge from the global knowledge base to each client model. It includes two steps.
Step 1 (Model Training Based on Adaptive Knowledge Distillation): Clients receive the global logit and weight list and use this information to train the local model on the public dataset.
Step 2 (Client Model Retraining): The client uses the model from the previous phase and continues to train the local model on the private dataset.
Algorithm 4 The Workflow of FedTKD
Input: 
Client Model: M = { M 1 , M 2 , M n } , Private Dataset: D = { D 1 , D 2 , D n } , Public Dataset: D 0 , Communication round: T.
Output: 
Trained client model M = { M 1 , M 2 , , M n }
1:
/* Client Process */
2:
for round t from 1 to T do
3:
    for  C k from 1 to C n  do
4:
        /* Step 1: Train Client Model On Public Data */
5:
         M k InitClientModel ( t ) // Load Previous Model
6:
         L g t RecieveGlobalLogit ( L g t )
7:
         M k 1 TrainModelByAKD ( D 0 , M k 1 , L g t ) // According to Algorithm 3
8:
        /* Step 2: Train Client Model On Private Data */
9:
         M k 2 TrainLocalModel ( D k , M k )
10:
        /* Step 3: Output Logit On Public Dataset */
11:
         { l k 1 , l k 2 , , l k m } OutPutLogit ( M k 2 , D 0 )
12:
         S e r v e r SendLogitToServer ( L k )
13:
/* Server Process */
14:
for round t from 1 to T do
15:
    for  C i from 1 to C n  do
16:
         { L 1 , L 2 , , L n } ReceiveClientLogit ( { C 1 , C 2 , , C n } )
17:
     { C 1 , C 2 C k } IdentifyClientType ( { L 1 , L 2 , , L n } ) // Algorithm 1
18:
     L g t CalculateGlobalLogit ( { L 1 , L 2 , , L k } ) // Algorithm 2
19:
     { w 1 , w 2 , , w k } CalculateCategoryWeight ( L g t )
20:
     { C 1 , C 2 , , C k } SendLogitAndWeight ( L g t , W ) //Send to Clients
21:
    Train Client Model According to global logit L g t and weight list W.
22:
Output Client Model List: M = { M 1 , M 2 , , M n }

5. Experiment

5.1. Experiment Setup

5.1.1. Experiment Datasets

Dataset Description: Our experiment utilizes three categorical datasets: MNIST, Fashion-MNIST, and CIFAR-10, each dataset is described as follows:
MNIST: A dataset for handwritten digit classification consisting of 60,000 training and 10,000 test samples. Each sample is a 28 × 28 grayscale image representing digits from 0 to 9.
Fashion-MNIST: It is a ten-category dataset. This dataset comprises 60,000 training and 10,000 test images, each a single-channel grayscale image.
CIFAR-10: A more diverse dataset with ten categories containing 50,000 training and 10,000 test samples. Each sample is a 32 × 32 color image.
Dataset Preprocessing: Our algorithm necessitates the construction of a shared public dataset between the server and client sides, alongside private datasets for each client. We achieve this goal through the following preprocessing steps:
(1) Public Dataset Part: A portion of the original dataset is reserved as the public dataset. We ensure a uniform representation by selecting 10% of data from each category. For instance, in CIFAR-10’s ten categories, 500 samples per category are allocated to the public dataset.
(2) Private Dataset Part: To mimic the Non-IID data distribution of each client, We control client data distribution by adjusting the Dirichlet function’s parameter α . A lower α value leads to a more Non-IID distribution, while a higher value approximates an IID (independently identically distributed) scenario. This distribution is visualized in Figure 8, where each dot’s size represents the sample count, and its color indicates the sample category.

5.1.2. Baseline

We compared our method with five baseline algorithms. These include three methods (FedMD, FedDF, and FedDistill), which require sharing public datasets between the server and clients, and two methods (FedDistill+ and FedHe), which do not require public datasets.
FedMD [7]: This method solves the heterogeneous model knowledge fusion problem by constructing a public dataset. The server fuses the logit information of each client with the average weight method, and the client uses the knowledge distillation method to train the client model on the public dataset and then migrates the client model to the private dataset for further training.
FedDF [8]: This method needs to construct an auxiliary dataset, and the server side integrates the logit information of each client to the global logit. Each client treats the global logit as the teacher’s model knowledge to guide the local model training.
FedDistill [9]: This method needs to build a public dataset. The approach applies knowledge distillation methods to federated learning, which uses global logit as a teacher to guide the local model training of each client.
FedDistill+ [18]: This method does not require a public dataset. Each client first calculates the average logit of each category, and the server side aggregates the logit of each category and calculates the average value. Finally, each client uses the average logit as a regularization loss to guide the local model training.
FedHe [33]: This method does not require a public dataset. It also uses the average logit information of each category as a loss to assist client model training.

5.1.3. Metrics

We evaluate the performance of each algorithm by three key metrics, described below.
(1) Average Client Model Accuracy on the Test Dataset: This metric assesses the algorithm’s effectiveness in enhancing the accuracy of the client’s model. Higher average accuracy indicates better performance across all clients. We assume that K clients participate in federated learning with m communication rounds. We denote the accuracy of the i-th client in the t-th communication round as A t i . Then, we define A a v g t as the average client model accuracy for the t-th round, and we can calculate the average accuracy of K clients for t-th rounds according to Equation (26).
A a v g t = 1 K i = 1 K A i t , t [ 1 , m ]
Then, we define A a v g as the set for all communication rounds, and the set is defined as A a v g = { A a v g 1 , A a v g 2 , , A a v g m } .
(2) Accuracy of Global Logit on the Public Dataset: This metric reflects the server’s capability to amalgamate knowledge from each client. When the server calculates the global logit at t-th communication rounds, we can test the global logit’s accuracy on the public dataset. We define A g l o b a l t as the accuracy of global logit. The m communication round set is as shown in Equation (27).
A g o b a l = { A g o b a l 1 , A g o b a l 2 , , A g o b a l m }
(3) Accuracy of Client Logit on the Public Dataset: When clients finish training the local model, each client outputs the logit of the public dataset. We get this metric by calculating the accuracy of the client’s logit on the public dataset. We define A i t as the accuracy of the i-th client’s logit on the public dataset at the t-th round. Then, the set of all the clients at the t-th round is shown in Equation (28).
A t = { A 1 t , A 2 t , , A k t }

5.1.4. Heterogeneous Network Setup

To simulate a heterogeneous federated learning environment, we employed CNN networks as our foundational architecture, customizing them to have varying structures. We categorized the clients into five distinct groups, with each group utilizing a different network structure but maintaining consistency within the group.
Client Network Setup: For the MNIST and Fashion-MNIST datasets, we designed five pairs of CNN networks for processing single-channel images. For the CIFAR-10 dataset, we designed five CNN networks processing three channels.
Server Network Setup: It is crucial that the parameter of the server model is more complex than that of the client-side model. This ensures the server’s model can fully integrate knowledge from each client model.
We designed experiments with different numbers of clients (10, 15, and 20, respectively) participating in the federation task, and the client network configurations in each experiment are shown in Table 1. Taking the CIFAR-10 dataset as an example, we elaborate on the parameter settings of both client- and server-side network models. Each network is characterized by different convolutional layers and structures. For instance, Conv-X denotes the convolutional module, and C 2 d ( a , b ) means the 2D convolution. F C is the fully connected module, and L i n e ( a , b ) denotes the linear layer.

5.1.5. Hardware and Software Environment

Hardware Setup: Our experiments were conducted on a high-performance workstation equipped with an Intel i9-12900K CPU, 64GB of RAM, and an NVIDIA RTX 3090 graphics card.
Software Setup: We developed a federated learning framework (FedBolt) to implement key functionalities such as partitioning client datasets, customizing client training models, simulating various client attacks, and facilitating multi-client co-training.

5.2. Experimental Results and Analysis

Our experiments are structured into four parts to validate the performance and reliability of our algorithm: malicious client identification, attack experiments, different data distribution experiments, and framework performance analysis. Notably, the malicious client attack experiments are crucial in evaluating our algorithm’s robustness. We also compare our algorithm against other baselines in both attack and normal scenarios.

5.2.1. Malicious Client Attack Experiment

To verify the performance of each algorithm under different attack scenarios, we designed three attack-type scenarios. The details of each attack scenario are as follows.
Label Flipping Attack (Type-1): In the classification task, each sample’s logit information is a vector of K values, and the maximum of these values is the classification result of the model. We select this maximum value and swap it to a random location. We follow this strategy to replace 50% of the number of samples’ logit to form a label-flipping attack.
Noisy Data Attack (Type-2): We add a certain percentage of noisy data (e.g., Gaussian noise) to the client’s private dataset to train the local model, reducing the client’s model’s accuracy. Then, we use the low-quality model to output the logit in the public dataset, which generates inaccurate logits of the public dataset.
SecondMax Attack (Type-3): We tamper with the logit information corresponding to the sample. We first find the position of the largest value in the logit, then randomly select 50 % of the remaining positions, and modify the value of these positions to (MaxValue − 0.00001 ). For example, in a ten-classification task, the element at position 9 is the maximum value ( 3.56789 ). We modify the values of the elements at positions 0, 1, 5, 6, and 8 to ( 3.56789 0.00001 ).
Federated Task Setup: Our experiments involved scenarios with 10, 15, and 20 clients participating in a federated task with 100 communication rounds. At the same time, each client performs one epoch of local model training. We divide the clients into two groups: malicious clients and normal clients. Client IDs 1, 3, 5, 7, 9, 11, 13, 15, 17, and 19 are normal clients, and client IDs 2, 4, 6, 8, 10, 12, 14, 16, 18, and 20 are malicious clients, where the malicious client group performs the specified attack method.
Dataset Setup: In all three attack experiments, clients were set up with independently homogeneous data distributions (IID). For instance, in a 10-client experiment, each client has 6000 samples for MNIST and Fashion-MNIST. Each client has 5000 samples for CIFAR-10.
In particular, in the noise attack experiments, we set different ratios of noise data for the malicious client groups, in which for the MNIST dataset, the noise data ratios of clients 2, 4, 6, 8, 10, 12, 14, 16, 18, 20 are 0.91 ,   0.92 ,   0.93 ,   0.94 ,   0.95 ,   0.91 ,   0.92 ,   0.93 ,   0.94 ,   0.95 , respectively. For the Fasion-MNIST dataset, the noise data ratios of malicious clients are 0.75 ,   0.8 ,   0.85 ,   0.9 ,   0.95 ,   0.75 ,   0.8 ,   0.85 ,   0.9 ,   0.95 , respectively. For the CIFAR-10 dataset, they are 0.5 ,   0.55 ,   0.6 ,   0.65 ,   0.7 ,   0.5 ,   0.55 ,   0.6 ,   0.65 ,   0.7 , respectively.
To evaluate the performance of the client identification module, we process logit information in the following three steps:
(1) Vector Conversion: We gathered the logit information from the client and server. Then, we group the logit for each category and transform each subgroup into a one-dimensional vector.
(2) Cosine Similarity Calculation: For each category vector, we calculated the cosine similarity between the client’s category vectors and the server’s corresponding category vectors. This similarity measurement serves as a pivotal metric in our evaluation.
(3) Feature Value Calculation: We regard the cosine similarity values for the various categories as the feature values for each client.
We take the CIFAR-10 dataset as an example and apply this methodology to this dataset. We conduct the experiment on the three specified attack scenarios. The feature values for each client under these scenarios are illustrated in Figure 9. This figure reveals notable disparities in feature values between normal and malicious clients, particularly in Type-1 and Type-3 attack scenarios. In contrast, the Type-2 (noisy data attack) scenario demonstrates a variation in feature values across certain categories, though not as pronounced as in the other types.
After computing the feature values for each category of each client, we proceed with the following steps:
(1) Client Clustering: We take the cosine similarity value of each category as the feature values of each client. Then, we employ the K-means method to cluster clients into two subgroups.
(2) Accuracy Verification on Public Data: For each resulting cluster, we assess the accuracy of the logit on the public dataset. The subgroup exhibiting the highest accuracy is deemed the ’trusted client list,’ while the one with lower accuracy is classified as comprising malicious clients.
We follow the steps above for clustering analysis, and the clustering outcomes are presented in Figure 10. In Attack Types 1 and 3, our method successfully identifies malicious clients from the first communication round. In later rounds, the difference between the feature values of normal and malicious clients will become larger and larger, and our method can more easily separate the client types. However, in Attack Type 2, our method requires up to six communication rounds to categorize the clients accurately. The initial rounds are marked by relatively low logit accuracy, which obscures the differences between normal and malicious clients. As a result, some malicious clients may initially be misclassified as normal. To mitigate this issue, we need to verify client logit accuracy further, as detailed in the fourth stage of our method. This additional step is crucial for ensuring accurate client categorization in scenarios where initial data may not be distinct enough for immediate classification.

5.2.2. Experimental Analysis of Different Attack Scenarios

This experiment mainly evaluates our algorithm’s robustness across various attack scenarios. We benchmark our algorithm against five baseline methods, using two key performance metrics: the accuracy of the global logit on the public dataset and the average client model’s accuracy.
We tested the algorithms under predefined attack scenarios, adhering to our experimental configuration. As the FedDistill+ and FedHe methods do not necessitate a public dataset, our comparison of the accuracy of the global logitis is limited to the three algorithms (FedMD, FedDF, FedDistill). The experimental results are shown in Table 2. The experiment shows that our algorithms consistently outperform others across three datasets.
Notably, in Attack Types 1 and 2, the global logit accuracy of FedMD, FedDF, and FedDistill exhibited a significant decline. This observation suggests that these algorithms may not be as effective in mitigating the impact of Attack Types 1 and 2.
To illustrate the influence of the malicious client’s logit accuracy on the overall global logit, we meticulously tracked the logit accuracy of each client across all communication rounds. The results are shown in Figure 11. We analyze the impact of different attack scenarios on each algorithm as follows:
Attack Type 1: In this scenario, malicious clients significantly compromised the global logit accuracy of methods such as FedMD, FedDF, and FedDistill. The primary reason for this degradation is that these methods do not filter out malicious clients during the fusion process of the global logit.
Attack Type 2: In this scenario, FedMD, FedDF, and FedDistill methods use a simplistic weighted calculation for fusing global logit. Consequently, they fail to discern low-quality model outputs, leading to a reduction in logit accuracy.
Attack Type 3: The FedMD and FedDF methods are particularly vulnerable to tampered logit information, resulting in diminished global knowledge accuracy.
Contrastingly, our method consistently maintains high accuracy across all three attack scenarios. This resilience stems from our method’s capability to exclude malicious information from the fusion process and our server’s function to detect and ensure the fusion of high-quality logit information.
Furthermore, we evaluated the impact of different algorithms on average client accuracy. To assess how malicious clients affect federated tasks, we also introduce the scenarios without malicious clients (denoted as ’normal type’). These results are detailed in Table 3. our algorithm demonstrates superior accuracy on both the Fashion-MNIST and CIFAR-10 datasets. On the MNIST dataset, our algorithm slightly trails behind the FedHe algorithm in Attack Type 1 and FedDistill in Attack Type 3. This outcome can be attributed to the relative simplicity of the MNIST dataset, where methods like FedDistill+ and FedHe, which regularize categories, tend to achieve higher accuracy.
We conducted a comparative analysis of different algorithms to evaluate their performance under various attack scenarios. we take the CIFAR-10 as an example, and the results are shown in Figure 12. Notably, algorithm performance varies significantly depending on the type of attack:
Attack Type 1: Both FedDistill+ and FedHe methods experienced a notable decrease in accuracy. This suggests that these methods may be more vulnerable or less equipped to handle the specific challenges posed by this attack type.
Attack Type 2: The FedMD and FedDF methods showed a marked decline in performance. This indicates that these methods, while possibly effective in other scenarios, struggle to maintain accuracy under the conditions of Attack Type 2.
To illustrate how the accuracy of different algorithms fluctuates across communication rounds, we conducted an in-depth analysis using the CIFAR-10 dataset. We calculated the average client model accuracy at each communication round, as shown in Figure 13.
In Attack Types 1 and 3, the accuracy of the FedMD and FedDF algorithms showed a significant decrease with increasing communication rounds. This decline is attributed to their reliance on a simple weighted average method for global logit fusion. Such a method proves ineffective in excluding the erroneous logit information from malicious clients, thereby diminishing the overall accuracy of the global logit. As a result, clients utilizing this compromised information for local model training experience a reduction in their model’s accuracy.
Our algorithm consistently achieves the highest accuracy rates across all three attack scenarios. The strength of our method lies in its ability to prevent malicious information from influencing the model fusion process. Additionally, it empowers the server to fuse high-quality knowledge from the client selectively. This capability ensures that our approach remains stable and effective under various attack scenarios.

5.2.3. Experimental Analysis of Different Data Distribution Scenarios

In this section, we focus on assessing each algorithm’s performance in different data distribution scenarios. The experimental setup is as follows:
Federated Task Configuration: We set different numbers of clients (10, 15, and 20) to participate in the federated task with 100 communication rounds. Each client completed one training epoch, while the server conducted two on the public dataset.
Network Model Configuration: We follow Table 1 of the previous section to set up the client-side and server-side models in each scenario.
Dataset Division: We divided each dataset using the Dirichlet method of the previous subsection. This involved creating three Non-IID data scenarios ( α = { 0.5 ,   1 ,   10 } ) and one IID data distribution scenario.
Following these specified configurations, the experiments were conducted separately for each data distribution scenario. The accuracy results of each algorithm under these conditions are compiled and presented in Table 4.
The experiment demonstrates that our algorithms consistently achieve the highest accuracy across all data distribution scenarios on the three datasets. Notably, the performance of the FedMD and FedDF algorithms exhibits a decline in Non-IID scenarios ( α = { 0.5 ,   1 } ) , a trend that is particularly pronounced in the CIFAR-10 dataset.
To further explore the relationship between client and global logit accuracy, we meticulously tracked and analyzed the accuracy of each client and the global logit for every communication round, as shown in Figure 14. The accuracy of the FedMD and FedDF algorithms diminishes as the disparity in Non-IID data distribution increases. This trend underscores that the performance of algorithms relying on average weighting methods, such as FedMD and FedDF, is significantly hindered in Non-IID scenarios. In contrast, our algorithm and FedDistill demonstrate robustness against the variation in data distribution, maintaining consistent accuracy irrespective of the data scenario.
To verify the performance of each algorithm under different data distribution scenarios, we also calculated the average client model accuracy metrics, and the experimental results are shown in Table 5. The experiment results show that our algorithms obtain the highest values on both Fashion-MNIST and CIFAR-10 datasets. The FedDistill+ and FedHe algorithms show a serious drop in accuracy in Non-IID ( α = { 0.5 ,   1 } ) scenarios, which indicates that these two algorithms do not work stably in Non-IID scenarios.
To compare the accuracy of each algorithm on different rounds, we take the CIFAR-10 dataset as an example and compare the accuracy of five algorithms under different data distribution scenarios. The results are shown in Figure 15. The experiments show that our algorithms obtain the highest accuracy under different data scenarios.
We utilized the CIFAR-10 dataset to assess the effectiveness of the adaptive knowledge distillation method. This involved monitoring the weight values assigned to each category over each communication round. We adjusted the hyperparameters for different data distribution scenarios for optimal performance. Specifically, in Non-IID scenarios ( α = { 0.5 ,   1 } ), we set the β parameter to 0.8. For the Non-IID scenario ( α = 10 ) and the IID scenario, we set the β parameter to 0.75.
The weight values of each category for different data scenarios are shown in Figure 16. The results show that the weight values of each scenario no longer change after 15 communication rounds. This is because when our communication rounds reach a certain number of rounds, our algorithm has fully integrated the knowledge of each client. At the same time, the prediction of the samples of each category in the public dataset reaches the best accuracy, so after 15 rounds, the weight values of each category no longer change.

5.2.4. Performance Validation of the FedTKD Framework

We evaluated the computational efficiency of four main modules within our federated learning framework. These modules are Client-Side Identification, Server-Side Model Training, Global Logit Information Computation, and Category Weight Computation.
Experimental Setup: We set the number of communication rounds in the federated task to 50 for this evaluation. We meticulously recorded the computation time of each module across every communication round under each attack scenario. These results are presented in Figure 17.
Experimental Analysis: The experimental result reveals that the computation times of each round for each module are generally consistent, with only occasional variations. These inconsistencies are attributed to the high-performance workstations used for the experiments, which ran other tasks concurrently. This multitasking environment led to fluctuating resource allocation, resulting in sporadic peaks in computation time.
To gain a more comprehensive understanding of computational efficiency, we also calculated the average computation time for each module throughout the entire federated task. The detailed results are tabulated in Table 6. We differentiate between ‘total time 1’ and ‘total time 2’. The ‘total time 1’ is the aggregate time for all modules. The ‘total time 2’ is the cumulative time of all modules, excluding the server-side training model module.
The experiment shows that the number of clients does not influence the training model and logit calculation computation times. The training model times depend on network parameters, epochs, and the number of samples. The logit calculation time depends on the number of samples in the public dataset. The computation time of these two modules does not increase with the number of nodes because the factors affecting these modules are independent of the number of nodes. The time required for client identification and weight calculation increases with the number of clients. For server-side computational efficiency, the ‘total time 2’ is more relevant as server-side model training is independent of client-side tasks. We assume the server has sufficient computational resources to pre-train the model before receiving the client’s logit. Thus, the ‘total time 2’ can reflect the efficiency of our algorithm in processing client logit messages. Notably, in scenarios with 20 clients, our framework can complete essential steps in under 10 s, which is crucial for the real-time detection of client type.

6. Conclusions

To address the issue of trustworthy computing and information fusion in heterogeneous federated environments, we have designed a trustworthy federated learning framework (FedTKD). This framework encompasses several key functions. Firstly, we proposed a client recognition method based on logit features to exclude malicious clients from participating in global information fusion. Additionally, we proposed a reliable global logit fusion method to ensure high-quality information fusion. Finally, we proposed an adaptive knowledge distillation method to improve the accuracy of knowledge transfer from the server side to the client side. We conducted experiments in various attack scenarios, including logit tag flipping and low-quality information fusion scenarios, to assess the reliability and robustness of FedTKD. Furthermore, we have evaluated the algorithm’s performance under different data distribution scenarios and compared it with five baseline algorithms. The results demonstrate that our algorithms outperform others in various circumstances.
However, it is important to note that our approach requires a common dataset shared between the server and clients, which should include samples of all types. In practical scenarios, creating such a comprehensive common dataset can be challenging. Therefore, our future research will investigate how to achieve trustworthy heterogeneous federated learning when only a subset of sample types is available. We also plan to explore a wider range of client attacks to expand the applicability of our method. Additionally, we intend to deploy FedTKD in a distributed environment to assess and optimize its real-world performance.

Author Contributions

Conceptualization, L.C., W.Z. and C.W.T.; methodology, L.C., S.Q., D.Z. and X.Z.; software, L.C., C.D. and X.Z.; validation, L.C., C.D., X.Z. and S.Q.; formal analysis, L.C., D.Z., S.Q. and X.Z.; data curation, L.C., C.D. and Y.Z.; writing—original draft preparation, L.C., D.Z. and C.W.T.; writing—review and editing, L.C., C.W.T. and W.Z.; project administration, W.Z. and C.W.T. visualization, L.C., C.D. and Y.Z.; supervision, C.W.T.; funding acquisition, L.C., W.Z. and C.W.T. All authors have read and agreed to the published version of the manuscript.

Funding

The work is supported by the Singapore Ministry of Education (AcRF Tier 1 RG91/22 and NTU startup fund), the National Natural Science Foundation of China (No. 62072469), and the China Scholarship Council (No. 202206450035).

Institutional Review Board Statement

Not applicable.

Informed Consent Statement

Not applicable.

Data Availability Statement

Data are contained within the article.

Conflicts of Interest

The authors declare no conflicts of interest.

References

  1. McMahan, B.; Moore, E.; Ramage, D.; Hampson, S.; y Arcas, B.A. Communication-efficient learning of deep networks from decentralized data. In Proceedings of the Artificial Intelligence and Statistics, Fort Lauderdale, FL, USA, 20–22 April 2017; pp. 1273–1282. [Google Scholar]
  2. Karimireddy, S.P.; Kale, S.; Mohri, M.; Reddi, S.; Stich, S.; Suresh, A.T. Scaffold: Stochastic controlled averaging for federated learning. In Proceedings of the International Conference on Machine Learning, Virtual, 13–18 July 2020; pp. 5132–5143. [Google Scholar]
  3. Li, T.; Sahu, A.K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; Smith, V. Federated optimization in heterogeneous networks. Proc. Mach. Learn. Syst. 2020, 2, 429–450. [Google Scholar]
  4. Xie, C.; Koyejo, S.; Gupta, I. Asynchronous federated optimization. arXiv 2019, arXiv:1903.03934. [Google Scholar]
  5. Hinton, G.; Vinyals, O.; Dean, J. Distilling the knowledge in a neural network. arXiv 2015, arXiv:1503.02531. [Google Scholar]
  6. Fukuda, T.; Suzuki, M.; Kurata, G.; Thomas, S.; Cui, J.; Ramabhadran, B. Efficient Knowledge Distillation from an Ensemble of Teachers. In Proceedings of the Interspeech, Stockholm, Sweden, 20–24 August 2017; pp. 3697–3701. [Google Scholar]
  7. Li, D.; Wang, J. Fedmd: Heterogenous federated learning via model distillation. arXiv 2019, arXiv:1910.03581. [Google Scholar]
  8. Lin, T.; Kong, L.; Stich, S.U.; Jaggi, M. Ensemble distillation for robust model fusion in federated learning. Adv. Neural Inf. Process. Syst. 2020, 33, 2351–2363. [Google Scholar]
  9. Jiang, D.; Shan, C.; Zhang, Z. Federated learning algorithm based on knowledge distillation. In Proceedings of the 2020 International Conference on Artificial Intelligence and Computer Engineering (ICAICE), Beijing, China, 23–25 October 2020; pp. 163–167. [Google Scholar]
  10. Zhu, Z.; Hong, J.; Zhou, J. Data-free knowledge distillation for heterogeneous federated learning. In Proceedings of the International Conference on Machine Learning, Virtual, 18–24 July 2021; pp. 12878–12889. [Google Scholar]
  11. Zhang, L.; Shen, L.; Ding, L.; Tao, D.; Duan, L.Y. Fine-tuning global model via data-free knowledge distillation for non-iid federated learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, New Orleans, LA, USA, 18–24 June 2022; pp. 10174–10183. [Google Scholar]
  12. Zhang, Z.; Shen, T.; Zhang, J.; Wu, C. Feddtg: Federated data-free knowledge distillation via three-player generative adversarial networks. arXiv 2022, arXiv:2201.03169. [Google Scholar]
  13. Lu, Q.; Zhu, L.; Xu, X.; Whittle, J.; Xing, Z. Towards a roadmap on software engineering for responsible AI. In Proceedings of the 1st International Conference on AI Engineering: Software Engineering for AI, Pittsburgh, PA, USA, 16–17 May 2022; pp. 101–112. [Google Scholar]
  14. Lu, Q.; Zhu, L.; Xu, X.; Whittle, J. Responsible-AI-by-design: A pattern collection for designing responsible AI systems. IEEE Softw. 2023, 40, 63–71. [Google Scholar] [CrossRef]
  15. Chen, L.; Zhang, W.; Xu, L.; Zeng, X.; Lu, Q.; Zhao, H.; Chen, B.; Wang, X. A Federated Parallel Data Platform for Trustworthy AI. In Proceedings of the 2021 IEEE 1st International Conference on Digital Twins and Parallel Intelligence (DTPI), Beijing, China, 15 July–15 August 2021; pp. 344–347. [Google Scholar]
  16. Wang, J.; Liu, Q.; Liang, H.; Joshi, G.; Poor, H.V. Tackling the objective inconsistency problem in heterogeneous federated optimization. Adv. Neural Inf. Process. Syst. 2020, 33, 7611–7623. [Google Scholar]
  17. Li, Q.; He, B.; Song, D. Model-contrastive federated learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Nashville, TN, USA, 19–25 June 2021; pp. 10713–10722. [Google Scholar]
  18. Seo, H.; Park, J.; Oh, S.; Bennis, M.; Kim, S.L. Federated Knowledge Distillation. In Machine Learning and Wireless Communications; Cambridge University Press: Cambridge, UK, 2022; p. 457. [Google Scholar]
  19. Chen, H.; Vikalo, H. The Best of Both Worlds: Accurate Global and Personalized Models through Federated Learning with Data-Free Hyper-Knowledge Distillation. arXiv 2023, arXiv:2301.08968. [Google Scholar]
  20. Li, S.; Cheng, Y.; Wang, W.; Liu, Y.; Chen, T. Learning to detect malicious clients for robust federated learning. arXiv 2020, arXiv:2002.00211. [Google Scholar]
  21. Chen, L.; Dong, C.; Qiao, S.; Huang, Z.; Nie, Y.; Hou, Z.; Tan, C. FedDRL: A Trustworthy Federated Learning Model Fusion Method Based on Staged Reinforcement Learning. arXiv 2023, arXiv:2307.13716. [Google Scholar]
  22. Blanchard, P.; El Mhamdi, E.M.; Guerraoui, R.; Stainer, J. Machine learning with adversaries: Byzantine tolerant gradient descent. In Proceedings of the Advances in Neural Information Processing Systems 30 (NIPS 2017), Long Beach, CA, USA, 4–9 December 2017. [Google Scholar]
  23. Chen, L.; Zhao, D.; Tao, L.; Zeng, X.; Tan, C. A Credible and Fair Federated Learning Framework Based on Blockchain. IEEE Trans. Artif. Intell. 2024, 1, 1–15. [Google Scholar] [CrossRef]
  24. Yin, D.; Chen, Y.; Kannan, R.; Bartlett, P. Byzantine-robust distributed learning: Towards optimal statistical rates. In Proceedings of the International Conference on Machine Learning, Stockholm, Sweden, 10–15 July 2018; pp. 5650–5659. [Google Scholar]
  25. Li, S.; Ngai, E.C.H.; Voigt, T. An Experimental Study of Byzantine-Robust Aggregation Schemes in Federated Learning. arXiv 2023, arXiv:2302.07173. [Google Scholar] [CrossRef]
  26. Karimireddy, S.P.; He, L.; Jaggi, M. Learning from history for byzantine robust optimization. In Proceedings of the International Conference on Machine Learning, Virtual, 18–24 July 2021; pp. 5311–5319. [Google Scholar]
  27. Zhang, J.; Ge, C.; Hu, F.; Chen, B. RobustFL: Robust federated learning against poisoning attacks in industrial IoT systems. IEEE Trans. Ind. Inform. 2021, 18, 6388–6397. [Google Scholar] [CrossRef]
  28. Wang, Y.; Xie, L.; Liu, X.; Yin, J.L.; Zheng, T. Model-agnostic adversarial example detection through logit distribution learning. In Proceedings of the 2021 IEEE International Conference on Image Processing (ICIP), Anchorage, AK, USA, 19–22 September 2021; pp. 3617–3621. [Google Scholar]
  29. Cheng, S.; Wu, J.; Xiao, Y.; Liu, Y. Fedgems: Federated learning of larger server models via selective knowledge fusion. arXiv 2021, arXiv:2110.11027. [Google Scholar]
  30. Zhang, H.; Chen, D.; Wang, C. Confidence-aware multi-teacher knowledge distillation. In Proceedings of the ICASSP 2022—2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Singapore, Singapore, 23–27 May 2022; pp. 4498–4502. [Google Scholar]
  31. He, Y.; Chen, Y.; Yang, X.; Zhang, Y.; Zeng, B. Class-wise adaptive self distillation for heterogeneous federated learning. In Proceedings of the 36th AAAI Conference on Artificial Intelligence, Virtual, 22 February–1 March 2022; Volume 22. [Google Scholar]
  32. Lukasik, M.; Bhojanapalli, S.; Menon, A.K.; Kumar, S. Teacher’s pet: Understanding and mitigating biases in distillation. arXiv 2021, arXiv:2106.10494. [Google Scholar]
  33. Chan, Y.H.; Ngai, E.C. Fedhe: Heterogeneous models and communication-efficient federated learning. In Proceedings of the 2021 17th International Conference on Mobility, Sensing and Networking (MSN), Exeter, UK, 13–15 December 2021; pp. 207–214. [Google Scholar]
Figure 1. Heterogeneous network fusion method based on logit information.
Figure 1. Heterogeneous network fusion method based on logit information.
Entropy 26 00096 g001
Figure 2. Logit attack behavior of malicious clients.
Figure 2. Logit attack behavior of malicious clients.
Entropy 26 00096 g002
Figure 3. Trustworthy client identification method based on logit feature information.
Figure 3. Trustworthy client identification method based on logit feature information.
Entropy 26 00096 g003
Figure 4. The calculation process of global logit.
Figure 4. The calculation process of global logit.
Entropy 26 00096 g004
Figure 5. Calculation of weighting parameters based on teacher logit confidence level.
Figure 5. Calculation of weighting parameters based on teacher logit confidence level.
Entropy 26 00096 g005
Figure 6. The process of client model training based on adaptive knowledge distillation.
Figure 6. The process of client model training based on adaptive knowledge distillation.
Entropy 26 00096 g006
Figure 7. The system architecture of FedTKD.
Figure 7. The system architecture of FedTKD.
Entropy 26 00096 g007
Figure 8. Different data distribution scenarios for the Cifar-10 dataset (10 clients).
Figure 8. Different data distribution scenarios for the Cifar-10 dataset (10 clients).
Entropy 26 00096 g008
Figure 9. The feature values of each client in different attack scenarios.
Figure 9. The feature values of each client in different attack scenarios.
Entropy 26 00096 g009
Figure 10. Clustering results for each client under different attack scenarios (10 clients).
Figure 10. Clustering results for each client under different attack scenarios (10 clients).
Entropy 26 00096 g010
Figure 11. The accuracy of each client in the public dataset under different attack scenarios (10 clients).
Figure 11. The accuracy of each client in the public dataset under different attack scenarios (10 clients).
Entropy 26 00096 g011
Figure 12. Average client accuracy for each algorithm in different attack scenarios.
Figure 12. Average client accuracy for each algorithm in different attack scenarios.
Entropy 26 00096 g012
Figure 13. Average client accuracy for different attack scenarios (10 clients).
Figure 13. Average client accuracy for different attack scenarios (10 clients).
Entropy 26 00096 g013
Figure 14. The accuracy of each client in a public dataset with different data distributions (10 clients).
Figure 14. The accuracy of each client in a public dataset with different data distributions (10 clients).
Entropy 26 00096 g014
Figure 15. Average client accuracy for different data distribution scenarios (10 clients).
Figure 15. Average client accuracy for different data distribution scenarios (10 clients).
Entropy 26 00096 g015
Figure 16. Weight allocation for each category in different data scenarios (10 clients).
Figure 16. Weight allocation for each category in different data scenarios (10 clients).
Entropy 26 00096 g016
Figure 17. Computation time for each model per communication round.
Figure 17. Computation time for each model per communication round.
Entropy 26 00096 g017
Table 1. Network configuration in heterogeneous environments.
Table 1. Network configuration in heterogeneous environments.
IDConv-1Conv-2Conv-3Conv-4Conv-5Conv-6FC
ServerK = 3, S = 1C2d (16, 16)C2d (16, 32)C2d (32, 64)C2d (32, 64)C2d (54, 128)Line (512, 512)
Out = 16C2d (16, 16)C2d (32, 32)C2d (64, 64)C2d (64, 64)C2d (256, 512)Line (512, 10)
1, 2, 11, 16K = 3, S = 1C2d (16, 16)C2d (16, 32)C2d (32, 64)C2d (32, 64)C2d (54, 128)Line (512, 512)
Out = 16C2d (16, 16)C2d (32, 32)C2d (64, 64)C2d (64, 64)C2d (256, 512)Line (512, 10)
3, 4, 12, 17K = 3, S = 1C2d (16, 16)C2d (16, 32)C2d (32, 64)C2d (32, 64)C2d (64, 128)Line (256, 256)
Out = 16C2d (16, 16)C2d (32, 32)C2d (64, 64)C2d (64, 64)C2d (256, 256)Line (256, 10)
5, 6, 13, 18K = 3, S = 1C2d (16, 16)C2d (16, 32)C2d (32, 64)C2d (32, 64)-Line (512, 128)
Out = 16C2d (16, 16)C2d (32, 32)C2d (64, 64)C2d (64, 64)-Line (128, 10)
7, 8, 14, 19K = 3, S = 1C2d (16, 16)C2d (16, 32)C2d (32, 64)--Line (1024, 128)
Out = 16C2d (16, 16)C2d (32, 32)C2d (64, 64)--Line (128, 10)
9, 10, 15, 20K = 3, S = 1C2d (16, 16)C2d (32, 32)C2d (64, 64)--Line (2048, 256)
Out = 16C2d (16, 16)C2d (32, 32)C2d (64, 128)--Line (256, 10)
Table 2. The global logit accuracy of each algorithm in different attack types.
Table 2. The global logit accuracy of each algorithm in different attack types.
NumberMethodMNISTFashion-MNISTCIFAR-10
Type-1Type-2Type-3Type-1Type-2Type-3Type-1Type-2Type-3
10 clientsFedMD0.540.9610.9680.5770.8940.9010.5480.6170.709
FedDF0.750.9530.9730.7060.8970.9150.410.6280.639
FedDistill0.830.9960.9960.7960.9640.9660.520.920.988
Ours0.9970.9980.9980.9960.9980.9960.9970.9960.998
15 clientsFedMD0.8460.9450.9490.6840.8670.8840.5560.6010.676
FedDF0.8350.9410.9580.7460.8680.8860.4460.5950.634
FedDistill0.8910.9790.9780.8150.9130.9110.7150.8610.992
Ours0.9880.9860.9920.9360.9390.9320.9980.9960.997
20 clientsFedMD0.8370.9360.9430.6750.8850.8690.5280.5830.663
FedDF0.8350.9410.9580.6760.8550.8690.4460.5640.601
FedDistill0.8910.9790.9780.7870.8980.9030.6110.8510.992
Ours0.9890.9890.990.9320.9370.9310.9980.9930.998
Table 3. Average client model accuracy of each algorithm in different attack scenarios.
Table 3. Average client model accuracy of each algorithm in different attack scenarios.
NumberMethodMNISTFashion-MNISTCIFAR-10
Type-1Type-2Type-3NormalType-1Type-2Type-3NormalType-1Type-2Type-3Normal
10 clientsFedMD0.9280.9450.9480.9610.8350.8410.8280.8510.5370.5170.5120.56
FedDF0.9390.9380.9530.9520.8370.8410.8470.8490.5230.5270.5210.578
FedDistill0.9520.9450.9540.9560.8480.8390.8520.8520.570.5330.5780.584
FedDistill+0.960.8690.960.9590.8510.8120.8510.8510.580.530.5760.582
FedHe0.9610.8670.6730.9610.8530.8150.6740.8520.5810.5250.5660.583
Ours0.9550.9470.9550.9580.8620.8470.8550.8580.5860.5410.5890.592
15 clientsFedMD0.9420.9360.9270.9490.8220.8190.8230.8350.5130.5010.5050.547
FedDF0.9350.9270.9410.9390.7460.8160.8260.8280.4460.4920.5030.546
FedDistill0.9420.9340.9430.9450.8230.8160.8260.8270.5520.5180.5530.557
FedDistill+0.9470.8320.9470.9470.8380.7840.840.8410.5280.4840.5180.528
FedHe0.9410.7790.9380.950.8420.7890.8340.8430.5340.4920.5280.536
Ours0.9470.9390.9460.9450.8430.8210.8420.8480.5640.5280.560.562
20 clientsFedMD0.9360.9280.9120.9420.8130.8090.8110.8280.4820.4810.4870.525
FedDF0.9230.9210.9340.9320.8150.8080.8180.8180.4640.4740.4830.527
FedDistill0.9420.9340.9390.9380.8170.8130.8210.8170.5310.5020.5410.54
FedDistill+0.9380.7870.9370.9370.8320.7640.8330.8310.4970.4570.4960.498
FedHe0.950.8290.9380.9410.8360.7680.8250.8340.5340.4920.5280.506
Ours0.9420.9340.9410.9410.8430.8210.8380.8350.5450.5320.5450.546
Table 4. The global logit accuracy of each algorithm in different data distribution scenarios.
Table 4. The global logit accuracy of each algorithm in different data distribution scenarios.
NumberMethodMNISTFashion-MNISTCIFAR-10
α = 0 . 5 α = 1 α = 10 IID α = 0 . 5 α = 1 α = 10 IID α = 0 . 5 α = 1 α = 10 IID
10 clientsFedMD0.950.9580.9610.9740.8710.8750.8840.910.5560.5760.5410.66
FedDF0.9480.9470.9490.9720.8660.8780.8840.9050.5890.6090.5950.726
FedDistill0.9930.9930.9970.9970.9270.9460.9530.960.9910.9940.990.994
Ours0.9970.9980.9980.9980.9970.9960.9970.9980.9950.9960.9980.999
15 clientsFedMD0.9370.9450.9410.9570.8550.8620.8720.8830.570.5760.5830.669
FedDF0.9290.9370.9480.9510.8590.8560.8760.8820.5820.5720.5750.665
FedDistill0.9660.9670.9690.9760.8890.8950.9010.9070.9830.9910.9930.995
Ours0.9870.9880.9890.9920.9210.9320.9340.9350.9960.9980.9970.998
20 clientsFedMD0.9280.9360.9410.9510.8550.860.8620.8750.5460.550.5680.645
FedDF0.9290.9310.9380.940.8510.860.8610.8720.5510.5540.5430.631
FedDistill0.9650.9670.970.9710.8770.8840.8930.8990.990.9910.9930.997
Ours0.9870.9890.9890.9910.9320.9320.9270.9350.9920.9960.9980.998
Table 5. Average client model accuracy of each algorithm in different data distribution scenarios.
Table 5. Average client model accuracy of each algorithm in different data distribution scenarios.
NumberMethodMNISTFashion-MNISTCIFAR-10
α = 0 . 5 α = 1 α = 10 IID α = 0 . 5 α = 1 α = 10 IID α = 0 . 5 α = 1 α = 10 IID
10 clientsFedMD0.9310.950.9590.9610.7810.8250.8480.8510.4860.4990.5420.56
FedDF0.9330.9380.9480.9520.7990.8240.8450.8490.5170.5450.5570.578
FedDistill0.9380.9460.9530.9560.7770.8270.8420.8520.5340.5510.5740.584
FedDistill+0.7990.8650.9560.9590.7030.7860.8450.8510.3710.4530.5490.582
FedHe0.7880.8460.9550.9610.6980.7620.8460.8520.3910.4740.5490.583
Ours0.9560.9580.9570.9580.8120.8320.8530.8580.5450.5650.5830.592
15 clientsFedMD0.9010.9280.9440.9490.7510.7880.8330.8350.440.4890.5280.547
FedDF0.9020.920.9350.9390.7660.7910.8220.8280.5050.5090.5310.546
FedDistill0.9130.9250.9390.9450.7660.7930.8210.8270.5050.5280.5480.557
FedDistill+0.7310.8380.9390.9470.6570.7640.8320.8410.3360.390.4970.528
FedHe0.7050.8590.9430.950.6460.7280.8310.8430.3540.4130.5060.536
Ours0.9240.9340.9370.9450.770.8020.8360.8480.5280.5420.5550.562
20 clientsFedMD0.8840.9090.9370.9420.7340.7810.8230.8280.4260.4630.5130.525
FedDF0.8990.9130.9270.9320.7660.7870.8210.8180.4850.4970.5140.527
FedDistill0.9070.9190.9340.9380.7510.7950.8220.8170.4990.5150.5330.54
FedDistill+0.6860.8430.9290.9370.6190.740.8220.8310.3140.3790.4760.498
FedHe0.6730.8230.9330.9410.6570.7350.8240.8340.3360.3730.4830.506
Ours0.9180.9240.9390.9410.7690.7960.830.8350.5190.5220.540.546
Table 6. Average computation time for each module on the server side (seconds).
Table 6. Average computation time for each module on the server side (seconds).
NumberAttackTrain ModelIdentify ClientCalculate LogitCalculate WeightTotal Time 1Total Time 2
10 clientsType-110.1122.8020.2982.08115.2945.181
Type-210.0952.7240.2962.03915.1555.06
Type-310.1962.6790.3011.93215.1084.912
15 clientsType-110.34.0170.2922.95117.567.26
Type-210.4973.9640.2973.00517.7647.267
Type-310.263.8850.2932.80717.2156.955
20 clientsType-110.6255.6110.2984.06920.6049.978
Type-210.3215.2590.2953.98419.869.539
Type-310.3375.2780.2993.82619.749.403
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Chen, L.; Zhang, W.; Dong, C.; Zhao, D.; Zeng, X.; Qiao, S.; Zhu, Y.; Tan, C.W. FedTKD: A Trustworthy Heterogeneous Federated Learning Based on Adaptive Knowledge Distillation. Entropy 2024, 26, 96. https://doi.org/10.3390/e26010096

AMA Style

Chen L, Zhang W, Dong C, Zhao D, Zeng X, Qiao S, Zhu Y, Tan CW. FedTKD: A Trustworthy Heterogeneous Federated Learning Based on Adaptive Knowledge Distillation. Entropy. 2024; 26(1):96. https://doi.org/10.3390/e26010096

Chicago/Turabian Style

Chen, Leiming, Weishan Zhang, Cihao Dong, Dehai Zhao, Xingjie Zeng, Sibo Qiao, Yichang Zhu, and Chee Wei Tan. 2024. "FedTKD: A Trustworthy Heterogeneous Federated Learning Based on Adaptive Knowledge Distillation" Entropy 26, no. 1: 96. https://doi.org/10.3390/e26010096

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop