Source code for pubplotlib.formatter
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatterSciNotation as BaseLogFormatterSciNotation
from matplotlib.ticker import ScalarFormatter as BaseScalarFormatter
# ===================
# FORMATTERS
# ===================
[docs]
class LogFormatterSciNotation(BaseLogFormatterSciNotation):
"""
Custom logarithmic formatter with configurable scientific notation bounds.
Shows decimal notation for values between `low` and `high`, and scientific
notation outside this range.
Parameters
----------
low : float, default 1e-3
Lower bound for decimal notation.
high : float, default 1e3
Upper bound for decimal notation.
**kwargs
Additional arguments passed to matplotlib's LogFormatterSciNotation.
"""
def __init__(self, low=1e-3, high=1e3, **kwargs):
super().__init__(**kwargs)
self.low = low
self.high = high
def __call__(self, x, pos=None, **kwargs):
out = super().__call__(x, pos, **kwargs)
if out == '':
return out
if x > self.low and x < self.high:
return f"{x:g}"
else:
return out
[docs]
class ScalarFormatter(BaseScalarFormatter):
"""
Custom scalar formatter with configurable scientific notation bounds.
Shows decimal notation for values between `low` and `high`, and scientific
notation outside this range.
Parameters
----------
low : float, default 1e-3
Lower bound for decimal notation.
high : float, default 1e3
Upper bound for decimal notation.
**kwargs
Additional arguments passed to matplotlib's ScalarFormatter.
"""
def __init__(self, low=1e-3, high=1e3, **kwargs):
super().__init__(**kwargs)
self.low = low
self.high = high
self.set_scientific(True)
self.set_powerlimits((np.log10(low), np.log10(high)))
self.set_useOffset(False)
self.set_useMathText(True)
def __call__(self, x, pos=None, **kwargs):
out = super().__call__(x, pos, **kwargs)
if out == '':
return out
if x > self.low and x < self.high:
return f"{x:g}"
else:
return out
[docs]
class SelectiveFormatter(LogFormatterSciNotation):
"""
Formatter that only displays labels for specified tick values.
Useful for showing labels only on major ticks while keeping minor ticks unlabeled.
Parameters
----------
tick_labels : array-like
The numerical values for which labels should be shown.
*args, **kwargs
Additional arguments passed to LogFormatterSciNotation.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> import matplotlib.ticker as ticker
>>>
>>> # Setup selective formatting
>>> tickvalues = generate_log_ticks(0.01, 100)
>>> major_ticks = generate_powers_of_10_ticks(0.01, 100)
>>> minor_ticks = filter_major_ticks(tickvalues, major_ticks)
>>> formatter = SelectiveFormatter(major_ticks)
>>>
>>> # Apply to axes
>>> fig, ax = plt.subplots()
>>> ax.xaxis.set_major_formatter(formatter)
>>> ax.xaxis.set_major_locator(ticker.FixedLocator(major_ticks))
>>> ax.xaxis.set_minor_locator(ticker.FixedLocator(minor_ticks))
"""
def __init__(self, tick_labels, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tick_labels = np.asarray(tick_labels, dtype=float)
def __call__(self, x, pos=None):
if any(np.isclose(x, val) for val in self.tick_labels):
return super().__call__(x, pos)
else:
return ""
# ===================
# AXIS CONFIGURATION
# ===================
[docs]
def set_formatter(ax=None, low=0.01, high=100, axis='both'):
"""
Apply custom formatting to axis that preserves matplotlib's tick placement.
Uses decimal notation for values between `low` and `high`, and scientific
notation outside this range.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Axes to format. If None, uses current axes.
low : float, default 0.01
Lower bound for decimal notation.
high : float, default 100
Upper bound for decimal notation.
axis : {'x', 'y', 'both'}, default 'both'
Which axes to format.
Examples
--------
>>> fig, ax = plt.subplots()
>>> ax.loglog(x, y)
>>> set_formatter(ax, low=0.01, high=100)
"""
def wrap_axis(axis_obj):
default_major_formatter = axis_obj.get_major_formatter()
if isinstance(default_major_formatter, BaseLogFormatterSciNotation):
axis_obj.set_major_formatter(LogFormatterSciNotation(low=low, high=high))
elif isinstance(default_major_formatter, BaseScalarFormatter):
axis_obj.set_major_formatter(ScalarFormatter(low=low, high=high))
default_minor_formatter = axis_obj.get_minor_formatter()
if isinstance(default_minor_formatter, BaseLogFormatterSciNotation):
axis_obj.set_minor_formatter(LogFormatterSciNotation(low=low, high=high))
elif isinstance(default_minor_formatter, BaseScalarFormatter):
axis_obj.set_minor_formatter(ScalarFormatter(low=low, high=high))
def apply_to_axis(ax_single):
if axis in ['x', 'both']:
wrap_axis(ax_single.xaxis)
if axis in ['y', 'both']:
wrap_axis(ax_single.yaxis)
if ax is None:
axes = plt.gcf().get_axes()
elif isinstance(ax, (list, tuple, np.ndarray)):
axes = ax
else:
axes = [ax]
for axis_obj in axes:
apply_to_axis(axis_obj)