This study developed a robust, crop-aware transformer architecture for global-scale multi-crop yield prediction to address spatial heterogeneity, phenological dependencies and yield-limiting uncertainties by integrating domain-specific attention mechanisms with heteroscedastic uncertainty estimation within an ensemble meta-learning framework. The proposed methodology is illustrated in Fig 1.
Dataset description and preprocessing
This research was conducted at the Department of Computer Science and Applications, MDU Rohtak, in 2025. The experiment was implemented in Python 3.13.2 on a system with an Intel i7 processor and 16GB RAM. The model was trained on the GlobalCropYield5min dataset, which provides gridded annual yield estimates at a 5-arcminute spatial resolution (approximately 9 km²) for maize, rice, wheat and soybeans from 1982 to 2015
(Cao et al., 2025). Structured in the NetCDF format with temporal, latitudinal and longitudinal dimensions, this dataset integrates official production statistics and satellite observations to ensure statistical consistency.
At a 5-arcminute resolution, the global grid contains approximately 9.3 million grid cells, of which approximately 1 million represent agricultural land with usable yield observations across the study period. The data exhibited significant statistical dispersion, with coefficients of variation ranging from 35% for soybeans to 67% for maize, reflecting diverse global agro-environmental conditions. Geographically, these crops show distinct spatial distributions, ranging from the North American Corn Belt and the Argentine Pampas to the monsoon and temperate wheat belt regions of Asia.
The raw data underwent a five-stage processing pipeline. A 5 × 5 kernel arithmetic mean smoothing was first applied to reduce high-frequency noise from measurement errors and small-scale heterogeneity, while preserving regional yield patterns. To reduce computational requirements and avoid pseudo-replication from adjacent highly correlated pixels, every 4
th pixel was sampled in both latitude and longitude directions, resulting in effective ~20-arcminute spacing. For each sample, an 8-year window was constructed, consisting of a 4-year input sequence, a 3-year prediction gap and the target year. This 3-year gap prevents temporal leakage from year-to-year autocorrelation while maintaining sufficient historical context. Furthermore, grid cells with more than 30% missing data in the 8-year window were excluded, as were physically impossible yield values exceeding 20 metric tons per hectare for any crop and non-agricultural land cover classifications. Missing data gaps shorter than three consecutive years were filled using linear interpolation with forward/backward boundary-filling to maintain temporal series continuity. To prevent crop imbalance from biasing the model training, 15, 000 observations per crop were randomly sampled from the filtered dataset, resulting in a balanced final dataset of 60,000 samples. Moreover, a strict temporal holdout splitting was implemented, in which all training and cross-validation data came from 2013 and earlier (N=55,469 samples for training), while the test set comprised 2014-2015 (N=4,531 samples). This ensures that the model does not see future years during training, mimicking operational forecasting conditions.
Sixteen features were engineered to prioritize conservation over the raw performance. The full set comprised four temporal sequences, features (lagged yields at t-3 to t-6), two temporal trend features, three spatial features, three climate zone features, one phenological feature and three secular trend features.
Moreover, for each sample, the features are organized into a tensor of shape, where time-varying features change across time steps, whereas static features are repeated across all four time steps to maintain consistent tensor dimensions for the transformer input. A correlation was computed between all features and target yields, finding a maximum correlation of 0.58. The highest correlations were obtained from recent historical yields, which is expected and appropriate because recent performance informs near-term predictions.
Cross-validation protocol
A two-stage validation protocol was implemented that combined both temporal holdout and geographic blocking to prevent temporal and spatial leakage. In the temporal split, all model training, hyperparameter tuning and cross-validation used data from 2013 or earlier. The test set was held out completely until the final evaluation, ensuring that no future information leaked into the model’s development. For geographic blocking with GroupKFold, a 5-fold cross-validation was used within the training data using geographic blocks instead of random sampling. Each block was constructed at a 0.5° resolution by grouping all pixels with the same latitude and longitude. All pixels within the same geographic block remained together and were assigned entirely to training or validation within each fold, never split between them.
This blocking prevents spatial leakage, in which model training on nearby pixels can artificially inflate the validation performance. Our ablation studies (Section 3.6) quantify this effect: random cross-validation inflates R² by 1.9 percentage points compared with geographic blocking, demonstrating substantial spatial autocorrelation in global crop yield data.
Finally, each fold contained approximately 44,375 training samples and 11,094 validation samples.
Proposed model: Crop-aware transformer architecture
The core predictive engine is a modified transformer architecture designed to process multimodal agricultural data. It is a hybrid architecture in which the stage processes raw spatiotemporal features and outputs a preliminary yield prediction along with an uncertainty estimate modified according to the specific crop type. It contains five main components: (i) crop embeddings providing 32 dimensional learned representations per crop, (ii) input projection mapping combined features (16 original + 32 embedding=48 dimensions) to hidden space (128 dimensions), crop-aware multi-head attention with separate Query, Key, Value projection matrices per crop type (4 crops × 3 matrices × 128 × 128 =196, 608 dedicated parameters), in which for crop c: Q
c = X @ W
CQ, K
c= X @ W
CK, V
c=X @ W
CV. Each crop’s Q, K and V representations are then processed through standard multi-head attention (four heads, d
k=32 per head):
Training configuration
Model is trained using AdamW optimizer with learning rate 5×10
-4, weight decay 10
-4, β
1=0.9, β
2 = 0.999, ε=1e-8. A Cosine annealing learning rate schedule with T_max=50 epochs, decaying to a minimum learning rate of 0. Dropout regularization (p=0.1) in transformer layers, gradient clipping at max_norm=1.0 to prevent exploding gradients and mixed precision training (FP16 forward passes, FP32 gradients). Early stopping monitored validation R
2 with a patience of 10 epochs. The models typically converged in 30-45 epochs. We implemented comprehensive checkpointing to save model states, optimizer states, normalization scalers and training progress after each cross-validation fold and periodically during final training, which enabled exact reproducibility and resumption after interruptions.
Baseline models
We compared four strong baseline models using identical features, validation protocols and the same train-test splits: Ridge Regression (L2 regularization α =1.0), Random Forest (100 trees, depth 20), XGBoost (200 estimators, depth 8, learning rate 0.05) and LightGBM (200 estimators, depth 8, learning rate 0.05). All baselines used RobustScaler normalization applied separately for each cross-validation fold.
Algorithm1: Crop-aware transformer
Require: D: Dataset {(X
i, y
i,c
i)}; B: Geographic blocks; X
train: Multiscale spatial features, Y
train: Yield targets, C: Crop ID vector.
θ: Initial model parameters, E
max=50, α= 5×10
-4, β=32.
Ensure: θ*: Optimized Transformer parameters; CV and test performance.
1: Fit RobustScalers

