Feb 22, 2023
Representation learning has been the core matter of all deep learning systems as it enables them to automatically discover and learn the underlying structure in the data. Learning representation is not an easy task as it requires processing a massive amount of data; this can be difficult especially if we are working on tasks pertaining to medicine and healthcare. In such settings, the need for an AI system that can learn a given task with less data is essential as not all data is accessible. This is where reasoning-modulated representation comes into the picture.
The idea behind reasoning-modulated representation learning is to leverage a neural network to learn a representation of input and reason about that representation to better understand the input. By embedding the reasoning process into the system the neural network can make more informed predictions, decisions, and recommendations. This is due to the fact that the reasoning process is used to modify the learned representation.
Reasoning-modulated representations aim to create AI systems that are not just data-driven, but also capable of understanding and reasoning about the information they are processing, leading to improved performance and more human-like intelligence.
Magnetic Resonance Imaging and Computerised Tomography scans
Magnetic Resonance Imaging (MRI) and Computerized Tomography (CT) scans are two types of diagnostic imaging techniques commonly used in medicine. These methods allow radiologists to visualize the internal structures of the body.
MRI is a non-invasive imaging technique that uses a strong magnetic field and radio waves to produce detailed images of the body's internal structures. This technique uses a magnet to create a magnetic field that causes the hydrogen atoms in the body's tissues to align in a particular direction. Radio waves are then used to cause the hydrogen atoms to emit signals detected by a scanner receiver. These signals are used to create highly detailed images of the body's internal structures, including soft tissues, such as organs, muscles, and ligaments. MRI is especially useful for diagnosing conditions that affect the brain and nervous system, such as stroke, brain tumors, and spinal cord injuries.
On the other hand, Computerized Tomography (CT) scans use a series of X-rays to produce images of the body's internal structures. The machine emits X-rays, which pass through the body and are detected by a series of detectors on the opposite side. The detectors collect the data and use it to create highly detailed cross-sectional images of the body's internal structures. CT scans are beneficial for diagnosing conditions that affect the bones, such as fractures, and detecting tumors and other abnormalities in the body's organs.
Importance of representation learning to understand MRI and CT scans
Representation learning is an important field in deep learning. It essentially studies how we can explore the fundamental vectors or distributions or underlying patterns that make up the entire dataset. These representations can be either manually selected or automatically learned.
In medicine, learning representations automatically via deep neural networks are extremely complex. Part of the reason for the complexity is that the granular details or features are saturated, sometimes homogenous, within the dataset which makes it hard to learn the proper structures. Medical scans, such as MRI and CT scans, capture detailed images of the internal structures of the body, and the resulting image data can be very high-dimensional and noisy. Moreover, the structures being imaged can be highly variable in shape, size, and appearance, depending on the patient's age, health status, and medical history.
To learn representations in such datasets using deep learning techniques two points must be kept in mind:
Requirement of the large and complex neural network so that it can extract meaningful representations.
A larger dataset to train on in order to make a much more generalized prediction.
But as we discussed earlier finding a large dataset in the medical setting is not always possible. Hence we must come to another method that could yield a much more generalized prediction.
Issues with learning representation
Before we move further along the discussion let’s discuss some of the issues that we as deep learning engineers often face while learning representations via deep neural networks in the field of medicine.
Data variability: Medical images can vary widely in terms of their quality, resolution, and image artifacts, which can make it difficult for deep-learning models to learn accurate and generalizable representations.
Data sparsity: Medical imaging datasets can be relatively small, which can limit the ability of deep learning models to learn complex and accurate representations.
Labeling challenges: Medical imaging datasets may require expert annotations or segmentation to accurately label regions of interest, which can be time-consuming and may require specialized expertise.
Data privacy: These datasets can contain sensitive patient information, which can limit access to data and make it difficult to build large and diverse datasets for representation learning.
Generalization: Deep learning models trained on one type of medical imaging data may not necessarily generalize well to other types of imaging data or to different patient populations.
Interpretability: The representations learned by deep learning models from medical imaging data can be highly complex and difficult to interpret, which can limit their clinical applicability and usefulness for medical professionals.
Leveraging reasoning modulated representation learning
To tackle the above-mentioned issues we will explore reasoning-modulated representation (RMR) learning. It is a type of representation learning where the deep learning technique uses an algorithmic approach. It is an approach that consists of predefined rules or laws. For instance, you can use a mathematical formula to transform one given to another. Essentially algorithm approach has known parameters and operations that help in transformation contrary to the deep neural network that optimizes itself to find an optimum solution.
The authors in the paper titled Neural Algorithmic Reasoning argue that algorithms have played a crucial in recent technological advancements. They have been the basis for technical progress in many fields, deep learning methods possess fundamentally different qualities that prevent them from achieving similar levels of generalization. However, the authors suggest that if deep learning methods were better able to mimic or adapt algorithms, they would be able to achieve this level of generalization. Additionally, by representing embeddings in a continuous space of learned algorithms, deep neural networks can adapt known algorithms more closely to real-world problems and potentially find more efficient solutions than those proposed by human-computer scientists.
We have to understand that algorithm usually creates a bottleneck when processing huge amount of data. Authors in the paper titled “Neural Algorithmic Reasoning” argue that the predictions and the representations learned by deep learning have significantly different properties from algorithms; the former offer scant guarantees but can adapt to a wide range of challenges, whilst the latter offer high assurances but are rigid to the problem being addressed.
Inspired by deep reinforcement learning it is possible to leverage the algorithmic approach in deep neural networks. Let’s elaborate on how these two techniques are combined together in RMR. In RMR there are two major approaches:
Prior knowledge of the workflow: Basically, we have an aim or a well-defined task in hand. This task can be looking for outliers in the given CT or MRI scans. Interestingly, we can define the shape, size, and appearance of the structure, depending on the patient's age, health status, medical history, and region of interest in the scan. This would allow us to make a stronger prediction on the fewer data. The authors argue that such “ knowledge, however, usually requires us to be mindful of abstract properties of the data—and such properties cannot always be robustly extracted from natural observations.”
Data-driven approach: This is where a deep neural network is trained.
Ideally, in the prior knowledge workflow, the abstract knowledge of the data is extracted. Now let us assume that we are working with the MRI scans pertaining to the later stages of Alzheimer’s disease. In this case, an MRI may be seen as a decrease in the size of various brain regions (mainly affecting the temporal and parietal lobes). We can assume definitive properties such as the shape, size, and appearance of the structure, depending on the patient's age, health status, medical history, and region of interest in the affected regions of the scan as abstract knowledge. These properties can be defined as x`. Keep in mind that x`<<x i.e. it has a much lower dimension compared to original input x.
Similarly, we can define y` which is the output of abstract knowledge. This set of outputs is assumed to be produced by algorithm A. We also assume that algorithm A is known or can be trivially computed.
Thus the whole step can be defined as A(x`)=y`.
The equation above can also be written as g(P(f(x`))) in terms of neural network. The equation follows an encode-process-decode paradigm. It clearly indicates that it has three important modules:
Encoder: This module projects x` into a latent space Z which is a high-dimensional continuous space. The high dimensionality enables neural networks to avoid bottleneck issues.
Processor: The processor stimulates individual algorithmic steps in high-dimensional latent space. This step also reduces and ensures computational efforts. The processor is an important component as it learns the underlying law of the task at hand. It is generally a deep MLP or graph neural network (GNN). Once a proper processor, P, has been obtained –via training– it can be plugged into other neural networks. Because it remains a high-dimensional and differentiable component it doesn’t experience any bottleneck effects. Specifically, we may relate its operations to algorithm A. We have to keep in mind that the P is approximated and it does not represent the algorithm.
Decoder: The decoder projects the latent space back to the abstract output space similar to autoencoders.
Now, when it comes to the data-driven approach we can follow similar steps as to what we saw for the prior knowledge workflow. Here we will replace the encoder f to f` and decoder g to g` such that we are able to learn a function g`(P(f`(x))) ≈ Φ(x). Keep in mind that the processor P is the same. Essentially, the parameters in P are frozen in the data-driven pipeline so the learnable function – or the neural network – can be optimized using gradient descent without affecting the learned parameters of P.
Note: P might not perfectly represent A which in turn might not perfectly represent Φ. While we rely on a skip connection in our implementation of P, it has no learnable parameters and does not offer the system the ability to learn a correction of P in the natural setting. Our choice is motivated by the desire to both maintain the semantics and interpretability of P and to force the model to rely on the processor P, not simply bypass it. We show empirically that our pipeline is surprisingly robust to imperfect P models even with weak (linear) encoders/decoders.
This approach essentially creates an opportunity to learn better representations. Once these representations are learned they can then be down streamed to various tasks.
There are still challenges that we need to overcome such as since the parameters of P are kept frozen in the natural pipeline, f` is left with the challenging task of mapping natural inputs into an appropriate manifold that P can meaningfully operate over.
The authors argue that the success of mapping depends on carefully tuning the hyperparameter of f`.
Here are some key points that you might consider:
We found that common representation learning setting where we know something about our task's generative process. e.g. system must process some known attributes of the task in our case it would be structure, shape, and region of interest in the scan, which can make the task easy to approach with fewer data.
However, explicitly making use of this (abstract) information is often quite tricky. Processing this abstract information attracts bottlenecks as algorithms often assume that x is provided without error. But in the real world, there is always unprecedented noise available in the data. These noises can vary from sample to sample.
In RMR, we show that we can encapsulate the x` -> y` path using a high-dimensional GNN, pre-trained on large quantities of data (which we can usually pre-generate, even synthetically). Essentially learning a neural network approximator that emulates the algorithm. This approach makes the algorithm flexible and robust to noises.
The learned algorithmic approximator or Processor P can be leveraged to the natural pipeline.
There are ample opportunities to leverage RMR to enhance AI in healthcare. The idea of RMR is fascinating and it can actually help us to make a better generalization. As you can see that deep learning algorithms can function and generalize better if they have an anchor point to lean onto. Essentially the processor P provides that anchor point and reasoning power which in turn maps the input to the output efficiently and with fewer data. The overall idea is to provide neural networks with some predefined information to enhance the generalization.