Skip to content
Logo Theodo

Dimensionality reduction for exploration and curation of datasets

Stepan Lebedev13 min read

One million integers embedded into 2D space with UMAP

The availability of AI models for everyday users skyrocketed last year, leading to a surge in Artificial Intelligence’s popularity. However, despite this advancement, companies still grapple with the challenge of collecting accurate data for AI implementation. Today, the significance of well-curated data has only grown, yet many companies struggle with this aspect. With databases expanding rapidly and containing numerous parameters, human analysis becomes exceedingly challenging, if not impossible. Let’s see how we simplify this task using a powerful projection algorithm!

Why clean data is important? Other approaches to get insights

Unrepresentative, poorly structured, or inadequately annotated data significantly contributes to the failure of AI projects. This is primarily due to machine learning’s sensitivity to outliers, including aberrant, false, or unrepresentative data. Additionally, the quality of responses from Large Language Models (LLMs) hinges on the quality of provided context. Therefore, employing specific tools for proper data curation is imperative.

We could cite this interesting approach using LLMs, but it requires to know things about prompt engineering and may struggle to scale. In this article, we try to tackle the problem in another way using a deterministic algorithm.

Indeed, a family of nonlinear dimensionality reduction techniques by simplifying the high dimensional representation of the data, allows getting valuable insights about how to improve the quality of our dataset. There are a lot of different algorithms (such as t-SNE, nonlinear PCA, Laplacian Eigenmaps, etc…) but we will concentrate on the one that gives the best results for visual analysis, UMAP.

What is UMAP (Uniform Manifold Approximation and Projection), a primer

UMAP (Uniform Manifold Approximation and Projection) is a dimensionality reduction technique that is widely used for visualizing high-dimensional data in a lower-dimensional space. It is particularly effective in preserving both the local and global structure of the data.

The UMAP algorithm works by constructing a graph representation of the data, where each datapoint is connected to its nearest neighbors. It then projects each datapoint into a low dimension in a way that preserves these local relationships. Using simple words, if two datapoints are close (with regard to some metric) in the high dimensional space, they will be close in the low dimensional space.

One of the key advantages of UMAP is its ability to handle large datasets efficiently. It scales well to millions of datapoints and can be applied to both numerical and categorical data. UMAP also offers various parameters that allow users to control the trade-off between preserving local and global structure, as well as the density of the resulting embedding.

In summary, UMAP is a powerful dimensionality reduction technique that can be used for exploratory data analysis, clustering, and visualization. It provides a flexible and efficient approach to reducing the dimensionality of complex datasets while preserving important structural information. In our case, we want to find aberrations and confusing datapoints by applying the minimum effort and assuming that we don’t know in depth the client’s domain nor are masters of sophisticated data analysis techniques. It seems that UMAP reducing data to 2D points is a good choice for our use case.

If you are interested in having more details about how UMAP works, feel free to visit the official website.

UMAP application examples

To get a grip on what type of visualization UMAP can provide, we will see in this section its application on a common dataset: MNIST. MNIST is a dataset of handwritten digits, widely used as a benchmark in the field of machine learning. It contains 60 000 train samples (6 000 images per digit) and 10 000 test samples (1 000 images per digit). Each image is on the grayscale and is shaped 28x28. Here are some examples of what kind of images you expect to find: The image shows multiple examples for each class of handwritten digits from MNIST Dataset

With the next code snippet, we manage to plot UMAP output on the MNIST and color each projected datapoint according to the class it belongs to:

import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from umap import UMAP


# Load MNIST dataset
mnist = fetch_openml('mnist_784', version=1, cache=True)
X = mnist.data
y = mnist.target.astype(int)


# Apply UMAP
umap = UMAP(n_components=2)
X_umap = umap.fit_transform(X)


# Plot UMAP output with class colors
plt.scatter(X_umap[:, 0], X_umap[:, 1], c=y, cmap='tab10', s=1)
plt.colorbar()
plt.title('UMAP Visualization of MNIST')
plt.show()

We will then have the following representation:

The plot shows points clustered into distinct groups, each represented by a different color, across a 2D plane. There
are ten clusters corresponding to the digits 0 through 9. Each cluster appears to be roughly in its own region of the
plot, suggesting that the UMAP algorithm has effectively reduced the dimensionality of the MNIST dataset, which consists
of images of handwritten digits, to visually separate the different digit classes in two-dimensional space.