,on training data; save scalers.
2: Temporal split: D
train ← {year ≤ 2013} (N = 55,469); D
test ← {year > 2013} (N = 4,531).
3: Geographic CV: Create 5 folds via GroupKFold on with D
train with groups = B (0.5°blocks)
4: For fold f = 1 to 5 do.
5: If checkpoint exists then load results; Continue.
6: Split: D
ftrain, D
fval ← fold f (blocks stay together).
7: Initialize model θ
f (seed =42 + f): Embeddings E
c;∈
32; Projection; W
proj ∈
45 × 128.
8: Initialize Transformer weights q and crop-aware projections; {W
CQ, W
CK, W
CV ∈ 128 × 128},
3 c = 0 Encoder: 4 layers; Heads: fμ, fσ.
9: Optimizer🡨Adam W(θ
f, lr = α, wd = 10
-4); scheduler ← Cosine annealing LR (T
max = 50).
10: For epoch e = 1 to E
max do.
11: Training for each batch (X
b, Y
b, C
b) in β do.
12: ec ← E
cb (lookup); X
45 ← concat [X
b, e
c]; H ← W
proj X
45.
13: Crop routing Q, K, V 🡨 0; for c ∈ {0,1,2,3} do.
14: Mask ← (c
b = c); Q [mask] ←W
Qc H [mask]; K, V similarly.
15: Scores ←QK
T
; attn ← softmax (scores); out ←attn V.
16: H
enc ←Transformer encoder (out); H
pool ←Avg pool(H
enc).
17:

← f
μ (H
pool);

←Softplus [fσ (H
pool)] + 10
-6.
18:
19: Backward; clip_grad (θ
f, 1.0); optimizer.step ().
20: Scheduler.step ().
21: Validation: ŷ
val, y
val ←evaluate (D
fval); val_r2 R² (y
val, ŷ
val).
22: If val_r2 > best_r2 then best_r2 val_r2; save checkpoint else patience + +.
23: If patience ≥ 10 then break.
24: Inverse transform: ŷ
orig ←
yield (ŷ
val); compute val_r2
orig, val_rmse
orig.
25: Compute cv_r2 ←mean (fold_r2s) ±std (fold_r2s).
26: Train final model on full D
train (85% train, 15% val for early stopping).
27: Test: Evaluate on D
test; compute test_r2, test_rmse, test_mae, uncertainties.
28: return θ*, {cv_r2, test_r2, test_rmse, test_mae, predictions, uncertainties}.
Stage 2: Ensemble meta-learning strategy
This is the final predictive stage, in which the trained deep learning model is treated as a feature extractor, generating predictions and uncertainty estimates for the training set. These outputs were fed into a LightGBM regressor (the meta-learner), which learned to correct the residual biases in the model outputs, leveraging the superior performance of the decision tree on tabular representations and improving the calibration of uncertainty estimates.
Algorithm 2: Ensemble meta-learning inference strategy
Require: θ* from Algorithm 1; D
train, D
test.
Ensure: ŷ
ensemble,
2calibrated.
Note: Provides marginal improvement (+0.2% R
2); primary results use Algorithm 1 only.
1: Feature extraction: Load model θ*; M
train ← [].
2: For each (X
i, C
i) ∈ D
train do.
3:
i,
i, ← model (X
i, C
i).
4: v
i ← [
i,
i, one_hot (C
i)] (meta-features, dim=6).
5: Append v
i to M
train.
6: Meta-learner:

←Light GBM (n = 200, depth = 8, lr = 0.05).
7: Train

on (M
train, y
train).
8: Temperature calibration: Split D
train into train
meta (85%), val
meta(15%).
9: For T ∈ {0.5, 0.75, 1.0, 1.25, 1.5, 2.0} do.
10:
cal ←
val .T; ECE ←compute_calibration_error.
11: If ECE <best_ECE then T
opt ←T.
12: Inference: For each (X
new, C
new) ∈ D
test do.
13:
trans,
trans ←model (X
new, C
new).
14: V
new ← [
trans,
trans, one_hot (C
new)].
15: ŷ
ensemble ←

(V
new),
cal ←
2trans . T
opt.
16: Ensemble_r2 ←R² (y
test, ŷ
ensemble).
17: Return ŷ
ensemble,
2calibrated, improvement (typically + 0.002 to + 0.003).