Skip to content

Commit 1277ecd

Browse files
Give NominalVariables pretty names and update dprint tests
1 parent 418b89f commit 1277ecd

4 files changed

Lines changed: 57 additions & 48 deletions

File tree

pytensor/graph/basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,8 @@ def _reduce(self):
695695
return cls, (self.id, self.type)
696696

697697
def _str(self):
698+
if self.name is not None:
699+
return self.name
698700
return f"*{self.id}-{var_type.__str__(self)}"
699701

700702
new_type = type(
@@ -707,6 +709,11 @@ def _str(self):
707709
return cls.__instances__[(typ, id)]
708710

709711
def __init__(self, id: _IdType, typ: _TypeType, name: str | None = None):
712+
if hasattr(self, "id"):
713+
# Cache hit — only update name if explicitly provided
714+
if name is not None:
715+
self.name = name
716+
return
710717
self.id = id
711718
super().__init__(type=typ, name=name)
712719

pytensor/graph/fg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def __init__(
965965
from pytensor.graph.basic import FrozenApply
966966

967967
nominal_inputs = tuple(
968-
NominalVariable(i, inp.type, name=inp.name) for i, inp in enumerate(inputs)
968+
NominalVariable(i, inp.type, name=f"i{i}") for i, inp in enumerate(inputs)
969969
)
970970

971971
memo: dict[Variable, Variable] = dict(zip(inputs, nominal_inputs, strict=True))
@@ -1013,6 +1013,8 @@ def __init__(
10131013

10141014
self.inputs: tuple[Variable, ...] = nominal_inputs
10151015
self.outputs: tuple[Variable, ...] = frozen_outputs
1016+
for i, out in enumerate(frozen_outputs):
1017+
out.name = f"o{i}"
10161018
self._variables: set[Variable] | None = None
10171019
self._apply_nodes: set[Apply] | None = None
10181020
self._clients: dict[Variable, list[ClientType]] | None = None

tests/compile/test_builders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,10 +904,10 @@ def test_debugprint():
904904
905905
OpFromGraph{inline=False} [id A]
906906
← Add [id E]
907-
├─ *0-<Matrix(float64, shape=(?, ?))> [id F]
907+
├─ i0 [id F]
908908
└─ Mul [id G]
909-
├─ *1-<Matrix(float64, shape=(?, ?))> [id H]
910-
└─ *2-<Matrix(float64, shape=(?, ?))> [id I]
909+
├─ i1 [id H]
910+
└─ i2 [id I]
911911
"""
912912

913913
for truth, out in zip(exp_res.split("\n"), lines, strict=True):

tests/scan/test_printing.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def test_debugprint_sitsot():
5959
6060
Scan{scan_fn, while_loop=False, inplace=none} [id C]
6161
← Mul [id U] (inner_out_sit_sot-0)
62-
├─ *0-<Vector(float64, shape=(?,))> [id V] -> [id E] (inner_in_sit_sot-0)
63-
└─ *1-<Vector(float64, shape=(?,))> [id W] -> [id L] (inner_in_non_seqs-0)
62+
├─ i0 [id V] -> [id E] (inner_in_sit_sot-0)
63+
└─ i1 [id W] -> [id L] (inner_in_non_seqs-0)
6464
"""
6565

6666
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -116,8 +116,8 @@ def test_debugprint_sitsot_no_extra_info():
116116
117117
Scan{scan_fn, while_loop=False, inplace=none} [id C]
118118
← Mul [id U]
119-
├─ *0-<Vector(float64, shape=(?,))> [id V] -> [id E]
120-
└─ *1-<Vector(float64, shape=(?,))> [id W] -> [id L]
119+
├─ i0 [id V] -> [id E]
120+
└─ i1 [id W] -> [id L]
121121
"""
122122

123123
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -183,10 +183,10 @@ def test_debugprint_nitsot():
183183
184184
Scan{scan_fn, while_loop=False, inplace=none} [id B]
185185
← Mul [id X] (inner_out_nit_sot-0)
186-
├─ *0-<Scalar(float64, shape=())> [id Y] -> [id S] (inner_in_seqs-0)
186+
├─ i0 [id Y] -> [id S] (inner_in_seqs-0)
187187
└─ Pow [id Z]
188-
├─ *2-<Scalar(float64, shape=())> [id BA] -> [id W] (inner_in_non_seqs-0)
189-
└─ *1-<Scalar(int64, shape=())> [id BB] -> [id U] (inner_in_seqs-1)
188+
├─ i2 [id BA] -> [id W] (inner_in_non_seqs-0)
189+
└─ i1 [id BB] -> [id U] (inner_in_seqs-1)
190190
"""
191191

192192
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -264,21 +264,21 @@ def compute_A_k(A, k):
264264
Scan{scan_fn, while_loop=False, inplace=none} [id B]
265265
← Mul [id Y] (inner_out_nit_sot-0)
266266
├─ ExpandDims{axis=0} [id Z]
267-
│ └─ *0-<Scalar(float64, shape=())> [id BA] -> [id S] (inner_in_seqs-0)
267+
│ └─ i0 [id BA] -> [id S] (inner_in_seqs-0)
268268
└─ Pow [id BB]
269269
├─ Subtensor{i} [id BC]
270270
│ ├─ Subtensor{start:} [id BD]
271271
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
272-
│ │ │ ├─ *3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
272+
│ │ │ ├─ i3 [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
273273
│ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
274274
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
275275
│ │ │ │ │ ├─ Add [id BI]
276-
│ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1)
276+
│ │ │ │ │ │ ├─ i3 [id BF] -> [id X] (inner_in_non_seqs-1)
277277
│ │ │ │ │ │ └─ Subtensor{i} [id BJ]
278278
│ │ │ │ │ │ ├─ Shape [id BK]
279279
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BL]
280280
│ │ │ │ │ │ │ └─ Second [id BM]
281-
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BN] -> [id W] (inner_in_non_seqs-0)
281+
│ │ │ │ │ │ │ ├─ i2 [id BN] -> [id W] (inner_in_non_seqs-0)
282282
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
283283
│ │ │ │ │ │ │ └─ 1.0 [id BP]
284284
│ │ │ │ │ │ └─ 0 [id BQ]
@@ -291,16 +291,16 @@ def compute_A_k(A, k):
291291
│ │ │ │ └─ ScalarFromTensor [id BT]
292292
│ │ │ │ └─ Subtensor{i} [id BJ]
293293
│ │ │ │ └─ ···
294-
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BN] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
294+
│ │ │ └─ i2 [id BN] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
295295
│ │ └─ 1 [id BU]
296296
│ └─ -1 [id BV]
297297
└─ ExpandDims{axis=0} [id BW]
298-
└─ *1-<Scalar(int64, shape=())> [id BX] -> [id U] (inner_in_seqs-1)
298+
└─ i1 [id BX] -> [id U] (inner_in_seqs-1)
299299
300300
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
301301
← Mul [id BY] (inner_out_sit_sot-0)
302-
├─ *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BG] (inner_in_sit_sot-0)
303-
└─ *1-<Vector(float64, shape=(?,))> [id CA] -> [id BN] (inner_in_non_seqs-0)
302+
├─ i0 [id BZ] -> [id BG] (inner_in_sit_sot-0)
303+
└─ i1 [id CA] -> [id BN] (inner_in_non_seqs-0)
304304
"""
305305

306306
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -354,27 +354,27 @@ def compute_A_k(A, k):
354354
Inner graphs:
355355
356356
Scan{scan_fn, while_loop=False, inplace=none} [id E]
357-
*0-<Scalar(float64, shape=())> [id Y] -> [id U] (inner_in_seqs-0)
358-
*1-<Scalar(int64, shape=())> [id Z] -> [id W] (inner_in_seqs-1)
359-
*2-<Vector(float64, shape=(?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
360-
*3-<Scalar(int32, shape=())> [id BB] -> [id B] (inner_in_non_seqs-1)
357+
i0 [id Y] -> [id U] (inner_in_seqs-0)
358+
i1 [id Z] -> [id W] (inner_in_seqs-1)
359+
i2 [id BA] -> [id C] (inner_in_non_seqs-0)
360+
i3 [id BB] -> [id B] (inner_in_non_seqs-1)
361361
← Mul [id BC] (inner_out_nit_sot-0)
362362
├─ ExpandDims{axis=0} [id BD]
363-
│ └─ *0-<Scalar(float64, shape=())> [id Y] (inner_in_seqs-0)
363+
│ └─ i0 [id Y] (inner_in_seqs-0)
364364
└─ Pow [id BE]
365365
├─ Subtensor{i} [id BF]
366366
│ ├─ Subtensor{start:} [id BG]
367367
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
368-
│ │ │ ├─ *3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1) (n_steps)
368+
│ │ │ ├─ i3 [id BB] (inner_in_non_seqs-1) (n_steps)
369369
│ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
370370
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
371371
│ │ │ │ │ ├─ Add [id BK]
372-
│ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1)
372+
│ │ │ │ │ │ ├─ i3 [id BB] (inner_in_non_seqs-1)
373373
│ │ │ │ │ │ └─ Subtensor{i} [id BL]
374374
│ │ │ │ │ │ ├─ Shape [id BM]
375375
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BN]
376376
│ │ │ │ │ │ │ └─ Second [id BO]
377-
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0)
377+
│ │ │ │ │ │ │ ├─ i2 [id BA] (inner_in_non_seqs-0)
378378
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
379379
│ │ │ │ │ │ │ └─ 1.0 [id BQ]
380380
│ │ │ │ │ │ └─ 0 [id BR]
@@ -387,18 +387,18 @@ def compute_A_k(A, k):
387387
│ │ │ │ └─ ScalarFromTensor [id BU]
388388
│ │ │ │ └─ Subtensor{i} [id BL]
389389
│ │ │ │ └─ ···
390-
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
390+
│ │ │ └─ i2 [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
391391
│ │ └─ 1 [id BV]
392392
│ └─ -1 [id BW]
393393
└─ ExpandDims{axis=0} [id BX]
394-
└─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
394+
└─ i1 [id Z] (inner_in_seqs-1)
395395
396396
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
397-
*0-<Vector(float64, shape=(?,))> [id BY] -> [id BI] (inner_in_sit_sot-0)
398-
*1-<Vector(float64, shape=(?,))> [id BZ] -> [id BA] (inner_in_non_seqs-0)
397+
i0 [id BY] -> [id BI] (inner_in_sit_sot-0)
398+
i1 [id BZ] -> [id BA] (inner_in_non_seqs-0)
399399
← Mul [id CA] (inner_out_sit_sot-0)
400-
├─ *0-<Vector(float64, shape=(?,))> [id BY] (inner_in_sit_sot-0)
401-
└─ *1-<Vector(float64, shape=(?,))> [id BZ] (inner_in_non_seqs-0)
400+
├─ i0 [id BY] (inner_in_sit_sot-0)
401+
└─ i1 [id BZ] (inner_in_non_seqs-0)
402402
"""
403403

