Spatio-Temporal Transformer Networks for Trajectory Prediction in Autonomous Driving
Spatio-Temporal Transformer Networks for Trajectory Prediction in Autonomous Driving
Abstract
To safely and rationally participate in dense and heterogeneous traffic, autonomous vehi-
cles require to sufficiently analyze the motion patterns of surrounding traffic-agents and
accurately predict their future trajectories. This is challenging because the trajectories of
traffic-agents are not only influenced by the traffic-agents themselves but also by spatial
interaction with each other. Previous methods usually rely on the sequential step-by-step
processing of Long Short-Term Memory networks (LSTMs) and merely extract the inter-
actions between spatial neighbors for single type traffic-agents. We propose the Spatio-
Temporal Transformer Networks (S2TNet), which models the spatio-temporal interactions
by spatio-temporal Transformer and deals with the temporel sequences by temporal Trans-
former. We input additional category, shape and heading information into our networks
to handle the heterogeneity of traffic-agents. The proposed methods outperforms state-of-
the-art methods on ApolloScape Trajectory dataset by more than 7% on both the weighted
sum of Average and Final Displacement Error.
Keywords: Trajectory prediction, Transformer, Autonomous Driving.
1. Introduction
Autonomous driving is an innovative and advanced research field that can reduce the num-
ber of road fatalities, increase traffic efficiency, decrease environmental pollution and give
mobility to handicapped members of our society (Milakis et al. (2017)). In order to achieve
desired goals and avoid collisions of other agents, autonomous vehicles need to have the
ability to perceive the environment and make intelligent decisions. As a part of perception,
trajectory prediction can well reflect the future behaviors of surrounding agents and build
a bridge between perception and decision-making. However, complex temporal prediction
is inevitably accompanied by spatial agent-agent interactions at the same time, especially
in the dense and highly dynamic traffic composed of heterogeneous traffic-agents, including
pedestrians, cyclists, human drivers. The heterogeneity means that these traffic-agents have
diverse shapes, sizes, dynamics and behaviors. Moreover, a variety of potentially reason-
able spatial interactions between traffic-agents may occur, e.g. human drivers may overtake
another vehicle or slow down to follow other vehicles (Lefèvre et al. (2014)). Consequently,
trajectory prediction is a challenging task that plays an important role in autonomous
driving.
Classical methods treat traffic-agents as individual entities without any spatial inter-
actions and abstract their motion as kinematic and dynamic models (Brännström et al.
(2010)), Gaussian Processes (Rasmussen (2003)) and etc., making it difficult to compre-
hend complex scenarios or accomplish long-term predictions. With the success in deep
neural networks, recent trajectory prediction methods mainly focus on using these net-
works to extract features on spatial and temporal dimensions (Alahi et al. (2016); Huang
et al. (2019); Ivanovic and Pavone (2019); Mohamed et al. (2020);). Long Short-Term
Memory networks (LSTMs) are widely used for modeling temporal features. The LSTMs
are based on consecutively processing sequences and storing the latent states to represent
knowledge about the motion of traffic-agents (Giuliari et al. (2021)). However, LSTM-based
methods remember the history with a single vector with limited memory and regularly have
difficulty in handling complex temporal dependencies (Vaswani et al. (2017)). After that,
pooling mechanism (Deo and Trivedi (2018)), attention mechanisms (Ivanovic and Pavone
(2019)) and graph convolution mechanisms (Li et al. (2019); Yu et al. (2020)) are used to
model the spatial interactions. The limitation of these methods is that they only model the
interactions of spatially proximal traffic-agents and ignore the influence by traffic-agents
beyond the given spatial limits. This assumption may work well when the speed of traffic-
agents is low, but lose efficacy with speed increasing. Besides, the majority of trajectory
prediction algorithms are developed for homogeneous traffic-agents in a single scene, which
corresponding to human pedestrians in crowds (Alahi et al. (2016)) or moving vehicles on
a highway (Deo and Trivedi (2018). These methods may have great limitation on dealing
with dense urban environments where heterogeneous traffic-agents coexist and interact with
each other.
In this paper, we address all these limitations by employing Spatio-Temporal Trans-
former Networks (S2TNet) for heterogeneous traffic-agents trajectory prediction. S2TNet
is proposed based on the vanilla Transformer architecture, which discards the sequential
nature of data and models features with only the effective self-attention mechanism. For
the spatial dimension, we propose spatial self-attention mechanism to capture the interac-
tions between all traffic-agents in the road network, not limited to the interactions between
spatial neighbors. For the temporal dimension, temporal convolution network (TCN) is
adopted to extract temporal dependencies of consecutive frame and combined with spatial
self-attention to form the spatio-temporal Transformer where a set of new spatio-temporal
features are obtained. Based on temporal self-attention mechanism, temporal Transformer
could refine the temporal features for each traffic-agent independently and produce the fu-
ture trajectories auto-regressively. In addition to history trajectories, we input additional
shape, heading, category features into our networks to handle the heterogeneity of traffic-
agents. Main contributions of this paper are summarized as follows:
2. Background
2.1. Problem Formulation
Trajectory prediction aims to accurately predict the future long-term trajectories of traffic-
agents, given their history trajectories and other information such as shapes and categories.
The input of S2TNet is
X = [x1 , · · · , xtobs ] (1)
where,
xi = {(xi0 , y0i , l0i , w0i , θ0i , τ0i , · · · , xin , yni , lni , wni , θni , τni ) | i ∈ (1 : tobs )} (2)
are the history feature vectors (including global coordinates x and y, lengths l, widths
w, headings θ and categories τ ) of n traffic-agents being predicted in a road network. The
subscript n in (2) refers to all agents in general and varies with different scenes. We currently
take into account five types of traffic-agents c ∈ (1, 2, 3, 4, 5), representing small vehicles,
big vehicles, pedestrian, cyclist and others sequentially. We hold that additional features if
available to each traffic-agent could handle the heterogeneity of traffic-agents and improve
trajectory accuracy.
The output of S2TNet is
Y = [ytobs , · · · , ytf ut ] (3)
where,
yi = {(xi0 , y0i , · · · , xin , yni ) | i ∈ (tobs+1 : tf ut )} (4)
are the future feature vectors including global coordinates x and y. It is noted that S2TNet
outputs future positions of all observed traffic-agents simultaneously other than merely
predicting the location of one specific traffic-agent.
With the objective to hierarchically represent the trajectory sequences, we construct
a spatio-temporal graph G = (V, E) on a trajectory sequence with N traffic-agents and
T frames featuring both intra-frame and inter-frame connection. In this graph, the node
set V = {xti | t ∈ (1, T ), i ∈ (1, N )} includes all the feature vectors of traffic-agents, and
E represents the set of edges connected between nodes. We utilize node and traffic-agent
equally in the following description. The edge set E consists of two subsets. The fist subsets
depicts the virtual spatial connection between traffic-agents in the same frame, denotes
as ES = {(xti , xtj ) | i, j ∈ (1, N ), t ∈ (1, T )}. The second subset contains the temporal
edges which connects the same traffic-agent in consecutive frames as ET = {(xti , xi t1 ) | i ∈
(1, N ), t, t1 ∈ (1, T )}.
Predicted Trajectories N×
Output coordinates
y Add & Norm
x Trajectory Generator
Temporal Convolution
Temporal
FC
Transformer Decoder
History features Spatial Self-attention
τ T
w Temporal
l Transformer Encoder
θ N
y LEGEND
x
Concatenation
Spatio-temporal
FC Transformer Encoder Positional Encoding
QKT
Attention(Q,K,V) = sof tmax( √ )V (5)
dk
√
where dk is the dimension of each query. The division by dk is used to increase gradients
stability.
By adding multi-head attention mechanism, we can further improve the performance of
self-attention. It gives multiple representation sub-spaces for self-attention and enables the
model to jointly deal with information from varied sub-spaces at separate positions.
n2t n3t
n1t mt41
mt43 n14
t t
n42 n4t obs
m
42
n 4
mt44
Figure 2: Spatial and Temporal Self-Attention. (a) The spatial interactions of node 4 in
frame t is modeled. nti (i = 1, 2, 3, 4) is the embeddings of node i. mt4j (j =
1, 2, 3, 4) is the message passing from node j to 4. (b) The temporal correlations
between inter-frame are computed in temporal Transformer where the nodes are
independent of each other.
layer. In order to further capture the temporal dependencies on all history frames, we
perform post-processing of the input embeddings with the second temporal Transformer
encoder. Temporal Transformer decoder refines the output embeddings based on the spatio-
temporal features provided by encoders and the previously predicted output embeddings
produced by previously output coordinates. Finally, the trajectory generator outputs all
the traffic-agents future trajectories Y(tobs +1,tf ut ) simultaneously by decoding the output
embeddings.
The messages sent from all j to i is normalized over the weights of spatial-edges and summed
to get a single attention head of node i, as in the following:
X mtij
headti = sof tmax( √ )vj (8)
j
dk
By repeating this embedding extraction process h times, multi-head attention are concate-
nated and projected to output embeddings with an fully connected layer:
Qi · K T
where, headi = sof tmax( √ i )Vi (12)
dk
Where Qi , K i and V i are query, key and value matrix learned from the embeddings of input
node i.
Chen Wang Sun
Instead of fully connected network used in vanilla Transformer, the second sub-layer is
the separable convolution (Chollet (2016)) in order to achieve higher accuracy.
Decoder To inject the relative position information of previous output trajectories to
decoder, we add the positional encodings to output embeddings:
use Adam Kingma and Ba (2014) as the optimizer and impose a learning rate variation
strategy as follows:
where warmup step is set to 5000. Random rotation is implemented for data augmentation
in the training.
4. Experiments
4.1. Dataset and Evaluation Metrics
Our model is evaluated on ApolloScape Trajectory dataset (Ma et al. (2019)) which is
collected by Apollo autonomous vehicles. The ApolloScape Trajectory dataset contains
images, point clouds, and manually annotated trajectories. It is gathered under various
lighting conditions and traffic densities in Beijing, China. More specifically, it comprises
vastly complex traffic flows mixed with vehicles, riders, and pedestrians. The dataset in-
cludes 53 minute training sequences and 50 minute S2TNet sequences captured at 2 frames
per second. We need to predict six future frames based on six history frames. Due to the
S2TNet
testset of ApolloScape Trajectory dataset is not public, we obtain the results of our model
and other baselines by uploading to the ApolloScape Trajectory Leaderboard 1 .
Two metrics are used to evaluate model performance: the Average Displacement Error
(ADE) (Pellegrini et al. (2009)) and the Final Displacement Error (FDE). ADE is the
mean Euclidean distance over all predicted positions and ground truth positions during
the prediction time, and FDE is the last item of ADE. Obviously, ADE shows the average
prediction performance, while the FDE reflects just the prediction accuracy at the end
points. Because the trajectories of heterogeneous traffic-agents are diverse in scales, we use
the following weighted sum of ADE (WSADE) and weighted sum of FDE (WSFDE) as
metrics:
W SADE = Dv · ADEv + Dp · ADEp + Db · ADEb (17)
W SF DE = Dv · F DEv + Dp · F DEp + Db · F DEb (18)
where Dv = 0.20, Dp = 0.58, and Db = 0.22 are relevant with reciprocals of the average
velocity of vehicles, pedestrian and cyclist in the dataset.
4.2. Baselines
To evaluate the performance of S2TNet, we compare S2TNet with a wide range of baselines,
including:
• Constant Velocity (CV): We use the average velocity of history trajectories as the
constant velocity during the future to predict trajectories.
• TrafficPredict: A LSTM-based method using a hierarchical architecture by (Ma et al.
(2019)).
• StarNet: (Zhu et al. (2019)) builds a star topology to consider the collective influence
among all pedestrians.
• Social LSTM (S-LSTM): (Alahi et al. (2016)) uses LSTM to extract single pedestrian
feature and devises a social pooling mechanism to capture neighbor information.
• Social GAN (S-GAN): (Gupta et al. (2018)) predicts socially plausible futures by a
conditional GAN.
• Transformer : (Giuliari et al. (2021)) uses vanilla temporal Transformer to model
pedestrian separately without any complex human-human interactions nor scene in-
teraction terms.
• STAR: (Yu et al. (2020)) interleaves spatial and temporal Transformer to capture the
social intersection between pedestrians.
• TPNet: (Fang et al. (2020)) first generates a candidate set of future trajectories, then
gets the final predictions by classifying and refining the candidates.
• GRIP++: (Li et al. (2019)) is the SOTA trajectory predictor which uses a enhanced
graph to represent the interactions of close objects, and applies ST-GCNS to extract
spatio-temporal features.
1. http://apolloscape.auto/leader board.htmll
Chen Wang Sun
• S2TNet has the ability to forecast long horizon trajectories for different categories of
traffic-agents. After observing 6 frames (3s) of history trajectories, S2TNet could
accurately predict the trajectories over 3 seconds horizon. Moreover, S2TNet does
well in the case of sharp turns for the vehicle, e.g. Fig. 3(a) and (b). With the
increase of prediction length, the prediction results of S2TNet are more realistic and
the cumulative error of S2TNet is better than GRIP++, e.g. Fig. 3(c) and (d).
S2TNet
• S2TNet is able to model spatio-temporal interaction accurately. In the top right por-
tion of Fig. 3(e) and (f), a vehicle runs in opposite directions to an unknown traffic-
agent. While the predicted trajectories of GRIP++ deviates from ground truth,
S2TNet precisely captures the interactive routes.
• S2TNet successfully identify the stationary traffic-agent. In the lower-left of Fig. 3(e)
and (f), two vehicles decelerate to near standstill. Compared with GRIP++, S2TNet
successfully predicts the corresponding stationary trajectories.
(a) (b)
(e)
Figure 3: Visualized Prediction Results in heterogeneous and dense traffic. S2TNet success-
fully captures spatio-temporal information and outperforms the SOTA model,
GRIP++. (a, b, c, d) Comparison the future trajectories of different types of
traffic-agents between two methods. (e, f) The prediction results of GRIP++
and S2TNet in a complete traffic scene.
with (8). This indicates that temporal self-attention mechanism could effectively
improve the ability to extract temporal information.
• More features, higher accuracy. Instead of feeding all features into S2TNet, we input
only history trajectories in (6). We find that rich information helps the network to
understand the heterogeneity of traffic-agents.
• The spatial self attention of the whole scene is better than that of the given spatial
limits We use a masked attention mechanism in (7) to ignore the influence out of the
given spatial limits (15m) as (Li et al. (2019)) does. We find that the traffic-agents in
the whole scene have a great influence on the accuracy of trajectory prediction.
Components Performance
SS TCN TE TD HF LM (WSADE/WSFDE)
(1) × × SC SC A W 1.2300/2.2949
(2) × X SC SC A W 1.2189/2.2570
(3) X × SC SC A W 1.2500/2.3561
(4) X X × SC A W 1.2674/2.4086
(5) X X FC FC A W 1.1945/2.2613
(6) X X SC SC C W 1.2170/2.3036
(7) X X SC SC A N 1.2686/2.3548
(8) X X SC SC A W 1.1679/2.1798
5. Conclusion
In this paper, we propose S2TNet, a Transformer-based framework to predict the trajec-
tories of heterogeneous traffic-agents around autonomous driving cars. Spatio-temporal
Transformer is designed to capture spatio-temporal interactions between all traffic-agents,
not limited to spatial neighbor. The temporal Transformer is utilized to enhance modeling
S2TNet
Acknowledgments
This research is supported by National Natural Science Foundation of China (No. 61790563).
References
Alexandre Alahi, Kratarth Goel, Vignesh Ramanathan, Alexandre Robicquet, Li Fei-Fei,
and Silvio Savarese. Social lstm: Human trajectory prediction in crowded spaces. In
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages
961–971, 2016.
Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. Layer normalization, 2016.
Mattias Brännström, Erik Coelingh, and Jonas Sjöberg. Model-based threat assessment
for avoiding arbitrary vehicle collisions. IEEE Transactions on Intelligent Transportation
Systems, 11(3):658–669, 2010.
Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov,
and Sergey Zagoruyko. End-to-end object detection with transformers. In European
Conference on Computer Vision, pages 213–229. Springer, 2020.
Rohan Chandra, Uttaran Bhattacharya, Aniket Bera, and Dinesh Manocha. Traphic: Tra-
jectory prediction in dense and heterogeneous traffic using weighted interactions. In
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition,
pages 8483–8492, 2019.
François Chollet. Xception: Deep learning with depthwise separable convolutions. CoRR,
abs/1610.02357, 2016. URL http://arxiv.org/abs/1610.02357.
Nachiket Deo and Mohan M Trivedi. Convolutional social pooling for vehicle trajectory
prediction. In Proceedings of the IEEE Conference on Computer Vision and Pattern
Recognition Workshops, pages 1468–1476, 2018.
Liangji Fang, Qinhong Jiang, Jianping Shi, and Bolei Zhou. Tpnet: Trajectory proposal
network for motion prediction. In Proceedings of the IEEE/CVF Conference on Computer
Vision and Pattern Recognition, June 2020.
Francesco Giuliari, Irtiza Hasan, Marco Cristani, and Fabio Galasso. Transformer networks
for trajectory forecasting. In 2020 25th International Conference on Pattern Recognition,
pages 10335–10342. IEEE, 2021.
Chen Wang Sun
Agrim Gupta, Justin Johnson, Li Fei-Fei, Silvio Savarese, and Alexandre Alahi. So-
cial GAN: socially acceptable trajectories with generative adversarial networks. CoRR,
abs/1803.10892, 2018. URL http://arxiv.org/abs/1803.10892.
Yingfan Huang, Huikun Bi, Zhaoxin Li, Tianlu Mao, and Zhaoqi Wang. Stgat: Model-
ing spatial-temporal interactions for human trajectory prediction. In Proceedings of the
IEEE/CVF International Conference on Computer Vision, pages 6272–6281, 2019.
Boris Ivanovic and Marco Pavone. The trajectron: Probabilistic multi-agent trajectory
modeling with dynamic spatiotemporal graphs. In Proceedings of the IEEE/CVF Inter-
national Conference on Computer Vision, pages 2375–2384, 2019.
Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv
preprint arXiv:1412.6980, 2014.
Vineet Kosaraju, Amir Sadeghian, Roberto Martı́n-Martı́n, Ian Reid, S Hamid Rezatofighi,
and Silvio Savarese. Social-bigat: Multimodal trajectory forecasting using bicycle-gan
and graph attention networks. arXiv preprint arXiv:1907.03395, 2019.
Stéphanie Lefèvre, Dizan Vasquez, and Christian Laugier. A survey on motion prediction
and risk assessment for intelligent vehicles. ROBOMECH journal, 1(1):1–14, 2014.
Xin Li, Xiaowen Ying, and Mooi Choo Chuah. GRIP: graph-based interaction-aware tra-
jectory prediction. CoRR, abs/1907.07792, 2019. URL http://arxiv.org/abs/1907.
07792.
Yicheng Liu, Jinghuai Zhang, Liangji Fang, Qinhong Jiang, and Bolei Zhou. Multimodal
motion prediction with stacked transformers. arXiv preprint arXiv:2103.11624, 2021.
Yuexin Ma, Xinge Zhu, Sibo Zhang, Ruigang Yang, Wenping Wang, and Dinesh Manocha.
Trafficpredict: Trajectory prediction for heterogeneous traffic-agents. In Proceedings of
the AAAI Conference on Artificial Intelligence, volume 33, pages 6120–6127, 2019.
Dimitris Milakis, Bart Van Arem, and Bert Van Wee. Policy and society related implications
of automated driving: A review of literature and directions for future research. Journal
of Intelligent Transportation Systems, 21(4):324–348, 2017.
Abduallah Mohamed, Kun Qian, Mohamed Elhoseiny, and Christian Claudel. Social-stgcnn:
A social spatio-temporal graph convolutional neural network for human trajectory pre-
diction. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition, pages 14424–14432, 2020.
S. Pellegrini, A. Ess, K. Schindler, and L. van Gool. You’ll never walk alone: Modeling
social behavior for multi-target tracking. In 2009 IEEE 12th International Conference
on Computer Vision, pages 261–268, 2009. doi: 10.1109/ICCV.2009.5459260.
Amir Sadeghian, Vineet Kosaraju, Ali Sadeghian, Noriaki Hirose, Hamid Rezatofighi, and
Silvio Savarese. Sophie: An attentive gan for predicting paths compliant to social and
physical constraints. In Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition, pages 1349–1358, 2019.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. arXiv preprint
arXiv:1706.03762, 2017.
Cunjun Yu, Xiao Ma, Jiawei Ren, Haiyu Zhao, and Shuai Yi. Spatio-temporal graph
transformer networks for pedestrian trajectory prediction. CoRR, abs/2005.08514, 2020.
URL https://arxiv.org/abs/2005.08514.
Pu Zhang, Wanli Ouyang, Pengfei Zhang, Jianru Xue, and Nanning Zheng. Sr-lstm:
State refinement for lstm towards pedestrian trajectory prediction. In Proceedings of
the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12085–
12094, 2019.
Yanliang Zhu, Deheng Qian, Dongchun Ren, and Huaxia Xia. Starnet: Pedestrian trajectory
prediction using deep neural network in star topology. CoRR, abs/1906.01797, 2019. URL
http://arxiv.org/abs/1906.01797.
N×
Separable Convolution
N×
Masked
Temporal Self-attention
Temporal Self-attention