1. Introduction
The unprecedented success of modern machine learning (ML) techniques in areas such as computer vision [
1], neuroscience [
2], image processing [
3], robotics [
4] and natural language processing [
5] has led to an increasing interest for their application to wireless communication systems in recent years.
Early efforts along this line of work fall into what is sometimes referred to as the “learning to communicate” paradigm, in which the goal is to automate one or more communication modules such as the modulator-demodulator, the channel coder-decoder, or others, by replacing them with suitable ML algorithms. Although important progress has been made for some particular communication systems, such as the molecular one [
6], it is still not yet clear whether ML techniques can offer a reliable alternate solution to model-based approaches, especially as typical wireless environments suffer from time-varying noise and interference.
Wireless networks have other important intrinsic features which may pave the way for more cross-fertilization between ML and communication, as opposed to applying ML algorithms as black boxes in replacement of one or more communication modules. For example, while in areas such as computer vision, neuroscience, and others, relevant data is generally available at one point, it is typically highly distributed across several nodes in wireless networks.
Examples include self-driving cars where multiple sensors, both external and internal to the car can be used to help the car navigate its environment, medical applications to diagnose a patient based on data from different medical institutions or environmental monitoring to detect hazardous events or pollution, and others, see [
7,
8] for more information. We give more details of the usefulness of such setups in Examples 1 and 2. A prevalent approach for the implementation of ML solutions in such cases would consist of collecting all relevant data at one point (a cloud server) and then training a suitable ML model using all available data and processing power. Because the volumes of data needed for training are generally large, and with the scarcity of network resources (e.g., power and bandwidth), that approach might not be appropriate in many cases, however. In addition, some applications might have stringent latency requirements which are incompatible with sharing the data, such as in automatic vehicle driving. In other cases, it might be desired not to share the raw data for the sake of enhancing the privacy of the solution, in the sense that infringing the user’s privacy is generally more easily accomplished from the raw data itself than from the output of a neural network (NN) that takes the raw data as input.
The above has called for a new paradigm in which intelligence moves from the heart of the network to its edge, which is sometimes referred to as “Edge Learning”. In this new paradigm, communication plays a central role in the design of efficient ML algorithms and architectures because both data and computational resources, which are the main ingredients of an efficient ML solution, are highly distributed. A key aspect towards building suitable ML-based solutions is whether the setting assumes only the training phase involves distributed data, sometimes referred to as
distributed learning, such as the Federated Learning (FL) of [
9] or if the inference (or test) phase also involves distributed data.
The considered problem setup is strongly related to the problems of distributed estimation and detection (see, e.g., [
10,
11,
12,
13] and references therein). We differentiate ourselves from these problems as we assume no prior knowledge of distribution of the data. This is a common setup in many practical applications, such as image or speech processing, or text analysis, where the distribution between the observed data and the target variable is unknown or too complex to model.
In particular, of those most closely related to this paper, a growing line of works focus on developing distributed learning algorithms and architectures. The works of [
14,
15] address the problem of distributed learning using kernel methods when each node observes independent samples drawn from the same distribution. In our specific setup, however, the nodes observe correlated data, necessitating collaboration among all nodes during inference. On the other hand, works such as [
16,
17] are focused on the narrower problem of detection and impose certain restrictions on the scope of their investigation. However, perhaps most popular and related to our work is the FL of [
9] which, as we already mentioned, is most suitable for scenarios in which the training phase has to be performed distributively, while the inference phase has to be performed centrally at one node. To this end, during the training phase, nodes (e.g., base stations) that possess data are all equipped with copies of a single NN model which they simultaneously train on their locally available data-sets. The learned weight parameters are then sent to a cloud or parameter server (PS) which aggregates them, e.g., by simply computing their average. The process is repeated, every time re-initializing using the obtained aggregated model, until convergence. The rationale is that, this way, the model is progressively adjusted to account for all variations in the data, not only those of the local data-set. For recent advances on FL and applications in wireless settings, the reader may refer to [
18,
19,
20] and references therein. Another relevant work is the Split Learning (SL) of [
21] in which, for a multiaccess type network topology, a two-part NN model, split into an encoder part and a decoder part, is learned sequentially. The decoder does not have its own data and in every round the NN encoder part is fed with a distinct data-set and its parameters are initialized using those learned from the previous round. The learned two-part model is then used as follows during the inference: one part of this model is used by an encoder, and the other one by a decoder. Another variation of SL, sometimes called “vertical SL”, was proposed recently in [
22]. The approach uses vertical partitioning of the data; in the special case of a multi-access topology, it is similar to the in-network learning solution that we propose in this paper.
Compared to both SL and FL, which consider only the training phase to be distributed, in this paper we focus on the problem in which the inference phase also takes place distributively. More specifically, in this paper, we study a network inference problem in which some of the nodes possess each, or can acquire, part of the data that is relevant for inference on a random variable Y. The node at which the inference needs to be performed is connected to the nodes that possess the relevant data through a number of intermediate other nodes. We assume that the network topology is fixed and known. This may model, e.g., a setting in which a macro BS needs to make inference on the position of a user on the basis of summary information obtained from correlated CSI measurements that are acquired at some proximity edge BSs. Each of the edge nodes is connected with the central node either directly, via an error free link of given finite capacity, or via intermediary nodes. While in some cases it might be enough to process only a subset of the J nodes, we assume that processing only a (any) strict subset of the measurements cannot yield the desired inference accuracy and, as such, the J measurements need to be processed during the inference or test phase.
Example 1. (Autonomous Driving) One basic requirement of the problem of autonomous driving is the ability to cope with problematic roadway situations, such as those involving construction, road hazards, hand signals, and reckless drivers. Current approaches mainly depend on equipping the vehicle with more on-board sensors. Clearly, while this can only allow a better coverage of the navigation environment, it seems unlikely to successfully cope with the problem of blind spots due, e.g., to obstruction or hidden obstacles. In such contexts, external sensors such as other vehicles’ sensors, cameras installed on the roofs of proximity buildings or wireless towers may help perform a more precise inference, by offering a complementary, possibly better, view of the navigation scene. An example scenario is shown in Figure 1. The application requires real-time inference which might be incompatible with current cellular radio standards, thus precluding the option of sharing the sensors’ raw data and processing it locally, e.g., at some on-board server. When equipped with suitable intelligence capabilities, each sensor can successfully identify and extract those features of its measurement data that are not captured by other sensors’ data. Then, it only needs to communicate those, not its entire data. Example 2. (Public Health) One of the early applications of machine learning is in the area of medical imaging and public health. In this context, various institutions can hold different modalities of patient data in the form of electronic health records, pathology test results, radiology, and other sensitive imaging data such as genetic markers for disease. The correct diagnosis may be contingent on being able to using all relevant data from all institutions. However, these institutions may not be authorized to share their raw data. Thus, it is desired to distributively train machine learning models without sharing the patient’s raw data in order to prevent illegal, unethical or unauthorized usage of it [23]. Local hospitals or tele-health screening centers seldom acquire enough diagnostic images on their own; collaborative distributed learning in this setting would enable each individual center to contribute data to an aggregate model without sharing any raw data. 1.1. Contributions
In this paper, we study the aforementioned network inference problem in which the network is modeled as a weighted acyclic graph and inference about a random variable is performed on the basis of summary information obtained from possibly correlated variables at a subset of the nodes. Following an information-theoretic approach in which we measure discrepancies between true values and their estimated fits using average logarithmic loss, we first develop a bound on the best achievable accuracy given the network communication constraints. Then, considering a supervised setting in which nodes are equipped with NNs and their mappings need to be learned from distributively available training data-sets, we propose a distributed learning and inference architecture and we show that it can be optimized using a distributed version of the well-known stochastic gradient descent (SGD) algorithm that we develop here. The resulting distributed architecture and algorithm, which we herein name “in-network (INL) learning”, generalize those introduced in [
24] (see also [
25,
26]) for a specific case, multiaccess type, network topology. We investigate in more detail what the various nodes need to exchange during both the training and inference phases, as well as associated requirements in bandwidth. Finally, we provide a comparative study with (an adaptation of) the FL and the SL algorithms, and experiments that illustrate our results. Part of the results this paper have also been presented in [
27,
28]. However, in this paper, we go beyond those works by offering a more comprehensive and detailed review of the state-of-the-art. Additionally, we provide proofs for the theorem and lemmas presented in this paper, which were not included in the previous publications. Furthermore, we introduce additional insights and conclusions that further contribute to the overall understanding and significance of the research findings.
1.2. Outline and Notation
In
Section 2 we describe the studied network inference problem formally. In
Section 3 we present our in-network inference architecture, as well as a distributed algorithm for training it distributively.
Section 4 contains a comparative study with FL and SL in terms of bandwidth requirements; as well as some experimental results. Finally, in
Section 5 we summarize the insights and results presented in this paper.
Throughout the paper, the following notation will be used. Upper case letters denote random variables, e.g., X; lower case letters denote realizations of random variables, e.g., x, and calligraphic letters denote sets, e.g., . The cardinality of a set is denoted by . For a random variable X with probability mass function , the shorthand , is used. Boldface letters denote matrices or vectors, e.g., or . For random variables and a set of integers , the notation designates the vector of random variables with indices in the set , i.e., . If then . In addition, for zero-mean random vectors and , the quantities , and denote, respectively, the covariance matrix of the vector , the covariance matrix of vector and the conditional covariance of given . Finally, for two probability measures and over the same alphabet , the relative entropy or Kullback-Leibler divergence is denoted as . That is, if is absolutely continuous with respect to , then , otherwise .
2. Network Inference: Problem Formulation
We consider the distributed supervised learning setup, in which multiple nodes observe different features relating to the same sample, sometimes refered to as distributed learning with vertically partitioned dataset, see [
8,
29]. We additionally assume the learning takes place over a communication constrained network. Specifically, consider an
N node distributed network. Of these
N nodes,
nodes possess or can acquire data that is relevant for inference on a random variable (r.v.) of interest
Y, with alphabet
. Let
denote the set of such nodes, with node
observing samples from the random variable
, with alphabet
. The relationship between the r.v. of interest
Y and the observed ones,
, is given by the joint probability mass function
, with
and
. For simplicity, we assume that random variables are discreet, however our technique can be applied to continuous variables as well. Inference on
Y needs to be performed at some node
N which is connected to the nodes that possess the relevant data through a number of intermediate other nodes. It has to be performed without any sharing of raw data. The network is modeled as a weighted directed acyclic graph and may represent, for example, a wired network or a wireless mesh network operated in time or frequency division, where the nodes may be servers, handsets, sensors, base stations or routers. We assume that the network graph is fixed and known. The edges in the graph represent point-to-point communication links that use channel coding to achieve close to error-free communication at rates below their respective capacities. For a given loss function
that measures discrepancies between true values of
Y and their estimated fits, what is the best precision for the estimation of
Y? Clearly, discarding any of the relevant data
can only lead to a reduced precision. Thus, intuitively features that collectively maximize information about
Y need to be extracted distributively by the nodes from the set
, without explicit coordination between them and they then need to propagate and combine appropriately at the node
N. How should that be performed optimally without sharing raw data? In particular, how should each node process information from the incoming edges (if any) and what should it transmit on every one of its outgoing edges? Furthermore, how should the information be fused optimally at Node
N?
More formally, we model an
N-node network by a directed acyclic graph
, where
is the set of nodes,
is the set of edges and
is the set of edge weights. Each node represents a device and each edge represents a noiseless communication link with capacity
. See
Figure 2. The processing at the nodes of the set
is such that each of them assigns an index
to each
and each received index tuple
, for each edge
. Specifically, let for
and
l such that
, the set
. The encoding function at node
j is
where
designates the Cartesian product of sets. Similarly, for
, node
k assigns an index
to each index tuple
for each edge
. That is,
The range of the encoding functions
are restricted in size, as
Node
N needs to infer on the random variable
using all incoming messages, i.e.,
In this paper, we choose the reconstruction set
to be the set of distributions on
, i.e.,
and we measure discrepancies between true values of
and their estimated fits in terms of average logarithmic loss, i.e., for
As such, the performance of a distributed inference scheme
for which (
3) is fulfilled is given by its achievable
relevance given by
which, for a discrete set
, is directly related to the error of misclassifying the variable
. It is imporant to note that
is problem specific constant and as such the relavance given by (
6) is simply a another form of the logarithmic loss.
Figure 2.
Studied network inference model.
Figure 2.
Studied network inference model.
In practice, in a supervised setting, the mappings given by (
1), (
2) and (
4) need to be learned from a set of training data samples
. The data is distributed such that the samples
are available at node
j for
and the desired predictions
are available at the end decision node
N. We parametrize the possibly stochastic mappings (
1), (
2) and (
4) using NNs. This is depicted in
Figure 3. We denote the parameters of the NNs that parameterize the encoding function at each node
with
and the parameters of the NN that parameterizes the decoding function at node
N with
. Let
, we aim to find the parameters
that maximize the relevance of the network, given the network constraints of (
3). Given that the actual distribution is unknown and we only have access to a dataset, the loss function needs to strike a balance between its performance on the dataset, given by empirical estimate of the relevance, and the network’s ability to perform well on samples outside the dataset.
The NNs at the various nodes are arbitrary and can be chosen independently—for instance, they need
not be identical as in FL. It is only required that the following mild condition which, as will become clearer from what follows, facilitates the back-propagation be met. Specifically, for every
and
, under the assumtion that all elements of
have the same dimension, it holds that
Similarly, for
we have
Remark 1. Conditions (7) and (8) were imposed only for the sake of ease of implementation of the training algorithm; the techniques present in this paper, including optimal trade-offs between relevance and complexity for the given topology, the associated loss function, the variational lower bound, how to parameterize it using NNs and so on, do not require (7) and (8) to hold. Alternative aggregation techniques, such as element-wise multiplication or element-wise averaging, can be employed to combine the information received by each node, in replacement to concatenation. The impact of these aggregation techniques has been analyzed in [22].