An Adaptive Contrastive Learning Model for Spike Sorting
An Adaptive Contrastive Learning Model for Spike Sorting
Sorting
Shenzhen, China
[email protected], [email protected]
Shengjie Zheng
Shenzhen Institute of Advanced Technology, Chinese Academy of Sciences
University of Chinese Academy of Sciences
Shenzhen, China
[email protected]
Abstract
Brain-computer interfaces (BCIs), is ways for electronic devices to communicate
directly with the brain. For most medical-type brain-computer interface tasks,
the activity of multiple units of neurons or local field potentials is sufficient for
decoding. But for BCIs used in neuroscience research, it is important to separate
out the activity of individual neurons. With the development of large-scale silicon
technology and the increasing number of probe channels, artificially interpreting
and labeling spikes is becoming increasingly impractical. In this paper, we propose
a novel modeling framework: Adaptive Contrastive Learning Model that learns
representations from spikes through contrastive learning based on the maximizing
mutual information loss function as a theoretical basis. Based on the fact that data
with similar features share the same labels whether they are multi-classified or
binary-classified. With this theoretical support, we simplify the multi-classification
problem into multiple binary-classification, improving both the accuracy and the
runtime efficiency. Moreover, we also introduce a series of enhancements for the
spikes, while solving the problem that the classification effect is affected because
of the overlapping spikes.
1 Introduction
Brain-computer interfaces (BCIs), are ways for electronic devices to communicate directly with
the brain. The most common brain-computer interface applications are in medical care, typically
in the form of collecting neural signals from a subject’s brain to help the subject perform tasks
∗
Lang Qian and Shengjie Zheng contributed equally to this work
†
Cheng Yang and Xiaojian Li are co-corresponding authors
2 Method
Step 1: The signal recordings obtained through microelectrodes are not directly usable due to Local
Field Potential (LFP) and high frequency noise. We need to band-pass filter the raw data from 300Hz
to 3000Hz, the data in the range from 1Hz to 100Hz, which is usually LFP, and the data in the range
from 300Hz to 3000Hz, which is the spike we use normally, the data at higher frequencies are usually
high frequency noise Rey et al. [2015], As shown in Fig.1.
Step 2: When we get the recorded bandpass filtered signal, we need to locate the spikes, there
are many methods for example, Absolute Value, Nonlinear Energy Operator, Stationary Wavelet
Transform Product Gibson et al. [2008], the task of locating the spikes is not the focus of this study.
2
Figure 1: The neural information is recorded using an array of electrodes, each measuring nearby
neural activity. The raw neural signal is transformed into a digital signal by the signal acquisition
system. The signals can be processed by band-pass filters and spike detecting by detector. The image
modified from Zheng et al. [2022].
The key element of contrast learning is to construct augmented data with the same semantic meaning
as the original sample. We can use contrast learning to find the maximum similarity in similar
samples and the maximum difference in different samples and based on that, we can extract the latent
representation in the samples. For example, we can recolor the image or scale the audio amplitude,
but we can still clearly perceive that the augmented data share the same semantics as the original
data Mohsenvand et al. [2020]. However, the augmentation methods of spike data is not exactly the
same as the augmentation of image data, text data, and audio data. To solve the problem of spike
data augmentation, we want to propose a system on multi-channel spike data augmentation that can
guarantee that the augmented sample and the original sample have the same semantics between them
both.
We uphold the idea that problem solving starts from the problem itself. We augmented the spikes fired
by neurons, and we need to try to simulate the different spike data fired by neurons in the brain due to
various possible influences. This is because the spike fired by the same neuron at different moments
is not exactly the same, but they share similar semantics. After observing the data fired by neurons
and consulting with relevant biological background experts, we propose five multi-channel spike data
augmentation methods, as shown in Fig. 2.(1) Random noise, (2) DC shifting (3) Horizontal shifting
(left or right), (4) Amplitude scaling, and (5) Spike overlapping.
This paragraph focused on the details of these five data augmentation methods. the operation is
performed on normalized data, with 10 channels and 81 sampling points.(1) Random noise: the
distribution of random noise values is a random normal distribution with a variance of 0.1. (2) DC
shifting: obeying a uniform distribution of - 0.1 to 0.1. (3) Horizontal shifting: Obeying a uniform
distribution of - 20 to 20 and the value is an integer. (4) Amplitude scaling: Following the uniform
distribution of 0.9 to 1.1. (5) Spike overlapping: randomly selecting a data (including raw data),
performing horizontal offset, then randomly reducing the amplitude and summing it with the raw
data.
Classifying spike data is not the same as classifying pictures in the traditional way, because we
cannot determine the number of spike classifications, we cannot treat classifying spike data as a
multi-classification task. In this case, we propose a new spike data classification model: an adaptive
contrast classification model, referred to as ACCM, as shown in Fig. 3.
3
Figure 2: Five enhancements of spike data are described, from top to bottom, as random noise, DC
shifting, Horizontal shifting, Amplitude scaling, Spike overlapping.
In this model, to give users more choices, we design three encoder architectures for Backbone
Network, as shown in Fig. 3:
(1) A transformer encoder is abbreviated as ACCM-T, as shown in Fig.4(a). This model not only
focuses on the peak and waveform of spike data, but also pays more attention to the channel-to-
channel relationship. Based on this situation, we chose the Transformer encoder block because of its
4
unique Multi-Head attention mechanism. Also, to focus more on the channel information, this model
still includes the position encoding operation, which is also suitable for large batches of data.
(2) A 1D-CNN encoder is abbreviated as ACCM-C, as shown in Fig.4(b). With aim of extracting the
different latent features of the current channel spike, the data have three 1D-convolution operations.
The kernel size of three 1D-convolution operations is all 1. Meanwhile, the output dimensions are
5,10,15 repectively. The reason for this design is to pay more attention to the characteristics of the
spike itself in terms of waveform, amplitude and other relevant factors. After three 1D-convolution
operations, we performed the concatenate and flatten operations respectively. We also abandoned the
use of L2-Regularization with considering that spike data is very sensitive to amplitude. Finally, we
reduced the dimensionality of the data by performing several fully connected operations. Meanwhile,
we abandon the use of 2D-CNN in this model, because the relationship between the microelectrode
channels is not really linear, but more like a graph structure. This model encoder is more suitable for
spike data that needs to be sensitive to peaks and patterns.
(3) A recurrent neural network encoder is abbreviated as ACCM-R, as shown in Fig.4(c). This model
is more concerned with channel-to-channel relationships, when the number of channels of neural
electrodes reaches hundreds or thousands, not all channels can detect spike signals, and the neural
information contained in multiple channels is sufficient to meet the requirements of spike sorting.
This model encoder is more suitable for channel-sensitive spike data.
At the same time, when we train the model, in order to avoid falling into the local optimal point, we
iteratively trained the two-class and four-class classifiers, which can help us find the global optimal
point. We set up a controller under each binary-classification branch to decide whether the classified
data should be classified again.
Figure 4: Encoder architectures: (a) Transformer encoder block, (b) 1D-CNN encoder block, (c)
RNN encoder block.
In this model, we choosed to maximize the mutual information for contrastive lossJi et al. [2019].
Because we are dealing with a binary-classification task, so the joint probability distribution was
given by the 2x2 matrix M, as shown in Equation 1.
n
1X T
M= Φ (xi ) · Φ (x0i ) (1)
n i=1
The output Φ(x) ∈ [0, 1]2 can be interpreted as the distribution of a discrete random variableover 2
classes. As we generally consider symmetric problems, we usually redefine M as M + MT /2.
5
2 X
2
X Mcc0
I(M) = Mcc0 · ln (2)
c=1 c0 =1
Mc · M0c
where maximizing I(M) trades-off minimizing the difference between the raw data and the augmented
data, as shown in Equation 2.
2.5 Dataset
In-vivo dataset: We used the hc3 dataset sourced from CRCNSMizuseki et al. [2013], the data set
contains recordings made from multiple hippocampal areas in Long-Evans rats. The raw (broadband)
data was recorded at 20KHz, and we selected a segment of data with 8 channels, each with 32
sampling points, and a total of three neurons firing spikes. Total number of spike data is 6000, each
neuron fired a total of 2,000 times.
Simulated data: The simulated dataset we used was derived from kilosort, having 10 channels, each
with 81 sampling points, and a total of eleven neurons firing spikes. Total number of spike data is
22000, each neuron fired a total of 2,000 timesPachitariu et al. [2016].
3 Experiment
In this section, we compared our results with one unsupervised model AE Eom et al. [2021] which
reconstructed the model own input and two contrastive learning models, SimCLR Chen et al. [2020],
SimCSE Gao et al. [2021]. The initialization weights of the model parameters and the selection of
the optimizer as well as the size of the learning rate are also presented.
To train our model, we experimented with three architectures Fig. 4. In each backbone network,
the choice of batch sizes with training data is not exactly the same. The epochs of each backbone
network node are determined by the number of training data, and usually fluctuate between 15 and 60.
After several model training sessions, we found that initializing the model parameters to a random
normal distribution with variance of 0.1 allowed the model to converge faster. At the same time, we
use the Adam optimizer instead of the SGD optimizer in order to avoid getting trapped in the local
optimal point, and the learning rate was initialized to 0.005. We performed the experiments on one
GTX 2080ti GPU.
4 Results
We tested three different versions of the model, ACCM-R, ACCM-C, and ACCM-T, and obtained the
results of these three versions in terms of accuracy and running time, respectively, for comparison
with an unsupervised model AE Eom et al. [2021], two contrastive learning models SimCLRChen
et al. [2020] and SimCSEGao et al. [2021].
Table 1: Simulated
Model ACC Running Time
AE 83.2% 6min13s
SimCLR 87.8% 8min17s
SimCSE 61.1% 3min3s
ACCM-R 95.3% 21min3s
ACCM-C 96.5% 5min41s
ACCM-T 98.4% 11min4s
6
Table 2: In-vivo
Model ACC Running Time
AE 71.2% 1min24s
SimCLR 97.9% 1min46s
SimCSE 82.1% 1min21s
ACCM-R 98.7% 1min9s
ACCM-C 99.9% 1min6s
ACCM-T 99.9% 58s
The above table 1 and table 2 shows the results of our experiments, the second column represents the
accuracy rate, and the third column is the model running time to reach the highest accuracy rate. Our
model is provided in three different versions for different situations. If we need fast running time
and relatively high accuracy, we can choose ACCM-C version. For high efficiency, we can choose
ACCM-T. Also, ACCM-T achieves very high accuracy on simulated data.
Through the above tables,we observed that the accuracy of model SimCSE is much less than that
of model SimCLR for both in-vivo and simulated datasets. Model SimCSE augmented spike data
only by dropout, while model SimCLR augmented spike data through our model’s data augmentation
methods. This can illustrate the superiority of our data augmentation methods. Meanwhile, we
observe that model SimCLR is much more accurate when the number of clusters is relatively small
than when the number of clusters is relatively large. This illustrated that in spike sorting, multiple
binary-classification tasks could be easier to get a higher accuracy rate than the multi-classification
task. Again, we concluded from the above two tables that our model could obtain high accuracy in
both in-vivo and simulated datasets, which is a good indication of the superiority of our model in
handling spike sorting.
5 Conclusion
In the field of neuroscience research, brain-computer interfaces provide a way to study the direct
causal relationship between small populations of neurons and specific external outputs. Identification
of the spiking waveform of each neuron is the basis for distinguishing the activity of individual
neurons in a small population of neurons. In this paper we introduced ACCM, a self-supervised
framework for learning representation for spike data. To be able to accomplish this work, we proposed
five ways of augmenting the spike data. We proposed three model variants, ACCM-C, ACCM-R,
and ACCM-T, because different professionals have different preferences for accuracy and runtime
when classifying spikes, and for the first time we introduced contrastive learning into spike sorting,
achieving state-of-the-art accuracy without the need for electrode maps. And, since we cannot
determine the number of clusters while performing spike sorting, we also proposed an adaptive
classification algorithm to transform the multi-classification task into multiple binary-classification
tasks.
References
Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework for contrastive
learning of visual representations. In International conference on machine learning, pages 1597–1607. PMLR,
2020.
Junsik Eom, In Yong Park, Sewon Kim, Hanbyol Jang, Sanggeon Park, Yeowool Huh, and Dosik Hwang.
Deep-learned spike representations and sorting via an ensemble of auto-encoders. Neural Networks, 134:
131–142, 2021.
Nir Even-Chen, Dante G Muratore, Sergey D Stavisky, Leigh R Hochberg, Jaimie M Henderson, Boris Murmann,
and Krishna V Shenoy. Power-saving design opportunities for wireless intracortical brain–computer interfaces.
Nature biomedical engineering, 4(10):984–996, 2020.
Tianyu Gao, Xingcheng Yao, and Danqi Chen. Simcse: Simple contrastive learning of sentence embeddings.
arXiv preprint arXiv:2104.08821, 2021.
7
Sarah Gibson, Jack W Judy, and Dejan Markovic. Comparison of spike-sorting algorithms for future hardware
implementation. In 2008 30th Annual International Conference of the IEEE Engineering in Medicine and
Biology Society, pages 5015–5020. IEEE, 2008.
Xu Ji, Joao F Henriques, and Andrea Vedaldi. Invariant information clustering for unsupervised image classifica-
tion and segmentation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages
9865–9874, 2019.
Yinghao Li, Shuai Tang, and Virginia R de Sa. Supervised spike sorting using deep convolutional siamese
network and hierarchical clustering. unpublished thesis, 2019.
Kenji Mizuseki, Anton Sirota, Eva Pastalkova, Kamran Diba, and György Buzsáki. Multiple single unit
recordings from different rat hippocampal and entorhinal regions while the animals were performing multiple
behavioral tasks. CRCNS org, 2013.
Mostafa Neo Mohsenvand, Mohammad Rasool Izadi, and Pattie Maes. Contrastive representation learning for
electroencephalogram classification. In Machine Learning for Health, pages 238–253. PMLR, 2020.
Marius Pachitariu, Nicholas A Steinmetz, Shabnam N Kadir, Matteo Carandini, and Kenneth D Harris. Fast and
accurate spike sorting of high-channel count probes with kilosort. Advances in neural information processing
systems, 29, 2016.
Hernan Gonzalo Rey, Carlos Pedreira, and Rodrigo Quian Quiroga. Past, present and future of spike sorting
techniques. Brain research bulletin, 119:106–117, 2015.
Muhammad Saif-ur Rehman, Omair Ali, Susanne Dyck, Robin Lienkämper, Marita Metzler, Yaroslav Parpaley,
Jörg Wellmer, Charles Liu, Brian Lee, Spencer Kellis, et al. Spikedeep-classifier: A deep-learning based fully
automatic offline spike sorting algorithm. Journal of Neural Engineering, 18(1):016009, 2021.
Changyu Seong, Wonjae Lee, and Dongsuk Jeon. A multi-channel spike sorting processor with accurate
clustering algorithm using convolutional autoencoder. IEEE Transactions on Biomedical Circuits and Systems,
15(6):1441–1453, 2021.
Daniel B Silversmith, Reza Abiri, Nicholas F Hardy, Nikhilesh Natraj, Adelyn Tu-Chan, Edward F Chang, and
Karunesh Ganguly. Plug-and-play control of a brain–computer interface through neural map stabilization.
Nature Biotechnology, 39(3):326–335, 2021.
Francis R Willett, Donald T Avansino, Leigh R Hochberg, Jaimie M Henderson, and Krishna V Shenoy.
High-performance brain-to-text communication via handwriting. Nature, 593(7858):249–254, 2021.
Yuanmeng Yan, Rumei Li, Sirui Wang, Fuzheng Zhang, Wei Wu, and Weiran Xu. Consert: A contrastive
framework for self-supervised sentence representation transfer. arXiv preprint arXiv:2105.11741, 2021.
Shengjie Zheng, Wenyi Li, Lang Qian, Chenggang He, and Xiaojian Li. A spiking neural network based on
neural manifold for augmenting intracortical brain-computer interface data. arXiv preprint arXiv:2204.05132,
2022.