Skip to content

Commit 8c5cf65

Browse files
typing: adding types to utilities
1 parent 41ae6d3 commit 8c5cf65

12 files changed

Lines changed: 195 additions & 126 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ python_version = "3.12"
144144
mypy_path = "stubs"
145145

146146
# Exclude specific directories from type checking will try to add them back gradually
147-
exclude = "(?x)(^temoa/extensions/|^temoa/utilities/|^stubs/)"
147+
exclude = "(?x)(^temoa/extensions/|^stubs/)"
148148

149149
# Strict typing for our own code
150150
disallow_untyped_defs = true

stubs/pyomo/core/base/component.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class ComponentData(ComponentBase):
159159
def __idiv__(self, other: typingAny) -> typingAny: ...
160160
def __itruediv__(self, other: typingAny) -> typingAny: ...
161161
def __ipow__(self, other: typingAny) -> typingAny: ...
162+
def set_value(self, val, skip_validation: bool = False) -> None: ...
162163

163164
class ActiveComponentData(ComponentData):
164165
def __init__(self, component) -> None: ...

temoa/utilities/capacity_analyzer.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
"""
66

77
import itertools
8-
import os.path
98
import sqlite3
9+
from typing import Any
1010

11-
from definitions import PROJECT_ROOT
1211
from matplotlib import pyplot as plt
1312

1413
# Written by: J. F. Hyink
@@ -17,34 +16,31 @@
1716

1817
# Created on: 7/18/23
1918

20-
# filename of db to analyze...
21-
db = 'US_9R_8D_CT500.sqlite'
19+
# filepath of db to analyze...
20+
source_db_file = 'US_9R_8D_CT500.sqlite'
21+
2222

23-
source_db_file = os.path.join(PROJECT_ROOT, 'data_files', 'untracked_data', db)
2423
print(source_db_file)
25-
res = []
24+
res: list[Any] = []
2625
try:
27-
con = sqlite3.connect(source_db_file)
28-
cur = con.cursor()
29-
cur.execute('SELECT max_cap FROM max_capacity')
30-
for row in cur:
31-
res.append(row)
26+
with sqlite3.connect(source_db_file) as con:
27+
cur: sqlite3.Cursor = con.cursor()
28+
cur.execute('SELECT max_cap FROM max_capacity')
29+
for row in cur:
30+
res.append(row)
3231

3332
except sqlite3.Error as e:
3433
print(e)
3534

36-
finally:
37-
con.close()
38-
3935
# chain them together into a list
40-
caps = list(itertools.chain(*res))
36+
caps: list[float] = list(itertools.chain(*res))
4137

4238
cutoff = 1 # GW : An arbitrary cutoff between big and small capacity systems.
43-
small_cap_sources = [c for c in caps if c <= cutoff]
44-
large_cap_sources = [c for c in caps if c > cutoff]
39+
small_cap_sources: list[float] = [c for c in caps if c <= cutoff]
40+
large_cap_sources: list[float] = [c for c in caps if c > cutoff]
4541

46-
aggregate_small_cap = sum(small_cap_sources)
47-
aggregate_large_cap = sum(large_cap_sources)
42+
aggregate_small_cap: float = sum(small_cap_sources)
43+
aggregate_large_cap: float = sum(large_cap_sources)
4844

4945
print(f'{len(small_cap_sources)} small cap sources account for: {aggregate_small_cap: 0.1f} GW')
5046
print(f'{len(large_cap_sources)} large cap sources account for: {aggregate_large_cap: 0.1f} GW')
@@ -56,8 +52,8 @@
5652
# make a cumulative contribution plot, and find a 5% cutoff
5753
cutoff_num_sources = 0
5854
caps.sort()
59-
total_cap = sum(caps)
60-
cumulative_caps = [
55+
total_cap: float = sum(caps)
56+
cumulative_caps: list[float] = [
6157
caps[0] / total_cap,
6258
]
6359
for i, cap in enumerate(caps[1:]):

temoa/utilities/clear_db_outputs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
from pathlib import Path
99

10-
basic_output_tables = [
10+
basic_output_tables: list[str] = [
1111
'output_built_capacity',
1212
'output_cost',
1313
'output_curtailment',
@@ -19,15 +19,15 @@
1919
'output_objective',
2020
'output_retired_capacity',
2121
]
22-
optional_output_tables = ['output_flow_out_summary', 'myopic_efficiency']
22+
optional_output_tables: list[str] = ['output_flow_out_summary', 'myopic_efficiency']
2323

2424
if len(sys.argv) != 2:
2525
print('this utility file expects a CLA for the path to the database to clear')
2626
sys.exit(-1)
2727

28-
target_db_str = sys.argv[1]
28+
target_db_str: str = sys.argv[1]
2929

30-
proceed = input('This will clear ALL output tables in ' + target_db_str + '? (y/n): ')
30+
proceed: str = input('This will clear ALL output tables in ' + target_db_str + '? (y/n): ')
3131
if proceed == 'y':
3232
target_db = Path(target_db_str)
3333
if not target_db.exists():

temoa/utilities/database_util.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
periods, and regions.
88
"""
99

