
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples\text_only\basic_units.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_text_only_basic_units.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_text_only_basic_units.py:


===========
Basic Units
===========

.. GENERATED FROM PYTHON SOURCE LINES 7-378







.. code-block:: default


    import math

    import numpy as np

    import matplotlib.units as units
    import matplotlib.ticker as ticker


    class ProxyDelegate:
        def __init__(self, fn_name, proxy_type):
            self.proxy_type = proxy_type
            self.fn_name = fn_name

        def __get__(self, obj, objtype=None):
            return self.proxy_type(self.fn_name, obj)


    class TaggedValueMeta(type):
        def __init__(self, name, bases, dict):
            for fn_name in self._proxies:
                if not hasattr(self, fn_name):
                    setattr(self, fn_name,
                            ProxyDelegate(fn_name, self._proxies[fn_name]))


    class PassThroughProxy:
        def __init__(self, fn_name, obj):
            self.fn_name = fn_name
            self.target = obj.proxy_target

        def __call__(self, *args):
            fn = getattr(self.target, self.fn_name)
            ret = fn(*args)
            return ret


    class ConvertArgsProxy(PassThroughProxy):
        def __init__(self, fn_name, obj):
            PassThroughProxy.__init__(self, fn_name, obj)
            self.unit = obj.unit

        def __call__(self, *args):
            converted_args = []
            for a in args:
                try:
                    converted_args.append(a.convert_to(self.unit))
                except AttributeError:
                    converted_args.append(TaggedValue(a, self.unit))
            converted_args = tuple([c.get_value() for c in converted_args])
            return PassThroughProxy.__call__(self, *converted_args)


    class ConvertReturnProxy(PassThroughProxy):
        def __init__(self, fn_name, obj):
            PassThroughProxy.__init__(self, fn_name, obj)
            self.unit = obj.unit

        def __call__(self, *args):
            ret = PassThroughProxy.__call__(self, *args)
            return (NotImplemented if ret is NotImplemented
                    else TaggedValue(ret, self.unit))


    class ConvertAllProxy(PassThroughProxy):
        def __init__(self, fn_name, obj):
            PassThroughProxy.__init__(self, fn_name, obj)
            self.unit = obj.unit

        def __call__(self, *args):
            converted_args = []
            arg_units = [self.unit]
            for a in args:
                if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'):
                    # if this arg has a unit type but no conversion ability,
                    # this operation is prohibited
                    return NotImplemented

                if hasattr(a, 'convert_to'):
                    try:
                        a = a.convert_to(self.unit)
                    except Exception:
                        pass
                    arg_units.append(a.get_unit())
                    converted_args.append(a.get_value())
                else:
                    converted_args.append(a)
                    if hasattr(a, 'get_unit'):
                        arg_units.append(a.get_unit())
                    else:
                        arg_units.append(None)
            converted_args = tuple(converted_args)
            ret = PassThroughProxy.__call__(self, *converted_args)
            if ret is NotImplemented:
                return NotImplemented
            ret_unit = unit_resolver(self.fn_name, arg_units)
            if ret_unit is NotImplemented:
                return NotImplemented
            return TaggedValue(ret, ret_unit)


    class TaggedValue(metaclass=TaggedValueMeta):

        _proxies = {'__add__': ConvertAllProxy,
                    '__sub__': ConvertAllProxy,
                    '__mul__': ConvertAllProxy,
                    '__rmul__': ConvertAllProxy,
                    '__cmp__': ConvertAllProxy,
                    '__lt__': ConvertAllProxy,
                    '__gt__': ConvertAllProxy,
                    '__len__': PassThroughProxy}

        def __new__(cls, value, unit):
            # generate a new subclass for value
            value_class = type(value)
            try:
                subcls = type(f'TaggedValue_of_{value_class.__name__}',
                              (cls, value_class), {})
                return object.__new__(subcls)
            except TypeError:
                return object.__new__(cls)

        def __init__(self, value, unit):
            self.value = value
            self.unit = unit
            self.proxy_target = self.value

        def __getattribute__(self, name):
            if name.startswith('__'):
                return object.__getattribute__(self, name)
            variable = object.__getattribute__(self, 'value')
            if hasattr(variable, name) and name not in self.__class__.__dict__:
                return getattr(variable, name)
            return object.__getattribute__(self, name)

        def __array__(self, dtype=object):
            return np.asarray(self.value).astype(dtype)

        def __array_wrap__(self, array, context):
            return TaggedValue(array, self.unit)

        def __repr__(self):
            return 'TaggedValue({!r}, {!r})'.format(self.value, self.unit)

        def __str__(self):
            return str(self.value) + ' in ' + str(self.unit)

        def __len__(self):
            return len(self.value)

        def __iter__(self):
            # Return a generator expression rather than use `yield`, so that
            # TypeError is raised by iter(self) if appropriate when checking for
            # iterability.
            return (TaggedValue(inner, self.unit) for inner in self.value)

        def get_compressed_copy(self, mask):
            new_value = np.ma.masked_array(self.value, mask=mask).compressed()
            return TaggedValue(new_value, self.unit)

        def convert_to(self, unit):
            if unit == self.unit or not unit:
                return self
            try:
                new_value = self.unit.convert_value_to(self.value, unit)
            except AttributeError:
                new_value = self
            return TaggedValue(new_value, unit)

        def get_value(self):
            return self.value

        def get_unit(self):
            return self.unit


    class BasicUnit:
        def __init__(self, name, fullname=None):
            self.name = name
            if fullname is None:
                fullname = name
            self.fullname = fullname
            self.conversions = dict()

        def __repr__(self):
            return f'BasicUnit({self.name})'

        def __str__(self):
            return self.fullname

        def __call__(self, value):
            return TaggedValue(value, self)

        def __mul__(self, rhs):
            value = rhs
            unit = self
            if hasattr(rhs, 'get_unit'):
                value = rhs.get_value()
                unit = rhs.get_unit()
                unit = unit_resolver('__mul__', (self, unit))
            if unit is NotImplemented:
                return NotImplemented
            return TaggedValue(value, unit)

        def __rmul__(self, lhs):
            return self*lhs

        def __array_wrap__(self, array, context):
            return TaggedValue(array, self)

        def __array__(self, t=None, context=None):
            ret = np.array([1])
            if t is not None:
                return ret.astype(t)
            else:
                return ret

        def add_conversion_factor(self, unit, factor):
            def convert(x):
                return x*factor
            self.conversions[unit] = convert

        def add_conversion_fn(self, unit, fn):
            self.conversions[unit] = fn

        def get_conversion_fn(self, unit):
            return self.conversions[unit]

        def convert_value_to(self, value, unit):
            conversion_fn = self.conversions[unit]
            ret = conversion_fn(value)
            return ret

        def get_unit(self):
            return self


    class UnitResolver:
        def addition_rule(self, units):
            for unit_1, unit_2 in zip(units[:-1], units[1:]):
                if unit_1 != unit_2:
                    return NotImplemented
            return units[0]

        def multiplication_rule(self, units):
            non_null = [u for u in units if u]
            if len(non_null) > 1:
                return NotImplemented
            return non_null[0]

        op_dict = {
            '__mul__': multiplication_rule,
            '__rmul__': multiplication_rule,
            '__add__': addition_rule,
            '__radd__': addition_rule,
            '__sub__': addition_rule,
            '__rsub__': addition_rule}

        def __call__(self, operation, units):
            if operation not in self.op_dict:
                return NotImplemented

            return self.op_dict[operation](self, units)


    unit_resolver = UnitResolver()

    cm = BasicUnit('cm', 'centimeters')
    inch = BasicUnit('inch', 'inches')
    inch.add_conversion_factor(cm, 2.54)
    cm.add_conversion_factor(inch, 1/2.54)

    radians = BasicUnit('rad', 'radians')
    degrees = BasicUnit('deg', 'degrees')
    radians.add_conversion_factor(degrees, 180.0/np.pi)
    degrees.add_conversion_factor(radians, np.pi/180.0)

    secs = BasicUnit('s', 'seconds')
    hertz = BasicUnit('Hz', 'Hertz')
    minutes = BasicUnit('min', 'minutes')

    secs.add_conversion_fn(hertz, lambda x: 1./x)
    secs.add_conversion_factor(minutes, 1/60.0)


    # radians formatting
    def rad_fn(x, pos=None):
        if x >= 0:
            n = int((x / np.pi) * 2.0 + 0.25)
        else:
            n = int((x / np.pi) * 2.0 - 0.25)

        if n == 0:
            return '0'
        elif n == 1:
            return r'$\pi/2$'
        elif n == 2:
            return r'$\pi$'
        elif n == -1:
            return r'$-\pi/2$'
        elif n == -2:
            return r'$-\pi$'
        elif n % 2 == 0:
            return fr'${n//2}\pi$'
        else:
            return fr'${n}\pi/2$'


    class BasicUnitConverter(units.ConversionInterface):
        @staticmethod
        def axisinfo(unit, axis):
            """Return AxisInfo instance for x and unit."""

            if unit == radians:
                return units.AxisInfo(
                    majloc=ticker.MultipleLocator(base=np.pi/2),
                    majfmt=ticker.FuncFormatter(rad_fn),
                    label=unit.fullname,
                )
            elif unit == degrees:
                return units.AxisInfo(
                    majloc=ticker.AutoLocator(),
                    majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'),
                    label=unit.fullname,
                )
            elif unit is not None:
                if hasattr(unit, 'fullname'):
                    return units.AxisInfo(label=unit.fullname)
                elif hasattr(unit, 'unit'):
                    return units.AxisInfo(label=unit.unit.fullname)
            return None

        @staticmethod
        def convert(val, unit, axis):
            if units.ConversionInterface.is_numlike(val):
                return val
            if np.iterable(val):
                if isinstance(val, np.ma.MaskedArray):
                    val = val.astype(float).filled(np.nan)
                out = np.empty(len(val))
                for i, thisval in enumerate(val):
                    if np.ma.is_masked(thisval):
                        out[i] = np.nan
                    else:
                        try:
                            out[i] = thisval.convert_to(unit).get_value()
                        except AttributeError:
                            out[i] = thisval
                return out
            if np.ma.is_masked(val):
                return np.nan
            else:
                return val.convert_to(unit).get_value()

        @staticmethod
        def default_units(x, axis):
            """Return the default unit for x or None."""
            if np.iterable(x):
                for thisx in x:
                    return thisx.unit
            return x.unit


    def cos(x):
        if np.iterable(x):
            return [math.cos(val.convert_to(radians).get_value()) for val in x]
        else:
            return math.cos(x.convert_to(radians).get_value())


    units.registry[BasicUnit] = units.registry[TaggedValue] = BasicUnitConverter()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.006 seconds)


.. _sphx_glr_download_auto_examples_text_only_basic_units.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: basic_units.py <basic_units.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: basic_units.ipynb <basic_units.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
