11import os
2+ from typing import Dict
23
34import matplotlib .pyplot as plt
45import plotly .graph_objects as go
56from plotly .subplots import make_subplots
67
7- import maxplotlib .backends .matplotlib .utils as plt_utils
8+ from maxplotlib .backends .matplotlib .utils import (
9+ set_size ,
10+ setup_plotstyle ,
11+ setup_tex_fonts ,
12+ )
813from maxplotlib .subfigure .line_plot import LinePlot
914from maxplotlib .subfigure .tikz_figure import TikzFigure
15+ from maxplotlib .utils .options import Backends
1016
1117
1218class Canvas :
13- def __init__ (self , ** kwargs ):
19+ def __init__ (
20+ self ,
21+ nrows : int = 1 ,
22+ ncols : int = 1 ,
23+ figsize : tuple | None = None ,
24+ caption : str | None = None ,
25+ description : str | None = None ,
26+ label : str | None = None ,
27+ fontsize : int = 14 ,
28+ dpi : int = 300 ,
29+ width : str = "17cm" ,
30+ ratio : str = "golden" , # TODO Add literal
31+ gridspec_kw : Dict = {"wspace" : 0.08 , "hspace" : 0.1 },
32+ ):
1433 """
1534 Initialize the Canvas class for multiple subplots.
1635
1736 Parameters:
1837 nrows (int): Number of subplot rows. Default is 1.
1938 ncols (int): Number of subplot columns. Default is 1.
2039 figsize (tuple): Figure size.
40+ caption (str): Caption for the figure.
41+ description (str): Description for the figure.
42+ label (str): Label for the figure.
43+ fontsize (int): Font size. Default is 14.
44+ dpi (int): DPI for the figure. Default is 300.
45+ width (str): Width of the figure. Default is "17cm".
46+ ratio (str): Aspect ratio. Default is "golden".
47+ gridspec_kw (dict): Gridspec keyword arguments. Default is {"wspace": 0.08, "hspace": 0.1}.
2148 """
2249
23- # nrows=1, ncols=1, caption=None, description=None, label=None, figsize=None
24- self ._nrows = kwargs .get ("nrows" , 1 )
25- self ._ncols = kwargs .get ("ncols" , 1 )
26- self ._figsize = kwargs .get ("figsize" , None )
27- self ._caption = kwargs .get ("caption" , None )
28- self ._description = kwargs .get ("description" , None )
29- self ._label = kwargs .get ("label" , None )
30- self ._fontsize = kwargs .get ("fontsize" , 14 )
31- self ._dpi = kwargs .get ("dpi" , 300 )
32- # self._width = kwargs.get("width", 426.79135)
33- self ._width = kwargs .get ("width" , "17cm" )
34- self ._ratio = kwargs .get ("ratio" , "golden" )
35- self ._gridspec_kw = kwargs .get ("gridspec_kw" , {"wspace" : 0.08 , "hspace" : 0.1 })
50+ self ._nrows = nrows
51+ self ._ncols = ncols
52+ self ._figsize = figsize
53+ self ._caption = caption
54+ self ._description = description
55+ self ._label = label
56+ self ._fontsize = fontsize
57+ self ._dpi = dpi
58+ self ._width = width
59+ self ._ratio = ratio
60+ self ._gridspec_kw = gridspec_kw
3661 self ._plotted = False
3762
3863 # Dictionary to store lines for each subplot
@@ -196,11 +221,11 @@ def add_subplot(
196221 def savefig (
197222 self ,
198223 filename ,
199- backend = "matplotlib" ,
200- layers = None ,
201- layer_by_layer = False ,
202- verbose = False ,
203- plot = True ,
224+ backend : Backends = "matplotlib" ,
225+ layers : list | None = None ,
226+ layer_by_layer : bool = False ,
227+ verbose : bool = False ,
228+ plot : bool = True ,
204229 ):
205230 filename_no_extension , extension = os .path .splitext (filename )
206231 if backend == "matplotlib" :
@@ -238,7 +263,7 @@ def savefig(
238263 if verbose :
239264 print (f"Saved { full_filepath } " )
240265
241- def plot (self , backend = "matplotlib" , savefig = False , layers = None ):
266+ def plot (self , backend : Backends = "matplotlib" , savefig = False , layers = None ):
242267 if backend == "matplotlib" :
243268 return self .plot_matplotlib (savefig = savefig , layers = layers )
244269 elif backend == "plotly" :
@@ -263,9 +288,9 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
263288 filename (str, optional): Filename to save the figure.
264289 """
265290
266- tex_fonts = plt_utils . setup_tex_fonts (fontsize = self .fontsize , usetex = usetex )
291+ tex_fonts = setup_tex_fonts (fontsize = self .fontsize , usetex = usetex )
267292
268- plt_utils . setup_plotstyle (
293+ setup_plotstyle (
269294 tex_fonts = tex_fonts ,
270295 axes_grid = True ,
271296 axes_grid_which = "major" ,
@@ -276,7 +301,7 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
276301 if self ._figsize is not None :
277302 fig_width , fig_height = self ._figsize
278303 else :
279- fig_width , fig_height = plt_utils . set_size (
304+ fig_width , fig_height = set_size (
280305 width = self ._width ,
281306 ratio = self ._ratio ,
282307 dpi = self .dpi ,
@@ -313,7 +338,7 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
313338 savefig (str, optional): Filename to save the figure if provided.
314339 """
315340
316- tex_fonts = plt_utils . setup_tex_fonts (
341+ tex_fonts = setup_tex_fonts (
317342 fontsize = self .fontsize ,
318343 usetex = usetex ,
319344 ) # adjust or redefine for Plotly if needed
@@ -322,7 +347,7 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
322347 if self ._figsize is not None :
323348 fig_width , fig_height = self ._figsize
324349 else :
325- fig_width , fig_height = plt_utils . set_size (
350+ fig_width , fig_height = set_size (
326351 width = self ._width ,
327352 ratio = self ._ratio ,
328353 )
0 commit comments