"""Data coordinator for Navien Smart integration."""

import json
import logging
import uuid
from datetime import timedelta

from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator

from .const import AWS_IOT_ENDPOINT, AWS_IOT_REGION, DOMAIN
from .navien_api import NavienAPI

_LOGGER = logging.getLogger(__name__)

# Try to import awsiotsdk for MQTT (optional)
try:
    from awscrt import auth, mqtt
    from awsiot import mqtt_connection_builder

    HAS_MQTT = True
except ImportError:
    HAS_MQTT = False
    _LOGGER.warning(
        "awsiotsdk not installed - MQTT disabled, using REST polling only. "
        "Install with: pip install awsiotsdk"
    )

# Poll every 30s when MQTT unavailable, 50min when MQTT active
POLL_INTERVAL = timedelta(seconds=30)
MQTT_REFRESH_INTERVAL = timedelta(minutes=50)


class NavienCoordinator(DataUpdateCoordinator):
    """Navien Smart data coordinator.

    Manages authentication, device list, MQTT connection,
    and provides real-time device status updates.
    """

    def __init__(self, hass: HomeAssistant, username: str, password: str):
        super().__init__(
            hass,
            _LOGGER,
            name=DOMAIN,
            update_interval=POLL_INTERVAL if not HAS_MQTT else MQTT_REFRESH_INTERVAL,
        )
        self._username = username
        self._password = password
        self.api = NavienAPI()
        self._mqtt_connection = None
        self._mqtt_connected = False
        self._device_status: dict[str, dict] = {}
        self._devices: list[dict] = []
        self._known_device_ids: set[str] = set()
        self._new_device_callbacks: list = []

    @property
    def devices(self) -> list[dict]:
        return self._devices

    def register_new_device_callback(self, callback_fn):
        """Register callback for when new devices are discovered."""
        self._new_device_callbacks.append(callback_fn)

    def get_device_status(self, device_id: str) -> dict | None:
        return self._device_status.get(device_id)

    async def async_setup(self):
        """Initial setup: login, get devices, start MQTT."""
        _LOGGER.info("Navien Smart setup starting")
        await self.api.login(self._username, self._password)
        _LOGGER.info("Login success: user=%s", self.api.user_id)

        self._devices = await self.api.get_devices()
        self._known_device_ids = {d["deviceId"] for d in self._devices}
        _LOGGER.info("Found %d device(s)", len(self._devices))

        # Start MQTT if available
        if HAS_MQTT:
            await self._start_mqtt()
        else:
            _LOGGER.info("MQTT unavailable, using REST polling (every 30s)")

        # Initialize devices (triggers status publish)
        for dev in self._devices:
            try:
                resp = await self.api.control_device(dev)
                _LOGGER.debug("Init device %s: %s", dev.get("deviceId"), resp)
            except Exception as err:
                _LOGGER.warning(
                    "Failed to initialize device %s: %s", dev.get("deviceId"), err
                )

    async def _async_update_data(self) -> dict:
        """Periodic update."""
        # Refresh credentials
        try:
            await self.api.refresh_aws_credentials()
            _LOGGER.debug("AWS credentials refreshed")
        except Exception as err:
            _LOGGER.warning("AWS credential refresh failed, re-logging in: %s", err)
            try:
                await self.api.login(self._username, self._password)
            except Exception as login_err:
                _LOGGER.error("Re-login failed: %s", login_err)
                raise

        # Reconnect MQTT if available
        if HAS_MQTT:
            await self._stop_mqtt()
            await self._start_mqtt()

        # Re-fetch device list (detect new devices)
        try:
            fresh_devices = await self.api.get_devices()
            new_devices = [
                d for d in fresh_devices
                if d["deviceId"] not in self._known_device_ids
            ]
            if new_devices:
                self._devices = fresh_devices
                for dev in new_devices:
                    self._known_device_ids.add(dev["deviceId"])
                _LOGGER.info(
                    "Discovered %d new device(s): %s",
                    len(new_devices),
                    [d.get("modelName") for d in new_devices],
                )
                for cb in self._new_device_callbacks:
                    cb(new_devices)
        except Exception:
            _LOGGER.warning("Failed to refresh device list")

        # Request fresh status from all devices
        for dev in self._devices:
            try:
                await self.api.control_device(dev)
            except Exception:
                pass

        return self._device_status

    # ── MQTT ──

    async def _start_mqtt(self):
        """Connect to AWS IoT Core MQTT."""
        if not HAS_MQTT:
            return

        aws_creds = self.api.aws_credentials
        if not aws_creds:
            _LOGGER.warning("No AWS credentials, skipping MQTT")
            return

        user_seq = self.api.user_seq
        client_id = f"{uuid.uuid4()}-U{user_seq}"

        def _connect():
            credentials_provider = auth.AwsCredentialsProvider.new_static(
                access_key_id=aws_creds["accessKeyId"],
                secret_access_key=aws_creds["secretKey"],
                session_token=aws_creds["sessionToken"],
            )

            connection = mqtt_connection_builder.websockets_with_default_aws_signing(
                endpoint=AWS_IOT_ENDPOINT,
                region=AWS_IOT_REGION,
                credentials_provider=credentials_provider,
                client_id=client_id,
                clean_session=True,
                keep_alive_secs=30,
            )

            connect_future = connection.connect()
            connect_future.result(timeout=10)
            return connection

        try:
            self._mqtt_connection = await self.hass.async_add_executor_job(_connect)
            self._mqtt_connected = True
            _LOGGER.info("MQTT connected (clientId=%s)", client_id)
        except Exception as err:
            self._mqtt_connected = False
            _LOGGER.error("MQTT connection failed: %s", err)
            _LOGGER.info("Falling back to REST polling")
            self.update_interval = POLL_INTERVAL
            return

        # Subscribe to device status changes
        home_seq = self.api.home_seq
        topic = f"{home_seq}/mate/+"

        def _deep_merge(base: dict, update: dict) -> dict:
            """Deep merge update into base, returning merged result."""
            merged = base.copy()
            for key, value in update.items():
                if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
                    merged[key] = _deep_merge(merged[key], value)
                else:
                    merged[key] = value
            return merged

        def on_message(topic, payload, **kwargs):
            try:
                data = json.loads(payload)
                event_topic = data.get("topic", "")

                _LOGGER.debug(
                    "MQTT [%s]: %s",
                    event_topic.split("/")[-1] if "/" in event_topic else event_topic,
                    json.dumps(data, ensure_ascii=False, default=str)[:500],
                )

                reported = None
                is_full_state = False

                if "update/documents" in event_topic:
                    # Full state: payload.current.state.reported
                    reported = (
                        data.get("payload", {})
                        .get("current", {})
                        .get("state", {})
                        .get("reported")
                    )
                    is_full_state = True
                elif "update/accepted" in event_topic:
                    # Partial state: payload.state.reported
                    reported = (
                        data.get("payload", {})
                        .get("state", {})
                        .get("reported")
                    )
                    is_full_state = False
                else:
                    # delta or other messages - skip
                    return

                if not reported:
                    return

                device_id = reported.get("info", {}).get("deviceId")
                if not device_id:
                    return

                if is_full_state:
                    # Full state from documents - replace entirely
                    self._device_status[device_id] = reported
                    _LOGGER.info(
                        "MQTT full state [%s]: mode=%s, heater=%s",
                        device_id,
                        reported.get("operationMode"),
                        bool(reported.get("heater")),
                    )
                else:
                    # Partial state from accepted - merge into existing
                    existing = self._device_status.get(device_id, {})
                    self._device_status[device_id] = _deep_merge(existing, reported)
                    _LOGGER.debug(
                        "MQTT partial merge [%s]: %s",
                        device_id,
                        list(reported.keys()),
                    )

                # Notify HA entities on the event loop
                self.hass.loop.call_soon_threadsafe(
                    self.async_set_updated_data, self._device_status
                )
            except Exception:
                _LOGGER.exception("Error processing MQTT message")

        def _subscribe():
            sub_future, _ = self._mqtt_connection.subscribe(
                topic=topic,
                qos=mqtt.QoS.AT_LEAST_ONCE,
                callback=on_message,
            )
            sub_future.result(timeout=10)

        try:
            await self.hass.async_add_executor_job(_subscribe)
            _LOGGER.info("MQTT subscribed to %s", topic)
        except Exception as err:
            _LOGGER.error("MQTT subscribe failed: %s", err)

    async def _stop_mqtt(self):
        """Disconnect MQTT."""
        if self._mqtt_connection:
            def _disconnect():
                try:
                    self._mqtt_connection.disconnect().result(timeout=5)
                except Exception:
                    pass

            await self.hass.async_add_executor_job(_disconnect)
            self._mqtt_connection = None
            self._mqtt_connected = False
            _LOGGER.debug("MQTT disconnected")

    # ── MQTT Reconnect ──

    async def _reconnect_mqtt(self):
        """Re-login and reconnect MQTT."""
        _LOGGER.info("MQTT reconnecting: refreshing credentials...")
        await self._stop_mqtt()
        try:
            await self.api.refresh_aws_credentials()
        except Exception:
            _LOGGER.warning("Credential refresh failed, full re-login")
            await self.api.login(self._username, self._password)
        await self._start_mqtt()

    # ── Control (REST API) ──

    async def async_control(self, device: dict, payload: dict | None = None):
        """Send control command via REST API with auto re-login on failure."""
        resp = await self.api.control_device(device, payload)
        if resp.get("code") not in (200, 407):
            # Session might be invalidated (e.g. app login), re-login and retry
            _LOGGER.warning("Control failed (code=%s), re-logging in...", resp.get("code"))
            await self._relogin()
            resp = await self.api.control_device(device, payload)
        return resp

    async def _relogin(self):
        """Full re-login and MQTT reconnect."""
        _LOGGER.info("Re-login started")
        await self.api.login(self._username, self._password)
        _LOGGER.info("Re-login success, reconnecting MQTT...")
        await self._stop_mqtt()
        await self._start_mqtt()

    # ── Public ──

    async def async_refresh_devices(self):
        """Re-login, reconnect MQTT, and refresh device list."""
        _LOGGER.info("Manual refresh: re-login starting...")

        await self._relogin()

        try:
            fresh_devices = await self.api.get_devices()
            new_devices = [
                d for d in fresh_devices
                if d["deviceId"] not in self._known_device_ids
            ]
            self._devices = fresh_devices

            if new_devices:
                for dev in new_devices:
                    self._known_device_ids.add(dev["deviceId"])
                _LOGGER.info(
                    "Discovered %d new device(s): %s",
                    len(new_devices),
                    [d.get("modelName") for d in new_devices],
                )
                for cb in self._new_device_callbacks:
                    cb(new_devices)

            for dev in self._devices:
                try:
                    await self.api.control_device(dev)
                except Exception:
                    pass

            _LOGGER.info(
                "Manual refresh complete: %d device(s)", len(self._devices)
            )
        except Exception as err:
            _LOGGER.error("Manual device refresh failed: %s", err)
            raise

    async def async_shutdown(self):
        """Shutdown coordinator."""
        await self._stop_mqtt()
        await self.api.close()
        await super().async_shutdown()