As you can see most of the elements are plotted close to other images from the same class. However, some of them seem to be plotting where they shouldn’t. If you keep track of the image associated with those datapoints you can easily find some annotation errors, difficult data and aberrations.

How to project real-world data with UMAP?

Raw data projection

The first example is easy to understand, but it does not represent the real-world situation. Indeed, it’s not likely that you will have such small images with no colors and a relatively low amount of information. In most cases, you will deal with large data vectors containing a lot of noise.

Let’s see what we can get with another standard dataset, Cifar 10. It still does not represent the real-world data since images are relatively small (32x32 pixels), but we add colors (so real dimensions are 32x32x3) and some more complex shapes (cars, animals, …). Here are some examples of what you can find in this dataset:

The image shows multiple examples for each class of real-world elements from CIFAR10 Dataset

Here is a code snippet allowing you to visualize Cifar10:

import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from umap import UMAP


# Load MNIST dataset
mnist = fetch_openml('CIFAR_10', version=1, cache=True)
X = mnist.data
y = mnist.target.astype(int)


# Apply UMAP
umap = UMAP(n_components=2)
X_umap = umap.fit_transform(X)


# Plot UMAP output with class colors
plt.scatter(X_umap[:, 0], X_umap[:, 1], c=y, cmap='tab10', s=1)
plt.colorbar()
plt.title('UMAP Visualization of CIFAR10')
plt.show()

The plot will be:

The plot shows points that do not reflect any apparent structure. There are no clear clusters as it was for MNIST.
From this, we can conclude that the dimensionality reduction using UMAP is efficient for this dataset. Indeed we don't
have a visual separation of the data

Elements seem to be randomly plotted and no easy data analysis is possible. Why is that happening?

UMAP suffers from what we call the Curse of Dimensionality. In short terms, when dimensionality is rising high then all datapoints are close to each other with regard to metrics. Since UMAP is a metric-based algorithm it fails to correctly map elements.

Pre-trained model for embeddings extraction

To address this problem we have to extract meaningful information from the data into smaller vectors that are called embeddings. To do so we will extract information by using some pre-trained image model. We will use Resnet50 but feel free to experiment with other models! The Resnet model is a relatively big computer vision classification model that you can easily run locally. It is a standard model to use for vision tasks but any other model (variants of EfficientNet, Mobilenet, etc…) will do the work we expect here!

import timm
import torch
from sklearn.datasets import fetch_openml
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
from tqdm import tqdm


# Load Cifar 10 dataset
cifar10 = fetch_openml('CIFAR_10', version=1, cache=True)
X = cifar10.data
y = cifar10.target.astype(int)


# Create model for Embeddings extraction
resnet50 = timm.create_model('resnet50', pretrained=True)
resnet50_without_last_layer = torch.nn.Sequential(*(list(resnet50.children())[:-1])).to("mps")


# Prepare data for Embeddings Extraction
X_tensor = torch.tensor(X)
y_tensor = torch.tensor(y)


resize_transform = transforms.Compose([
   transforms.ToPILImage(),
   transforms.Resize((224, 224)),
   transforms.ToTensor(),
])


X_resized = torch.stack([resize_transform(x) for x in X_tensor])


dataset = TensorDataset(X_resized, y_tensor)
train_loader = DataLoader(dataset, batch_size=64, shuffle=False)


# Extract Embeddings
embeddings = []
progress_bar = tqdm(total=len(train_loader))


# Iterate through batches
with torch.no_grad():
   for inputs in train_loader:
       inputs = inputs[0].to("mps")
       outputs = resnet50_without_last_layer(inputs)
       embeddings.append(outputs.cpu())
       progress_bar.update(1)


progress_bar.close()
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.tolist()

And when we apply UMAP on those embeddings, we will have the following representation:

The plot shows a partial separation of the CIFAR10 dataset. Clustering being partial each class can easily be
identified. It suggests that those new embeddings are of better quality rather than the initial images for UMAP
dimensionality reduction. However, the information contained in them is not representative enough to completely separate
different clusters and there are areas where no analysis at first glance is possible.

It is already more organized and some analysis may be done. You could ask why some pieces of data are badly plotted. Pretrained vision models may struggle to retrieve meaningful embeddings when presented images are too different from the training set. In our case, we used a model trained on ImageNet, a large image dataset with each image having 469x387 pixels. When presenting our up-scaled 32x32 pixels images it is quite different from what it has already seen.

