import unittest
from smlr.penalties import (L1Penalty, 
                            L2Penalty, 
                            ElasticNetPenalty, 
                            SENPenalty,
                            NoPenalty)
import numpy as np


class TestPenalties(unittest.TestCase):
    def test_vector_value(self):
        """
        test all values of penalties on a vector.
        """

        beta = np.random.random(size=20)

        l1 = L1Penalty(lam=1.0)
        l2 = L2Penalty(lam=2.0)
        enet = ElasticNetPenalty(lam=2.0,alpha=0.5)

        self.assertAlmostEqual(l1.value(beta), np.linalg.norm(beta, 1))
        self.assertAlmostEqual(l2.value(beta), np.linalg.norm(beta, 2) ** 2)
        self.assertAlmostEqual(enet.value(beta), np.linalg.norm(beta, 1) + 0.5 * (np.linalg.norm(beta, 2)) ** 2)

    def test_matrix_value(self):
        """
        test all values of penalties on a matrix.
        """
        
        beta = np.array([
            [1.0, 2.0, 3.0],
            [4.0, 1.0, -1.0],
            [1.5, 2.2, 3.3]
        ])

        l1 = L1Penalty(lam=1.0)
        l2 = L2Penalty(lam=2.0)
        enet = ElasticNetPenalty(lam=2.0,alpha=0.5)
        sen = SENPenalty(lam1=1.0)
        no = NoPenalty()

        self.assertAlmostEqual(l1.value(beta), 19.0)
        self.assertAlmostEqual(l2.value(beta), np.linalg.norm(beta) ** 2)
        self.assertAlmostEqual(enet.value(beta), 19.0 + 0.5 * np.linalg.norm(beta) ** 2)
        self.assertAlmostEqual(sen.value(beta), np.linalg.norm(beta, 'nuc'))
        self.assertAlmostEqual(no.value(beta), 0.0)

    def test_prox(self):
        """
        proximal operator test
        """

        beta = np.array([1.0, 2.0, -1.0, 0.0, -0.5, -0.2, 0.2])
        l1 = L1Penalty(lam=1.0)

        self.assertTrue(np.all(np.isclose(l1.prox(beta, 0.5), np.array([0.5, 1.5, -0.5, 0.0, 0.0, 0.0, 0.0]))))

    
    

        

    


        