10+
from __future__ import annotations
11+
1012
import os
1113
import re
1214
import sqlite3
1315
from os import PathLike
16+
from typing import Any, cast
1417

1518
import deprecated
1619
import pandas as pd
@@ -55,9 +58,9 @@ def __init__(self, database_path: str | PathLike[str], scenario: str | None = No
5558

5659
def close(self) -> None:
5760
"""Closes the database cursor and connection."""
58-
if self.cur:
61+
if hasattr(self, 'cur') and self.cur:
5962
self.cur.close()
60-
if self.con:
63+
if hasattr(self, 'con') and self.con:
6164
self.con.close()
6265

6366
@staticmethod
@@ -115,6 +118,8 @@ def get_time_peridos_for_flags(self, flags: list[str] | None = None) -> set[int]
115118
query = f'SELECT period FROM time_period WHERE flag IN ({in_clause})'
116119

117120
self.cur.execute(query)
121+
# cast to int because sqlite might return strings or ints depending on how data was inserted
122+
# but type hint says set[int]
118123
return {int(row[0]) for row in self.cur}
119124

120125
def get_technologies_for_flags(self, flags: list[str] | None = None) -> set[str]:
@@ -125,7 +130,7 @@ def get_technologies_for_flags(self, flags: list[str] | None = None) -> set[str]
125130
in_clause = ', '.join(f"'{flag}'" for flag in flags)
126131
query = f'SELECT tech FROM Technology WHERE flag IN ({in_clause})'
127132

128-
return {row[0] for row in self.cur.execute(query)}
133+
return {cast('str', row[0]) for row in self.cur.execute(query)}
129134

130135
def get_commodities_and_tech(
131136
self, inp_comm: str | None, inp_tech: str | None, region: str | None
@@ -171,7 +176,7 @@ def get_commodities_for_flags(self, flags: list[str] | None = None) -> set[str]:
171176
in_clause = ', '.join(f"'{flag}'" for flag in flags)
172177
query = f'SELECT name FROM Commodity WHERE flag IN ({in_clause})'
173178

174-
return {row[0] for row in self.cur.execute(query)}
179+
return {cast('str', row[0]) for row in self.cur.execute(query)}
175180

176181
def get_commodities_by_technology(
177182
self, region: str | None, comm_type: str = 'input'
@@ -187,11 +192,11 @@ def get_commodities_by_technology(
187192
if region:
188193
query += f" WHERE region LIKE '%{region}%'"
189194

190-
return {tuple(row) for row in self.cur.execute(query)}
195+
return {cast('tuple[str, str]', row) for row in self.cur.execute(query)}
191196

192197
def get_capacity_for_tech_and_period(
193198
self, tech: str | None = None, period: int | None = None, region: str | None = None
194-
) -> pd.DataFrame | pd.Series:
199+
) -> pd.DataFrame | pd.Series[Any]:
195200
"""Retrieves capacity data, aggregated by technology."""
196201
if not self.scenario:
197202
raise ValueError('A scenario must be set for output-related queries')

temoa/utilities/db_migration_to_v3.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import sys
1313
from collections import defaultdict
1414
from pathlib import Path
15+
from typing import Any
1516

1617
parser = argparse.ArgumentParser()
1718
parser.add_argument(
@@ -36,9 +37,9 @@
3637
new_db_name = legacy_db.stem + '_v3.sqlite'
3738
new_db_path = Path(legacy_db.parent, new_db_name)
3839

39-
con_old = sqlite3.connect(legacy_db)
40-
con_new = sqlite3.connect(new_db_path)
41-
cur = con_new.cursor()
40+
con_old: sqlite3.Connection = sqlite3.connect(legacy_db)
41+
con_new: sqlite3.Connection = sqlite3.connect(new_db_path)
42+
cur: sqlite3.Cursor = con_new.cursor()
4243

4344
# bring in the new schema and execute
4445
with open(schema_file) as src:
@@ -50,7 +51,7 @@
5051

5152
# table mapping for DIRECT transfers
5253
# fmt: off
53-
direct_transfer_tables = [
54+
direct_transfer_tables: list[tuple[str, str]] = [
5455
("", "CapacityCredit"),
5556
("", "CapacityFactorProcess"),
5657
("", "CapacityFactorTech"),
@@ -105,14 +106,14 @@
105106
("SegFrac", "TimeSegmentFraction"),
106107
]
107108

108-
units_added_tables = [
109+
units_added_tables: list[tuple[str, str]] = [
109110
("", "MaxActivityGroup"),
110111
("", "MaxCapacityGroup"),
111112
("", "MinCapacityGroup"),
112113
("", "MinActivityGroup"),
113114
]
114115

115-
sequence_added_tables = [
116+
sequence_added_tables: list[tuple[str, str]] = [
116117
("time_season", "TimeSeason"),
117118
("time_periods", "time_period"),
118119
("time_of_day", "TimeOfDay"),
@@ -126,12 +127,16 @@
126127
if old_name == '':
127128
old_name = new_name
128129

129-
new_columns = [c[1] for c in con_new.execute(f'PRAGMA table_info({new_name});').fetchall()]
130-
old_columns = [c[1] for c in con_old.execute(f'PRAGMA table_info({old_name});').fetchall()]
130+
new_columns: list[str] = [
131+
c[1] for c in con_new.execute(f'PRAGMA table_info({new_name});').fetchall()
132+
]
133+
old_columns: list[str] = [
134+
c[1] for c in con_old.execute(f'PRAGMA table_info({old_name});').fetchall()
135+
]
131136
cols = str(old_columns[0 : len(new_columns)])[1:-1].replace("'", '')
132137

133138
try:
134-
data = con_old.execute(f'SELECT {cols} FROM {old_name}').fetchall()
139+
data: list[Any] = con_old.execute(f'SELECT {cols} FROM {old_name}').fetchall()
135140
except sqlite3.OperationalError:
136141
print('TABLE NOT FOUND: ' + old_name)
137142
data = []
@@ -222,10 +227,11 @@
222227
# let's ensure all the non-global entries are consistent (same techs in each region)
223228
skip_rps = False
224229
try:
225-
rps_entries = con_old.execute('SELECT * FROM tech_rps').fetchall()
230+
rps_entries: list[tuple[str, str, str]] = con_old.execute('SELECT * FROM tech_rps').fetchall()
226231
except sqlite3.OperationalError:
227232
print('source does not appear to include RPS techs...skipping')
228233
skip_rps = True
234+
rps_entries = []
229235
if not skip_rps:
230236
for region, tech, _notes in rps_entries:
231237
groups[region].add(tech)
@@ -239,7 +245,7 @@
239245
for group, techs in groups.items():
240246
print(f'group: {group} mismatches: {common ^ techs}')
241247
if group != 'global':
242-
techs_common &= not common ^ techs
248+
techs_common &= not (common ^ techs)
243249
if not techs_common:
244250
print(
245251
'combining RPS techs failed. Some regions are not same. Must be done '
@@ -357,7 +363,7 @@
357363
data = con_old.execute(read_qry).fetchall()
358364
if unlim_cap_present:
359365
# need to convert null -> 0 for unlim_cap to match new schema that does not allow null
360-
new_data = []
366+
new_data: list[Any] = []
361367
for row in data:
362368
new_row = list(row)
363369
if new_row[4] is None:

temoa/utilities/db_migration_v3_1_to_v4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import sqlite3
1818
from pathlib import Path
19+
from typing import Any
1920

2021
# ---------- Mapping configuration ----------
2122
CUSTOM_MAP: dict[str, str] = {
@@ -93,7 +94,7 @@ def map_token_no_cascade(token: str) -> str:
9394
return to_snake_case(token)
9495

9596

96-
def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple]:
97+
def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple[Any, ...]]:
9798
try:
9899
return conn.execute(f'PRAGMA table_info({table});').fetchall()
99100
except sqlite3.OperationalError:
@@ -131,7 +132,7 @@ def migrate_direct_table(
131132
return len(filtered)
132133

133134

134-
def migrate_all(args) -> None:
135+
def migrate_all(args: argparse.Namespace) -> None:
135136
src = Path(args.source)
136137
schema = Path(args.schema)
137138
out = Path(args.out) if args.out else src.with_suffix('.v4.sqlite')

0 commit comments

Comments
 (0)