Fine-tuned model for embeddings extraction

However, it is possible to push a little bit further. Indeed, one can fine-tune the pre-trained model to improve retrieved embeddings. The following representation is obtained from only one epoch training on the Cifar 10 dataset:

The plot shows an even better cluster representation of the CIFAR 10. Overlapping areas that are difficult to analyze
are highly reduced compared to the previous visualization. Each class cluster is well-defined. This shows that the
the latest approach is producing better embeddings for UMAP dimensionality reduction.

And now we have something more comfortable to work with!

Things to remember for real-world data projection with UMAP

Those first examples were aiming to make a few points:

Examples were done with image datasets, but it is important to remember that UMAP is completely agnostic of what kind of data you are providing! If you manage to get meaningful embeddings of your data, you can use UMAP to verify its quality.

As the final touch, the 2D representation when properly linked to the meaning of each datapoint (ex: its classes) will allow you to properly curate the dataset even if you are not the expert of the business logic! Indeed, you only need to spot badly plotted points. Furthermore, if you filter out some elements you can see the immediate impact on your data in the high dimension. It can be useful, especially for some GenAI projects when you need to retrieve data vectors to populate context based on the distance to the user prompt.

Let’s see what we can do with some data used for a GenAI project.

How to use UMAP for AI Chatbot project

Context about a Chatbot Project

When creating a domain-specific chatbot at some point you will create a text database that you will use as context. How does it work? When Someone is chatting with your bot, its messages will be vectorized then after the similarity search the corresponding context will be retrieved from the database. Then prompt engineering will order the LLM how to assemble the question and the context. Finally, the model response will be generated.

If you’d like more insights about the process feel free to visit Teaching Custom Knowledge to AI Chatbots.

Why is it important to curate the context database?

The similarity search is done by computing the distance between the user’s message embeddings and context embeddings stored in your database. Since context retrieval is done by metric computation, it is important to have elements “far” one from another. Otherwise, the retrieved context may not match the user’s message. As a consequence, the response will not be satisfactory.

Example of GenAI dataset projection

To illustrate how we can use UMAP to improve the context database, we will use the same data as in Teaching Custom Knowledge to AI Chatbots (kindly provided by its author). This database contains 996 entries where each entry is an association between text and LLM generated embeddings.

The main difference between images and text is that for text we have very large pre-trained models that are good for embedding extraction no matter (almost) what text we present to them. Just to compare Resnet50 that we used previously has 25.6 million parameters and GPT-3.5 has 175 billion of them! We can expect then, that embeddings provided by LLM are good enough for UMAP projection.

The pipeline for UMAP projection of the text data would be :

The pipeline for text data projection using UMAP will follow the next steps. At first from the context text database,
we will generate embeddings for each entry using a LLM. Then we will compute 2D Projection using the UMAP algorithm.

When those embeddings are projected the following representation results:

The plot shows a point cloud representing elements from the text database. Some areas are selected to point out some
issues with the dataset (such as Redundancy of the information, not relevant data)

As you can see it is quite difficult to interpret those results. However, it shows that we have a lot of redundancy of information in our dataset. To deal with this issue, we would apply the following filtering:

By changing the distance threshold we can keep track of how we are modifying our high-dimensional dataset and choose the best representation for our use case!

After some experiments, we achieve the following representation:

In this new plot, there is less data. Especially for areas that were pointed out as problematic in previous
representations. It suggests that further exploitation of the dataset will be more precise in this state (shorter search
time, fewer context extraction errors, ...)

In this new representation with a very simple filter, we got rid of redundant information and still preserved useful pieces of data. By doing so, we divided the size of the dataset by 4!

Furthermore, as said before, context retrieval is better when datapoints are far apart with regard to some metric. With the new visualization, we also improved this aspect of our database.

Conclusion on UMAP utility for dataset curation

In conclusion, UMAP emerges as a valuable asset in the arsenal of tools for data exploration in AI projects. Its ability to efficiently handle large and complex datasets, coupled with intuitive visualization capabilities, makes it an efficient tool for AI practitioners striving for success in their projects. As AI continues to advance, leveraging techniques like UMAP will be crucial in unlocking the full potential of data-driven applications across various domains.

Liked this article?