Skip to content

Commit 10974e8

Browse files
committed
TOML experiments for #28
1 parent 892fe3c commit 10974e8

4 files changed

Lines changed: 134 additions & 36 deletions

File tree

countess/core/config.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import re
66
from configparser import ConfigParser
77

8+
from tomllib import load as toml_load
9+
from tomli_w import dump as toml_dump
10+
811
from countess.core.pipeline import PipelineGraph, PipelineNode
912
from countess.core.plugins import load_plugin
1013

1114
logger = logging.getLogger(__name__)
1215

1316

14-
def read_config_dict(name: str, base_dir: str, config_dict: dict) -> PipelineNode:
17+
def read_config_meta(name: str, config_dict: dict) -> PipelineNode:
1518
if "_module" in config_dict:
1619
module_name = config_dict["_module"]
1720
class_name = config_dict["_class"]
@@ -37,7 +40,7 @@ def read_config_dict(name: str, base_dir: str, config_dict: dict) -> PipelineNod
3740

3841
# XXX check version and hash_digest and emit warnings.
3942

40-
node = PipelineNode(
43+
return PipelineNode(
4144
name=name,
4245
uuid=config_dict.get("_uuid"),
4346
plugin=plugin,
@@ -46,11 +49,33 @@ def read_config_dict(name: str, base_dir: str, config_dict: dict) -> PipelineNod
4649
sort_column=int(sort[0]),
4750
sort_descending=bool(int(sort[1])),
4851
)
52+
53+
54+
def read_config_toml(filename: str) -> PipelineGraph:
55+
base_dir = os.path.dirname(filename)
56+
pipeline_graph = PipelineGraph()
57+
nodes_by_name : dict[str, PipelineNode] = {}
58+
with open(filename, "rb") as fh:
59+
doc = toml_load(fh)
60+
for name, config_dict in doc.items():
61+
node = read_config_meta(name, config_dict)
62+
node.plugin.set_config({
63+
k: v
64+
for k, v in config_dict.items()
65+
if not k.startswith("_")
66+
}, base_dir)
67+
for key, val in config_dict.items():
68+
if key.startswith("_parent."):
69+
node.add_parent(nodes_by_name[val])
70+
pipeline_graph.nodes.append(node)
71+
nodes_by_name[name] = node
72+
73+
def read_config_dict(name: str, base_dir: str, config_dict: dict) -> PipelineNode:
74+
node = read_config_meta(name, config_dict)
4975
for key, val in config_dict.items():
5076
if not key.startswith("_"):
5177
node.set_config(key, ast.literal_eval(val), base_dir)
52-
return node
53-
78+
return node
5479

5580
def read_config(
5681
filenames: list[str],
@@ -110,31 +135,53 @@ def write_config_node_string(node: PipelineNode, base_dir: str = ""):
110135
return buf.getvalue()
111136

112137

113-
def write_config_node(node: PipelineNode, cp: ConfigParser, base_dir: str):
114-
cp.add_section(node.name)
138+
def get_config_meta(node: PipelineNode) -> dict[str, str]:
139+
meta = { "_uuid": node.uuid }
140+
if node.sort_column:
141+
desc = 1 if node.sort_descending else 0
142+
meta["_sort"] = "%d %d" % (node.sort_column, desc)
115143
if node.plugin:
116-
cp[node.name].update(
117-
{
118-
"_uuid": node.uuid,
119-
"_module": node.plugin.__module__,
120-
"_class": node.plugin.__class__.__name__,
121-
"_version": node.plugin.version,
122-
"_hash": node.plugin.hash(),
123-
"_sort": "%d %d" % (node.sort_column, 1 if node.sort_descending else 0),
124-
}
125-
)
144+
meta.update({
145+
"_module": node.plugin.__module__,
146+
"_class": node.plugin.__class__.__name__,
147+
"_version": node.plugin.version,
148+
"_hash": node.plugin.hash(),
149+
})
126150
if node.position:
127151
xx, yy = node.position
128-
cp[node.name]["_position"] = "%d %d" % (xx * 1000, yy * 1000)
152+
meta["_position"] = "%d %d" % (xx * 1000, yy * 1000)
129153
if node.notes:
130-
cp[node.name]["_notes"] = node.notes
154+
meta["_notes"] = node.notes
131155
for n, parent in enumerate(node.parent_nodes):
132-
cp[node.name][f"_parent.{n}"] = parent.name
156+
meta[f"_parent.{n}"] = parent.name
157+
return meta
158+
159+
160+
def write_config_node(node: PipelineNode, cp: ConfigParser, base_dir: str):
161+
cp.add_section(node.name)
162+
if node.plugin:
163+
cp[node.name].update(get_config_meta(node))
133164
if node.plugin:
134165
node.load_config()
135166
for k, v in node.plugin.get_parameters("", base_dir):
136167
cp[node.name][k] = repr(v)
137168

169+
def make_config_toml_node(node: PipelineNode, base_dir: str = ""):
170+
tab = get_config_meta(node)
171+
node.load_config()
172+
tab.update(node.plugin.get_config(base_dir))
173+
return tab
174+
175+
def make_config_toml(pipeline_graph: PipelineGraph, base_dir: str = ""):
176+
return {
177+
node.name: make_config_toml_node(node, base_dir)
178+
for node in pipeline_graph.traverse_nodes()
179+
}
180+
181+
def write_config_toml(pipeline_graph: PipelineGraph, filename: str):
182+
base_dir = os.path.dirname(filename)
183+
with open(filename, "wb") as fh:
184+
toml_dump(make_config_toml(pipeline_graph, base_dir), fh)
138185

139186
def export_config_graphviz(pipeline_graph: PipelineGraph, filename: str):
140187
with open(filename, "w", encoding="utf-8") as fh:

countess/core/parameters.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ def copy_and_set_value(self, value):
4545
new.value = value
4646
return new
4747

48+
def copy_and_set_config(self, config, base_dir="."):
49+
new = self.copy()
50+
new.set_config(config, base_dir)
51+
return new
52+
53+
def get_config(self, base_dir="."):
54+
raise NotImplementedError(f"Implement {self.__class__.__name__}.get_config()")
55+
56+
def set_config(self, value, base_dir="."):
57+
raise NotImplementedError(f"Implement {self.__class__.__name__}.set_config()")
58+
4859
def get_parameters(self, key, base_dir="."):
4960
raise NotImplementedError(f"Implement {self.__class__.__name__}.get_parameters()")
5061

@@ -82,6 +93,12 @@ def reset_value(self):
8293
def copy(self):
8394
return self.__class__(self.label, self._value)
8495

96+
def get_config(self, base_dir="."):
97+
return self._value
98+
99+
def set_config(self, value, base_dir="."):
100+
self._value = value
101+
85102
def get_parameters(self, key, base_dir="."):
86103
return ((key, self.value),)
87104

@@ -113,6 +130,12 @@ def reset_value(self, value: Any):
113130
def copy(self):
114131
return self.__class__(self.label, self._values)
115132

133+
def get_config(self, base_dir="."):
134+
return self._values
135+
136+
def set_config(self, value, base_dir="."):
137+
self._values = set(value)
138+
116139
def get_parameters(self, key, base_dir="."):
117140
yield from ((f"{key}.{n}", v) for n, v in enumerate(self._values))
118141

@@ -345,20 +368,24 @@ def set_base_dir(self, base_dir):
345368
self.value = self.get_file_path()
346369
self.base_dir = base_dir
347370

348-
def get_parameters(self, key, base_dir="."):
349-
if self.value:
350-
try:
351-
if base_dir:
352-
path = os.path.relpath(self.get_file_path(), base_dir)
353-
else:
354-
path = os.path.abspath(self.get_file_path())
355-
except ValueError:
356-
# relpath can fail on Windows
357-
path = self.get_file_path()
358-
else:
359-
path = None
371+
def get_config(self, base_dir="."):
372+
if not self.value:
373+
return None
374+
try:
375+
if base_dir:
376+
return os.path.relpath(self.get_file_path(), base_dir)
377+
else:
378+
return os.path.abspath(self.get_file_path())
379+
except ValueError:
380+
# relpath can fail on Windows
381+
return self.get_file_path()
382+
383+
def set_config(self, value, base_dir="."):
384+
self.value = value
385+
self.set_base_dir(base_dir)
360386

361-
return [(key, path)]
387+
def get_parameters(self, key, base_dir="."):
388+
return [(key, self.get_config(base_dir))]
362389

363390
def copy(self) -> "FileBaseParam":
364391
return self.__class__(self.label, self.value, file_types=self.file_types, base_dir=self.base_dir)
@@ -836,6 +863,16 @@ def __setattr__(self, name: str, value: None) -> None:
836863
else:
837864
super().__setattr__(name, value)
838865

866+
def get_config(self, base_dir="."):
867+
return { k: v.get_config(base_dir) for k, v in self.params.items() }
868+
869+
def set_config(self, value, base_dir="."):
870+
for k, v in value:
871+
if k in self.params:
872+
self.params[k].set_config(v, base_dir)
873+
else:
874+
logger.error("Unmatched Parameter Key %s", k)
875+
839876
def get_parameters(self, key, base_dir=".") -> Iterable[Tuple[str, str]]:
840877
for subkey, param in self.params.items():
841878
yield from param.get_parameters(f"{key}.{subkey}" if key else subkey, base_dir)
@@ -951,6 +988,16 @@ def __contains__(self, item):
951988
def __iter__(self):
952989
return self.params.__iter__()
953990

991+
def get_config(self, base_dir="."):
992+
return [ p.get_config(base_dir) for p in self.params ]
993+
994+
def set_config(self, value, base_dir="."):
995+
self.params = [
996+
self.copy_and_set_config(v, base_dir)
997+
for v in value
998+
]
999+
self.relabel()
1000+
9541001
def get_parameters(self, key, base_dir="."):
9551002
for n, p in enumerate(self.params):
9561003
yield from p.get_parameters(f"{key}.{n}", base_dir)

countess/gui/main.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from duckdb import DuckDBPyRelation
1616

1717
from countess import VERSION
18-
from countess.core.config import config_to_graph, export_config_graphviz, graph_to_config, read_config, write_config
18+
from countess.core.config import config_to_graph, export_config_graphviz, graph_to_config, read_config, write_config, write_config_toml
1919
from countess.core.pipeline import PipelineGraph
2020
from countess.core.plugins import get_plugin_classes
2121
from countess.gui.config import PluginConfigurator
@@ -568,9 +568,12 @@ def config_save(self, filename=None):
568568
)
569569
if not filename:
570570
return
571-
if not filename.endswith(".ini"):
572-
filename = filename + ".ini"
573-
write_config(self.graph, filename)
571+
if filename.endswith(".toml"):
572+
write_config_toml(self.graph, filename)
573+
else:
574+
if not filename.endswith(".ini"):
575+
filename = filename + ".ini"
576+
write_config(self.graph, filename)
574577
self.config_filename = filename
575578
self.config_changed = False
576579
self.update_title()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
'rapidfuzz~=3.13.0',
3030
'scipy~=1.15.3',
3131
'tkinterweb~=3.23.5',
32+
'tomli-w~=1.2.0',
3233
'ttkthemes~=3.2.2',
3334
'typing_extensions~=4.14.0',
3435
]

0 commit comments

Comments
 (0)