import matplotlib.pyplot as plt # type: ignore
import os
import yaml
from .stylebuilder import Style, Journal, assets_dir, builtin_yaml_filename, user_yaml_filename, user_style_dir
# --- Constants ---
golden = (1 + 5**0.5) / 2 # golden ratio
pt = 1 / 72.27 # points to inches
cm = 1 / 2.54 # centimeters to inches
# --- Defaults ---
_default_style = "aanda"
_current_style = None
# --- Registry ---
def _load_style_registry():
# Load built-in styles
builtin = {}
if builtin_yaml_filename.exists():
with open(builtin_yaml_filename, "r") as file:
raw = yaml.safe_load(file) or {}
for name, entry in raw.items():
builtin[name] = Style(
name=name,
onecol=entry.get("onecol"),
twocol=entry.get("twocol"),
fullpage=entry.get("fullpage"),
mplstyle=assets_dir.joinpath(entry.get("mplstyle")).as_posix() if entry.get("mplstyle") else None,
)
# Load user styles (override built-ins)
user = {}
if os.path.exists(user_yaml_filename):
with open(user_yaml_filename, "r") as file:
raw = yaml.safe_load(file) or {}
for name, entry in raw.items():
user[name] = Style(
name=name,
onecol=entry.get("onecol"),
twocol=entry.get("twocol"),
fullpage=entry.get("fullpage"),
mplstyle=os.path.join(user_style_dir, entry.get("mplstyle")) if entry.get("mplstyle") else None,
)
# User styles override built-ins
registry = {**builtin, **user}
return registry
_style_registry = _load_style_registry()
[docs]
def available_styles():
"""Return a list of available styles."""
return list(_style_registry.keys())
[docs]
def get_style(style=None):
"""Return a Style instance (from string or directly)."""
if isinstance(style, Style):
return style
if style is not None:
if style not in _style_registry:
raise ValueError(f"Style '{style}' not found. Available: {available_styles()}")
return _style_registry[style]
# Use current style if set, otherwise fall back to default
active_style = _current_style if _current_style is not None else _default_style
return _style_registry[active_style]
[docs]
def set_style(style=None):
"""Apply the style. Does nothing if already set."""
global _current_style
s = get_style(style)
if _current_style == s.name:
return
if s.mplstyle is not None:
plt.style.use(s.mplstyle)
_current_style = s.name
[docs]
def set_journal(journal=None):
"""Apply the journal style. Does nothing if already set."""
return set_style(journal)
[docs]
def restore():
"""Restore matplotlib's default style."""
plt.style.use('default')
global _current_style
_current_style = None
def _apply_style_local(style=None):
"""Apply a style locally without changing global state."""
s = get_style(style)
if s.mplstyle is not None:
plt.style.use(s.mplstyle)
return s
[docs]
def setup_figsize(style=None, twocols=False, fullpage=False, width_fraction=1.0, height_ratio=None):
"""Return (width, height) in inches for the style."""
# If style is explicitly provided, apply it locally without changing global state
# Otherwise, use the current global style
if style is not None:
s = _apply_style_local(style)
else:
s = _apply_style_local() # Uses current global style
if fullpage:
if s.fullpage is not None:
base_width = s.fullpage
else:
base_width = max(filter(None, [s.onecol, s.twocol, 0]))
if base_width == 0:
base_width = None
else:
base_width = s.twocol if twocols else s.onecol
if base_width is None:
raise ValueError(f"Style '{s.name}' does not support the selected layout.")
height = base_width / golden if height_ratio is None else base_width * height_ratio
width = base_width * width_fraction
return width, height
[docs]
def subplots(style=None, twocols=False, fullpage=False, width_fraction=1.0, height_ratio=None, journal=None, **kwargs):
"""Create subplots with style-appropriate dimensions."""
width, height = setup_figsize(style or journal, twocols, fullpage, width_fraction, height_ratio)
return plt.subplots(figsize=(width, height), **kwargs)
# --- Style Manager (mimics matplotlib.style interface) ---
[docs]
class _StyleManager:
"""
Manager for style operations, providing a matplotlib-like interface.
Mimics matplotlib.style.use() syntax while integrating with PubPlotLib's
style framework.
"""
[docs]
def use(self, style=None):
"""
Apply a style globally. Mimics matplotlib.style.use() syntax.
Parameters
----------
style : str or Style, optional
Name of the style to apply, or a Style instance.
Examples
--------
>>> pplt.style.use('aanda')
"""
return set_style(style)
[docs]
def available(self):
"""
Return list of available styles.
Returns
-------
list
Names of all available styles.
Examples
--------
>>> pplt.style.available()
['aanda', 'apj', 'presentation']
"""
return available_styles()
[docs]
def get(self, style=None):
"""
Get a Style instance.
Parameters
----------
style : str or Style, optional
Name of the style or a Style instance.
If None, returns the current active style.
Returns
-------
Style
The requested Style instance.
Examples
--------
>>> s = pplt.style.get('aanda')
>>> print(s.onecol, s.twocol)
"""
return get_style(style)
[docs]
def current(self):
"""
Get the name of the currently active style.
Returns
-------
str
Name of the current style.
Examples
--------
>>> pplt.style.current()
'aanda'
"""
global _current_style
return _current_style if _current_style is not None else _default_style
[docs]
def restore(self):
"""
Restore matplotlib's default style.
Examples
--------
>>> pplt.style.restore()
"""
return restore()
# Create the style namespace
style = _StyleManager()