404404
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -470,11 +470,11 @@ def fn(a_m2, a_m1, b_m2, b_m1):
470470
471471
Scan{scan_fn, while_loop=False, inplace=none} [id C]
472472
← Add [id BB] (inner_out_mit_sot-0)
473-
├─ *1-<Scalar(int64, shape=())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
474-
└─ *0-<Scalar(int64, shape=())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
473+
├─ i1 [id BC] -> [id E] (inner_in_mit_sot-0-1)
474+
└─ i0 [id BD] -> [id E] (inner_in_mit_sot-0-0)
475475
← Add [id BE] (inner_out_mit_sot-1)
476-
├─ *3-<Scalar(int64, shape=())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
477-
└─ *2-<Scalar(int64, shape=())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
476+
├─ i3 [id BF] -> [id O] (inner_in_mit_sot-1-1)
477+
└─ i2 [id BG] -> [id O] (inner_in_mit_sot-1-0)
478478
"""
479479

480480
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -597,19 +597,19 @@ def test_debugprint_mitmot():
597597
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
598598
← Add [id CK] (inner_out_mit_mot-0-0)
599599
├─ Mul [id CL]
600-
│ ├─ *2-<Vector(float64, shape=(?,))> [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
601-
│ └─ *5-<Vector(float64, shape=(?,))> [id CN] -> [id O] (inner_in_non_seqs-0)
602-
└─ *3-<Vector(float64, shape=(?,))> [id CO] -> [id BJ] (inner_in_mit_mot-0-1)
600+
│ ├─ i2 [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
601+
│ └─ i5 [id CN] -> [id O] (inner_in_non_seqs-0)
602+
└─ i3 [id CO] -> [id BJ] (inner_in_mit_mot-0-1)
603603
← Add [id CP] (inner_out_sit_sot-0)
604604
├─ Mul [id CQ]
605-
│ ├─ *2-<Vector(float64, shape=(?,))> [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
606-
│ └─ *0-<Vector(float64, shape=(?,))> [id CR] -> [id X] (inner_in_seqs-0)
607-
└─ *4-<Vector(float64, shape=(?,))> [id CS] -> [id CC] (inner_in_sit_sot-0)
605+
│ ├─ i2 [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
606+
│ └─ i0 [id CR] -> [id X] (inner_in_seqs-0)
607+
└─ i4 [id CS] -> [id CC] (inner_in_sit_sot-0)
608608
609609
Scan{scan_fn, while_loop=False, inplace=none} [id F]
610610
← Mul [id CT] (inner_out_sit_sot-0)
611-
├─ *0-<Vector(float64, shape=(?,))> [id CR] -> [id H] (inner_in_sit_sot-0)
612-
└─ *1-<Vector(float64, shape=(?,))> [id CU] -> [id O] (inner_in_non_seqs-0)
611+
├─ i0 [id CR] -> [id H] (inner_in_sit_sot-0)
612+
└─ i1 [id CU] -> [id O] (inner_in_non_seqs-0)
613613
"""
614614

615615
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -655,11 +655,11 @@ def no_shared_fn(n, x_tm1, M):
655655
Scan{scan_fn, while_loop=False, inplace=all} [id B]
656656
← Composite{switch(lt(0, i0), 1, 0)} [id K] (inner_out_sit_sot-0)
657657
└─ Subtensor{i, j, k} [id L]
658-
├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id M] -> [id J] (inner_in_non_seqs-0)
658+
├─ i2 [id M] -> [id J] (inner_in_non_seqs-0)
659659
├─ ScalarFromTensor [id N]
660-
│ └─ *0-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_seqs-0)
660+
│ └─ i0 [id O] -> [id D] (inner_in_seqs-0)
661661
├─ ScalarFromTensor [id P]
662-
│ └─ *1-<Scalar(int64, shape=())> [id Q] -> [id E] (inner_in_sit_sot-0)
662+
│ └─ i1 [id Q] -> [id E] (inner_in_sit_sot-0)
663663
└─ 0 [id R]
664664
665665
Composite{switch(lt(0, i0), 1, 0)} [id K]

0 commit comments

Comments
 (0)