from __future__ import annotations

from contextlib import contextmanager
from contextvars import ContextVar

from sqlalchemy.exc import CompileError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert, text
from sqlalchemy.dialects import postgresql
import psycopg2


_import_mode = ContextVar("import-mode", default="do-nothing")

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ..database import Database

# https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy/62305344#62305344
@contextmanager
def on_conflict(action="restrict"):
    token = _import_mode.set(action)
    try:
        yield
    finally:
        _import_mode.reset(token)


# @compiles(Insert, "postgresql")
def prefix_inserts(insert, compiler, **kw):
    """Conditionally adapt insert statements to use on-conflict resolution (a PostgreSQL feature)"""
    action = _import_mode.get()
    if action == "do-update":
        try:
            params = insert.compile().params
        except CompileError:
            params = {}
        vals = {
            name: value
            for name, value in params.items()
            if (
                name not in insert.table.primary_key
                and name in insert.table.columns
                and value is not None
            )
        }
        if vals:
            insert._post_values_clause = postgresql.dml.OnConflictDoUpdate(
                index_elements=insert.table.primary_key, set_=vals
            )
        else:
            action = "do-nothing"
    if action == "do-nothing":
        insert._post_values_clause = postgresql.dml.OnConflictDoNothing(
            index_elements=insert.table.primary_key
        )
    return compiler.visit_insert(insert, **kw)


_psycopg2_setup_was_run = ContextVar("psycopg2-setup-was-run", default=False)


def _setup_psycopg2_wait_callback():
    """Set up the wait callback for PostgreSQL connections. This allows for query cancellation with Ctrl-C."""
    # TODO: we might want to do this only once on engine creation
    # https://github.com/psycopg/psycopg2/issues/333
    val = _psycopg2_setup_was_run.get()
    if val:
        return
    psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
    _psycopg2_setup_was_run.set(True)


def table_exists(db: Database, table_name: str, schema: str = "public") -> bool:
    """Check if a table exists in a PostgreSQL database."""
    sql = """SELECT EXISTS (
        SELECT FROM information_schema.tables 
        WHERE table_schema = :schema
          AND table_name = :table_name
    );"""

    return db.session.execute(
        text(sql), params=dict(schema=schema, table_name=table_name)
    ).scalar()

