-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinterp_rocket.py
More file actions
4903 lines (4174 loc) · 175 KB
/
interp_rocket.py
File metadata and controls
4903 lines (4174 loc) · 175 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
interp_rocket.py - Interpretable ROCKET for Time Series Classification
A standalone, fully transparent reimplementation of MultiRocket (Tan et al.,
2022) with complete kernel-level interpretability. Designed for scientific
applications where understanding why a classifier makes its decisions is as
important as the accuracy.
Inspired by the transparent parameter storage in msROCKET (Lundy and O'Toole,
2021), which was based on the vanilla ROCKET classifier.
WHAT THIS DOES:
MultiRocket uses 84 deterministic base kernels (from MiniRocket) applied
at multiple dilations to both the raw signal and its first-order
difference. Four pooling operators (PPV, MPV, MIPV, LSPV) extract
features from each convolution output. A linear classifier (Ridge) trains
on these features.
Unlike sktime/aeon implementations, every parameter is stored in plain
numpy arrays with documented indexing, making it possible to:
1. Trace any feature back to its kernel, dilation, pooling op, and
signal type via decode_feature_index()
2. Visualize what each important kernel detects in a time series
3. Map classifier importance to temporal regions of the input
4. Identify robust features via cross-validation stability analysis
5. Test feature significance via permutation importance (PIMP)
6. Decompose feature contributions as redundant, synergistic, or
independent using information-theoretic methods
7. Assess temporal sensitivity via model-agnostic occlusion
ARCHITECTURE:
84 base kernels x D dilations x 2 representations x 4 pooling ops
where D depends on series length (controlled by max_dilations_per_kernel,
default 16). The distribution across dilations is fitted to the data.
FEATURE SELECTION:
Feature stability analysis (FSA) is the recommended method. It
identifies features that are consistently ranked as important across
cross-validation folds (Meinshausen and Buhlmann, 2010; Saeys et al.,
2008). Permutation importance (PIMP) provides an independent statistical
test using RandomForest to confirm that feature importance exceeds
chance (Altmann et al., 2010). Recursive feature elimination (RFE) is
available but not recommended as primary method due to sensitivity to
random seed and data split.
INTERPRETABILITY TOOLS:
- Temporal importance profiles (differential activation method)
- Receptive field diagrams (feature RF at peak discriminative location)
- Class-mean activation maps (kernel response on class-averaged signals)
- Aggregate activation (importance-weighted sum with differential)
- Multi-kernel summary (binary activation heatmap across features)
- Temporal occlusion sensitivity (per-trial and aggregate)
- Confusion-conditioned activation maps (correct vs. misclassified)
- Information decomposition (redundant/synergistic/independent)
- Kernel similarity network (correlation structure among features)
- Feature distribution analysis (per-class histograms)
COLOR PALETTE:
All plotting functions use a consistent tab10 hex palette defined at
module level (TAB10, POOLING_COLORS, INFO_COLORS).
KEY DIFFERENCES FROM SKTIME/AEON:
- All kernel weights, dilations, biases stored as accessible numpy arrays
- Complete feature to kernel to timepoint traceability
- Integrated visualization and analysis suite
- Feature stability analysis for robust feature selection
- Permutation importance with statistical testing (PIMP)
- Information-theoretic feature decomposition
- Class balancing via random oversampling for imbalanced data
- NumPy 2.x compatible
- Single-file, no framework dependencies beyond numpy/numba/sklearn/matplotlib
EXTENSIONS (in extensions/ directory):
- AMEE evaluation: perturbation-based saliency map ranking
- TSHAP integration: instance-level Shapley value attributions
- Channel selection: classifier-agnostic multivariate channel selection
- Kernel explorer: interactive tool for exploring kernels and pooling
REFERENCES:
Altmann, A., Tolosi, L., Sander, O., & Lengauer, T. (2010). Permutation
importance: a corrected feature importance measure. Bioinformatics,
26(10), 1340-1347.
Meinshausen, N. & Buhlmann, P. (2010). Stability selection. Journal of
the Royal Statistical Society: Series B, 72(4), 417-473.
Narayanan, N. S., Kimchi, E. Y., & Laubach, M. (2005). Redundancy and
synergy of neuronal ensembles in motor cortex. Journal of Neuroscience,
25(17), 4207-4216.
Tan, C. W., Dempster, A., Bergmeir, C., & Webb, G. I. (2022).
MultiRocket: multiple pooling operators and transformations for fast and
effective time series classification. Data Mining and Knowledge
Discovery, 36(5), 1623-1646.
Lundy, C., & O'Toole, J. M. (2021). Random convolution kernels with
multi-scale decomposition for preterm EEG inter-burst detection. In 2021
29th European Signal Processing Conference (EUSIPCO) (pp. 1182-1186).
Uribarri, G., Barone, F., Ansuini, A., & Fransen, E. (2024).
Detach-ROCKET: sequential feature selection for time series
classification with random convolutional kernels. Data Mining and
Knowledge Discovery, 38(6), 3922-3947.
USAGE:
import interp_rocket as IR
model = IR.InterpRocket(max_dilations_per_kernel=16)
model.fit(X_train, y_train)
metrics = model.evaluate(X_test, y_test)
# Feature stability analysis
stability = IR.cv_feature_stability(X_train, y_train)
stable = IR.get_stable_features(stability, threshold=0.8)
# Visualization (constrained by stable features)
model.plot_temporal_importance(X_test, y_test, feature_mask=stable)
IR.plot_receptive_field_diagram(model, X_test, y_test, feature_mask=stable)
# Permutation importance
pimp = IR.permutation_importance_test(model, X_train, y_train)
IR.plot_permutation_importance(pimp, model=model)
# Cross-validation
results = IR.cross_validate(X, y, n_repeats=10, n_folds=10, n_jobs=-2)
REQUIREMENTS:
numpy, numba (>=0.50), scikit-learn, matplotlib
Compatible with: NumPy 2.x, Python 3.10+
Author: Mark Laubach (American University, Department of Neuroscience)
Developed with Claude (Anthropic) as AI coding assistant.
License: BSD-3-Clause
"""
__version__ = "0.6.1"
import numpy as np
from itertools import combinations
from numba import njit, prange
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from sklearn.linear_model import RidgeClassifierCV
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
f1_score,
matthews_corrcoef,
confusion_matrix,
)
# ============================================================================
# COLOR PALETTE (tab10 as hex, used throughout all plotting functions)
# ============================================================================
TAB10 = [
"#1f77b4", # blue
"#ff7f0e", # orange
"#2ca02c", # green
"#d62728", # red
"#9467bd", # purple
"#8c564b", # brown
"#e377c2", # pink
"#7f7f7f", # gray
"#bcbd22", # olive
"#17becf", # cyan
]
POOLING_COLORS = {
"PPV": "#1f77b4", # blue
"MPV": "#ff7f0e", # orange
"MIPV": "#2ca02c", # green
"LSPV": "#9467bd", # purple
}
INFO_COLORS = {
"redundant": "#ff7f0e", # orange
"synergistic": "#1f77b4", # blue
"independent": "#7f7f7f", # gray
}
# ============================================================================
# SECTION 1: THE 84 BASE KERNELS
# ============================================================================
#
# MiniRocket/MultiRocket use 84 deterministic kernels of length 9.
# Each kernel has weights from {-1, 2}: six positions get -1, three get 2.
# The 84 kernels enumerate all C(9,3) = 84 ways to choose which 3 of 9
# positions receive the weight 2 (the rest get -1).
def _generate_base_kernels():
"""
Generate the 84 deterministic MiniRocket base kernels.
Returns
-------
kernels : ndarray, shape (84, 9), dtype float32
Each row is a length-9 kernel with weights in {-1, 2}.
indices : ndarray, shape (84, 3), dtype int32
The 3 positions (of 9) that receive weight 2 in each kernel.
"""
indices = np.array([combo for combo in combinations(range(9), 3)], dtype=np.int32)
kernels = np.full((84, 9), -1.0, dtype=np.float32)
for i, idx in enumerate(indices):
kernels[i, idx] = 2.0
return kernels, indices
# ============================================================================
# SECTION 2: DILATION FITTING
# ============================================================================
#
# Dilations control the temporal scale each kernel operates at.
# For a kernel of length 9 with dilation d, the receptive field spans
# 1 + 8*d timepoints. The set of dilations is chosen so that the largest
# dilation produces a receptive field just under the series length.
@njit(cache=True)
def _fit_dilations(input_length, num_features, max_dilations_per_kernel):
"""
Determine dilations and features-per-dilation for given series length.
Follows the MiniRocket/MultiRocket algorithm exactly:
- max dilation = (input_length - 1) / (9 - 1), ensuring receptive field
fits within the series
- dilations are exponentially spaced: 2^0, 2^1, ..., 2^(num_dilations-1)
- features are distributed across dilations as evenly as possible
Parameters
----------
input_length : int
Length of input time series.
num_features : int
Target number of features (will be rounded to multiple of 84).
max_dilations_per_kernel : int
Maximum number of distinct dilations to use.
Returns
-------
dilations : ndarray of int32
The dilation values to use.
num_features_per_dilation : ndarray of int32
How many features (biases) to generate per dilation.
"""
num_kernels = 84
num_features_per_kernel = num_features // num_kernels
true_max_dilations_per_kernel = min(
num_features_per_kernel, max_dilations_per_kernel
)
multiplier = num_features_per_kernel / true_max_dilations_per_kernel
# Maximum dilation such that receptive field fits in input
max_exponent = np.log2((input_length - 1) / (9 - 1))
max_exponent = max(max_exponent, 0)
num_dilations = min(true_max_dilations_per_kernel, int(max_exponent) + 1)
# Exponentially spaced dilations
dilations = np.zeros(num_dilations, dtype=np.int32)
for i in range(num_dilations):
dilations[i] = np.int32(2 ** (i * max_exponent / max(num_dilations - 1, 1)))
# Distribute features across dilations
num_features_per_dilation = np.zeros(num_dilations, dtype=np.int32)
for i in range(num_dilations):
num_features_per_dilation[i] = np.int32(
(i + 1) * multiplier - np.sum(num_features_per_dilation)
)
return dilations, num_features_per_dilation
# ============================================================================
# SECTION 3: QUANTILE GENERATION (for bias selection)
# ============================================================================
@njit(cache=True)
def _quantiles(n):
"""
Generate low-discrepancy quantiles using the golden ratio sequence.
These quantiles are used to sample biases from the distribution of
convolution outputs on training data. The golden ratio sequence
produces quasi-random, well-distributed quantiles.
Parameters
----------
n : int
Number of quantiles to generate.
Returns
-------
quantiles : ndarray of float32
Values in (0, 1), well-distributed.
"""
phi = (np.sqrt(np.float32(5.0)) + 1.0) / 2.0
quantiles = np.zeros(n, dtype=np.float32)
for i in range(n):
quantiles[i] = ((i + 1) * phi) % 1.0
return quantiles
# ============================================================================
# SECTION 4: BIAS FITTING
# ============================================================================
#
# Biases are the only truly data-dependent parameters. They are set by
# convolving a sample of training instances with each kernel at each
# dilation, then selecting quantiles of the resulting convolution outputs
# as bias thresholds.
@njit(fastmath=True, cache=True)
def _fit_biases(X, dilations, num_features_per_dilation, quantiles, random_state_seed):
"""
Fit biases from training data for all 84 kernels × all dilations.
Parameters
----------
X : ndarray, shape (n_instances, n_timepoints), dtype float32
Training time series.
dilations : ndarray of int32
Dilation values.
num_features_per_dilation : ndarray of int32
Number of features (biases) per dilation.
quantiles : ndarray of float32
Quantile positions for bias selection.
random_state_seed : int
Seed for reproducibility.
Returns
-------
biases : ndarray of float32
One bias per feature. Length = 84 * sum(num_features_per_dilation).
"""
np.random.seed(random_state_seed)
num_instances, input_length = X.shape
# The 84 index patterns (positions of weight=2 in each kernel)
# Regenerated here inside numba context
indices_raw = np.zeros((84, 3), dtype=np.int32)
count = 0
for i in range(9):
for j in range(i + 1, 9):
for k in range(j + 1, 9):
indices_raw[count, 0] = i
indices_raw[count, 1] = j
indices_raw[count, 2] = k
count += 1
num_kernels = 84
num_dilations = len(dilations)
num_features_total = num_kernels * np.sum(num_features_per_dilation)
biases = np.zeros(num_features_total, dtype=np.float32)
# Use min(n_instances, 10) examples for bias fitting (like MiniRocket)
num_examples = min(num_instances, 100)
feature_idx = 0
for dilation_index in range(num_dilations):
dilation = dilations[dilation_index]
padding = ((9 - 1) * dilation) // 2
num_features_this_dilation = num_features_per_dilation[dilation_index]
for kernel_index in range(num_kernels):
# Indices of the 3 positions with weight 2
i0 = indices_raw[kernel_index, 0]
i1 = indices_raw[kernel_index, 1]
i2 = indices_raw[kernel_index, 2]
# Collect convolution outputs from sample of training instances
# For each example, compute the convolution output at each position
n_conv = input_length + 2 * padding - (9 - 1) * dilation
if n_conv < 1:
n_conv = 1
C_all = np.zeros(num_examples * n_conv, dtype=np.float32)
c_idx = 0
for example_index in range(num_examples):
# Random selection of training examples
ex = np.random.randint(num_instances)
x = X[ex]
for t in range(n_conv):
# Sum of all 9 positions at this dilation
total = np.float32(0.0)
for pos in range(9):
input_idx = t + pos * dilation - padding
if 0 <= input_idx < input_length:
total += x[input_idx]
# The kernel: -1 at 6 positions, +2 at 3 positions
# = -sum_all + 3 * sum_at_indices
# Because: -1*sum_all + 2*sum_indices + 1*sum_indices
# Wait, let's be precise:
# w = -1 everywhere, then w[i0,i1,i2] = 2
# conv = sum(w * x_dilated) = -sum_all + 3*sum_at_indices
sum_at_indices = np.float32(0.0)
for idx_val in (i0, i1, i2):
input_idx = t + idx_val * dilation - padding
if 0 <= input_idx < input_length:
sum_at_indices += x[input_idx]
conv_val = -total + 3.0 * sum_at_indices
C_all[c_idx] = conv_val
c_idx += 1
# Select biases as quantiles of the convolution output
C_sorted = np.sort(C_all[:c_idx])
for feature_count in range(num_features_this_dilation):
q = quantiles[feature_idx]
bias_index = int(q * (c_idx - 1))
if bias_index < 0:
bias_index = 0
if bias_index >= c_idx:
bias_index = c_idx - 1
biases[feature_idx] = C_sorted[bias_index]
feature_idx += 1
return biases
# ============================================================================
# SECTION 5: TRANSFORM — The Core Convolution + Pooling
# ============================================================================
#
# This is where the features are actually extracted. For each series:
# 1. Convolve with each of 84 kernels at each dilation
# 2. Subtract each bias → get binary indicator (>0 or not)
# 3. Compute 4 pooling operators: PPV, MPV, MIPV, LSPV
#
# MultiRocket does this for both raw signal and first-order difference.
@njit(fastmath=True, parallel=True, cache=True)
def _transform(X, dilations, num_features_per_dilation, biases):
"""
Transform time series using MiniRocket-style convolution + 4 pooling ops.
Parameters
----------
X : ndarray, shape (n_instances, n_timepoints), dtype float32
dilations : ndarray of int32
num_features_per_dilation : ndarray of int32
biases : ndarray of float32
Returns
-------
features : ndarray, shape (n_instances, n_features * 4)
4 features per bias: PPV, MPV, MIPV, LSPV
"""
num_instances, input_length = X.shape
num_kernels = 84
num_dilations = len(dilations)
num_features_per_rep = num_kernels * np.sum(num_features_per_dilation)
# 4 pooling operators per feature
features = np.zeros((num_instances, num_features_per_rep * 4), dtype=np.float32)
# Regenerate indices inside numba
indices_raw = np.zeros((84, 3), dtype=np.int32)
count = 0
for i in range(9):
for j in range(i + 1, 9):
for k in range(j + 1, 9):
indices_raw[count, 0] = i
indices_raw[count, 1] = j
indices_raw[count, 2] = k
count += 1
for instance_idx in prange(num_instances):
x = X[instance_idx]
feature_idx = 0
for dilation_index in range(num_dilations):
dilation = dilations[dilation_index]
padding = ((9 - 1) * dilation) // 2
num_features_this_dilation = num_features_per_dilation[dilation_index]
n_conv = input_length + 2 * padding - (9 - 1) * dilation
if n_conv < 1:
n_conv = 1
for kernel_index in range(num_kernels):
i0 = indices_raw[kernel_index, 0]
i1 = indices_raw[kernel_index, 1]
i2 = indices_raw[kernel_index, 2]
# Compute full convolution output
C = np.zeros(n_conv, dtype=np.float32)
for t in range(n_conv):
total = np.float32(0.0)
for pos in range(9):
input_idx = t + pos * dilation - padding
if 0 <= input_idx < input_length:
total += x[input_idx]
sum_at_indices = np.float32(0.0)
for idx_val in (i0, i1, i2):
input_idx = t + idx_val * dilation - padding
if 0 <= input_idx < input_length:
sum_at_indices += x[input_idx]
C[t] = -total + 3.0 * sum_at_indices
# For each bias, compute the 4 pooling operators
for feature_count in range(num_features_this_dilation):
bias = biases[feature_idx]
# ---- PPV: Proportion of Positive Values ----
ppv_count = 0
# ---- MPV: Mean of Positive Values ----
mpv_sum = np.float32(0.0)
mpv_count = 0
# ---- MIPV: Mean of Indices of Positive Values ----
mipv_sum = np.float32(0.0)
# ---- LSPV: Longest Stretch of Positive Values ----
lspv_max = 0
lspv_current = 0
for t in range(n_conv):
val = C[t] - bias # shifted value
if val > 0:
ppv_count += 1
mpv_sum += val
mpv_count += 1
mipv_sum += t
lspv_current += 1
if lspv_current > lspv_max:
lspv_max = lspv_current
else:
lspv_current = 0
ppv = np.float32(ppv_count) / np.float32(n_conv)
mpv = (
mpv_sum / np.float32(mpv_count)
if mpv_count > 0
else np.float32(0.0)
)
mipv = (
mipv_sum / np.float32(ppv_count)
if ppv_count > 0
else np.float32(-1.0)
)
lspv = np.float32(lspv_max)
# Store: 4 features per bias, contiguously
base = feature_idx * 4
features[instance_idx, base] = ppv
features[instance_idx, base + 1] = mpv
features[instance_idx, base + 2] = mipv
features[instance_idx, base + 3] = lspv
feature_idx += 1
return features
# ============================================================================
# SECTION 6: PER-TIMEPOINT ACTIVATION MAP
# ============================================================================
@njit(fastmath=True, cache=True)
def compute_activation_map(x, kernel_index, dilation, bias):
"""
Compute the per-timepoint convolution output for a single kernel+dilation+bias.
Returns the raw convolution output (before bias subtraction) and the
binary activation (after bias subtraction), allowing visualization of
exactly where this kernel "fires" on the input.
Parameters
----------
x : ndarray, shape (n_timepoints,), dtype float32
kernel_index : int
Which of the 84 base kernels (0-83).
dilation : int
bias : float32
Returns
-------
conv_output : ndarray, shape (n_conv,)
Raw convolution values.
activation : ndarray, shape (n_conv,)
Binary: 1 where conv_output > bias, else 0.
time_indices : ndarray, shape (n_conv,)
The center timepoint each convolution position maps to.
"""
input_length = len(x)
padding = ((9 - 1) * dilation) // 2
# Regenerate indices
indices_raw = np.zeros((84, 3), dtype=np.int32)
count = 0
for i in range(9):
for j in range(i + 1, 9):
for k in range(j + 1, 9):
indices_raw[count, 0] = i
indices_raw[count, 1] = j
indices_raw[count, 2] = k
count += 1
i0 = indices_raw[kernel_index, 0]
i1 = indices_raw[kernel_index, 1]
i2 = indices_raw[kernel_index, 2]
n_conv = input_length + 2 * padding - (9 - 1) * dilation
if n_conv < 1:
n_conv = 1
conv_output = np.zeros(n_conv, dtype=np.float32)
activation = np.zeros(n_conv, dtype=np.float32)
time_indices = np.zeros(n_conv, dtype=np.float32)
for t in range(n_conv):
# Center of receptive field
center = t - padding + 4 * dilation
time_indices[t] = np.float32(center)
total = np.float32(0.0)
for pos in range(9):
input_idx = t + pos * dilation - padding
if 0 <= input_idx < input_length:
total += x[input_idx]
sum_at_indices = np.float32(0.0)
for idx_val in (i0, i1, i2):
input_idx = t + idx_val * dilation - padding
if 0 <= input_idx < input_length:
sum_at_indices += x[input_idx]
conv_val = -total + 3.0 * sum_at_indices
conv_output[t] = conv_val
activation[t] = np.float32(1.0) if conv_val > bias else np.float32(0.0)
return conv_output, activation, time_indices
# ============================================================================
# SECTION 7: MUTUAL INFORMATION
# ============================================================================
#
# Information-theoretic classification metric. Measures how much knowing
# the predicted label reduces uncertainty about the true label.
# Ported from R code by Mark Laubach (version from 2005).
def mutual_information(y_true=None, y_pred=None, cm=None, base=2):
"""
Calculate mutual information between true and predicted labels.
Parameters
----------
y_true : array-like, optional
True class labels.
y_pred : array-like, optional
Predicted class labels.
cm : array-like, optional
Pre-computed confusion matrix (rows=true, cols=predicted).
base : int or float, default=2
Logarithm base. Use 2 for bits, np.e for nats.
Returns
-------
mi : float
Mutual information in specified units (bits if base=2).
"""
if cm is None:
if y_true is None or y_pred is None:
raise ValueError("Must provide either (y_true, y_pred) or cm")
cm = confusion_matrix(y_true, y_pred)
else:
cm = np.asarray(cm)
total = cm.sum()
if total == 0:
return 0.0
p_joint = cm / total
p_true = p_joint.sum(axis=1)
p_pred = p_joint.sum(axis=0)
mi = 0.0
n_classes_true, n_classes_pred = p_joint.shape
for i in range(n_classes_true):
for j in range(n_classes_pred):
if p_joint[i, j] > 0 and p_true[i] > 0 and p_pred[j] > 0:
mi += (
p_joint[i, j]
* np.log(p_joint[i, j] / (p_true[i] * p_pred[j]))
/ np.log(base)
)
return mi
def _compute_all_metrics(y_true, y_pred):
"""
Compute all classification metrics.
Returns
-------
metrics : dict with keys:
'accuracy', 'balanced_accuracy', 'f1_macro', 'f1_weighted',
'mcc', 'mutual_info'
"""
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
n_classes = len(np.unique(y_true))
avg = "binary" if n_classes == 2 else "macro"
return {
"accuracy": float(accuracy_score(y_true, y_pred)),
"balanced_accuracy": float(balanced_accuracy_score(y_true, y_pred)),
"f1_macro": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
"f1_weighted": float(
f1_score(y_true, y_pred, average="weighted", zero_division=0)
),
"mcc": float(matthews_corrcoef(y_true, y_pred)),
"mutual_info": float(mutual_information(y_true=y_true, y_pred=y_pred)),
}
# ============================================================================
# SECTION 8: THE INTERPRETABLE ROCKET CLASS
# ============================================================================
class InterpRocket(BaseEstimator, ClassifierMixin):
"""
Interpretable MultiRocket classifier.
Provides full traceability from classifier decision → feature importance →
kernel identity → temporal activation pattern.
Inherits from sklearn.base.BaseEstimator and ClassifierMixin, providing
get_params(), set_params(), and a standard score(X, y) that returns
accuracy as a scalar for compatibility with sklearn pipelines. For the
full multi-metric evaluation, use evaluate(X, y).
Parameters
----------
max_dilations_per_kernel : int, default=32
Maximum number of dilation values per kernel.
num_features : int, default=10000
Target number of features per representation.
Actual count: 2 representations × 4 pooling ops × (rounded to 84 multiple).
random_state : int, default=0
Seed for reproducibility (only affects bias fitting and class balancing).
alpha_range : ndarray, optional
Range of Ridge regularization parameters.
class_weight : str or None, default=None
If 'balanced', randomly oversample minority class(es) to match
the majority class count before fitting. Resampling is applied to
the raw time series before the transform.
"""
# The 4 pooling operator names, in feature order
POOLING_NAMES = ["PPV", "MPV", "MIPV", "LSPV"]
def __init__(
self,
max_dilations_per_kernel=16, # suggested for use with I-ROCKET
num_features=10000,
random_state=0,
alpha_range=None,
class_weight=None,
):
self.max_dilations_per_kernel = max_dilations_per_kernel
self.num_features = num_features
self.random_state = random_state
self.alpha_range = alpha_range or np.logspace(-10, 10, 20)
self.class_weight = class_weight
# Will be set during fit()
self.base_kernels_ = None
self.base_indices_ = None
self.dilations_raw_ = None
self.dilations_diff_ = None
self.num_features_per_dilation_raw_ = None
self.num_features_per_dilation_diff_ = None
self.biases_raw_ = None
self.biases_diff_ = None
self.classifier_ = None
self.scaler_ = None
self.n_features_per_rep_ = None
self.classes_ = None
def fit(self, X, y):
"""
Fit the MultiRocket transform and Ridge classifier.
Parameters
----------
X : ndarray, shape (n_instances, n_timepoints)
Training time series. Will be converted to float32.
y : array-like, shape (n_instances,)
Class labels.
Returns
-------
self
Notes
-----
If class_weight='balanced', the minority class(es) are randomly
oversampled (with replacement) to match the majority class count
before fitting. This resampling is applied to the raw time series
before the transform, ensuring the convolution outputs and bias
quantiles reflect the balanced distribution.
"""
X = np.asarray(X, dtype=np.float32)
y = np.asarray(y)
self.classes_ = np.unique(y)
# --- Class balancing via random oversampling ---
if self.class_weight == "balanced":
classes, counts = np.unique(y, return_counts=True)
max_count = counts.max()
oversample_idx = []
rng = np.random.default_rng(self.random_state)
for cls, cnt in zip(classes, counts):
cls_idx = np.where(y == cls)[0]
if cnt < max_count:
extra = rng.choice(cls_idx, size=max_count - cnt, replace=True)
cls_idx = np.concatenate([cls_idx, extra])
oversample_idx.append(cls_idx)
oversample_idx = np.concatenate(oversample_idx)
rng.shuffle(oversample_idx)
X = X[oversample_idx]
y = y[oversample_idx]
print(
f" Class balancing: oversampled to {len(y)} instances "
f"({max_count} per class)"
)
n_instances, input_length = X.shape
print(f"InterpRocket.fit: {n_instances} instances × {input_length} timepoints")
print(f" Classes: {self.classes_}")
# Generate the 84 base kernels
self.base_kernels_, self.base_indices_ = _generate_base_kernels()
# --- Raw representation ---
print(" Fitting dilations (raw)...")
self.dilations_raw_, self.num_features_per_dilation_raw_ = _fit_dilations(
input_length, self.num_features, self.max_dilations_per_kernel
)
n_features_raw = 84 * np.sum(self.num_features_per_dilation_raw_)
quantiles_raw = _quantiles(n_features_raw)
print(
f" Fitting biases (raw): {n_features_raw} biases across "
f"{len(self.dilations_raw_)} dilations..."
)
self.biases_raw_ = _fit_biases(
X,
self.dilations_raw_,
self.num_features_per_dilation_raw_,
quantiles_raw,
self.random_state,
)
# --- First-difference representation ---
X_diff = np.diff(X, axis=1).astype(np.float32)
diff_length = X_diff.shape[1]
print(" Fitting dilations (diff)...")
self.dilations_diff_, self.num_features_per_dilation_diff_ = _fit_dilations(
diff_length, self.num_features, self.max_dilations_per_kernel
)
n_features_diff = 84 * np.sum(self.num_features_per_dilation_diff_)
quantiles_diff = _quantiles(n_features_diff)
print(
f" Fitting biases (diff): {n_features_diff} biases across "
f"{len(self.dilations_diff_)} dilations..."
)
self.biases_diff_ = _fit_biases(
X_diff,
self.dilations_diff_,
self.num_features_per_dilation_diff_,
quantiles_diff,
self.random_state + 1, # different seed for diff
)
self.n_features_per_rep_ = (n_features_raw, n_features_diff)
# --- Transform ---
print(" Transforming training data...")
X_features = self._transform(X)
print(f" Feature matrix: {X_features.shape}")
# --- Standardize features ---
# Ridge regression's L2 penalty is sensitive to feature scale.
# PPV features are in [0,1] while LSPV and MPV can have much larger
# ranges. Without standardization, the penalty disproportionately
# shrinks large-scale features regardless of their discriminative
# value. This matches the standard ROCKET pipeline.
print(" Standardizing features...")
self.scaler_ = StandardScaler(with_mean=True)
X_features = self.scaler_.fit_transform(X_features)
# --- Fit classifier ---
print(" Fitting RidgeClassifierCV...")
self.classifier_ = RidgeClassifierCV(alphas=self.alpha_range)
self.classifier_.fit(X_features, y)
train_acc = self.classifier_.score(X_features, y)
print(f" Training accuracy: {train_acc:.4f}")
print(f" Selected alpha: {self.classifier_.alpha_:.4f}")
return self
def _transform(self, X):
"""
Transform time series to feature vectors.
Returns concatenated features: [raw_PPV, raw_MPV, raw_MIPV, raw_LSPV,
diff_PPV, diff_MPV, diff_MIPV, diff_LSPV]
Parameters
----------
X : ndarray, shape (n_instances, n_timepoints), dtype float32
Returns
-------
features : ndarray, shape (n_instances, n_total_features)
"""
X = np.asarray(X, dtype=np.float32)
# Raw features
features_raw = _transform(
X,
self.dilations_raw_,
self.num_features_per_dilation_raw_,
self.biases_raw_,
)
# First-difference features
X_diff = np.diff(X, axis=1).astype(np.float32)
features_diff = _transform(
X_diff,
self.dilations_diff_,
self.num_features_per_dilation_diff_,
self.biases_diff_,
)
return np.concatenate([features_raw, features_diff], axis=1)
def transform(self, X):
"""Public transform method. Returns raw (unscaled) features."""
return self._transform(X)
def predict(self, X):
"""Predict class labels."""
X_features = self._transform(X)
X_features = self.scaler_.transform(X_features)
return self.classifier_.predict(X_features)
def score(self, X, y):
"""
Return classification accuracy as a scalar (sklearn convention).
This method exists for compatibility with sklearn pipelines,
GridSearchCV, and cross_val_score. For the full multi-metric
evaluation, use evaluate(X, y).
Parameters
----------
X : ndarray, shape (n_instances, n_timepoints)
y : array-like
Returns
-------
accuracy : float
"""
y_pred = self.predict(X)
return float(accuracy_score(np.asarray(y), y_pred))
def evaluate(self, X, y):
"""
Evaluate on test data, returning multiple metrics.
Parameters
----------
X : ndarray, shape (n_instances, n_timepoints)
y : array-like
Returns
-------
metrics : dict with keys:
'accuracy', 'balanced_accuracy', 'f1_macro', 'f1_weighted',
'mcc' (Matthews correlation coefficient),
'mutual_info' (bits)
"""
y_pred = self.predict(X)
return _compute_all_metrics(np.asarray(y), y_pred)