#!/usr/bin/env python3
"""Integration tests for python3-filterpy.

Each test simulates a realistic filtering scenario and asserts that the
filter converges to the true value within a numerical tolerance.
"""

import sys
import numpy as np

np.random.seed(42)

PASS = 0
FAIL = 0


def check(name, condition, detail=''):
    global PASS, FAIL
    if condition:
        print(f'PASS  {name}')
        PASS += 1
    else:
        print(f'FAIL  {name}' + (f': {detail}' if detail else ''))
        FAIL += 1


# ---------------------------------------------------------------------------
# 1. KalmanFilter — constant-velocity 1-D tracking
# ---------------------------------------------------------------------------
def test_kalman_filter():
    from filterpy.kalman import KalmanFilter

    dt = 1.0
    kf = KalmanFilter(dim_x=2, dim_z=1)
    kf.F = np.array([[1., dt], [0., 1.]])
    kf.H = np.array([[1., 0.]])
    kf.R = np.array([[5.]])
    kf.Q = np.array([[0.1, 0.1], [0.1, 0.1]])
    kf.x = np.array([[0.], [1.]])
    kf.P = np.eye(2) * 500.

    true_pos = 0.
    errors_early, errors_late = [], []
    for i in range(60):
        true_pos += 1.0
        z = true_pos + np.random.randn() * np.sqrt(5.)
        kf.predict()
        kf.update(np.array([[z]]))
        err = abs(kf.x[0, 0] - true_pos)
        if i < 10:
            errors_early.append(err)
        if i >= 50:
            errors_late.append(err)

    check('KalmanFilter: converges (late error < early error)',
          np.mean(errors_late) < np.mean(errors_early))
    check('KalmanFilter: late position error < 2.0',
          np.mean(errors_late) < 2.0,
          f'mean late error={np.mean(errors_late):.3f}')
    check('KalmanFilter: covariance is positive definite',
          np.all(np.linalg.eigvals(kf.P) > 0))


# ---------------------------------------------------------------------------
# 2. UnscentedKalmanFilter — nonlinear radar (range/bearing) tracking
# ---------------------------------------------------------------------------
def test_ukf():
    from filterpy.kalman import UnscentedKalmanFilter, MerweScaledSigmaPoints

    def fx(x, dt):
        return np.array([x[0] + x[1]*dt, x[1], x[2] + x[3]*dt, x[3]])

    def hx(x):
        r = np.sqrt(x[0]**2 + x[2]**2)
        b = np.arctan2(x[2], x[0])
        return np.array([r, b])

    dt = 0.1
    points = MerweScaledSigmaPoints(4, alpha=0.1, beta=2., kappa=-1.)
    ukf = UnscentedKalmanFilter(dim_x=4, dim_z=2, dt=dt, fx=fx, hx=hx,
                                points=points)
    ukf.x = np.array([100., 0., 100., -1.])
    ukf.P = np.diag([1., 0.1, 1., 0.1])
    ukf.R = np.diag([0.5, 0.01])
    ukf.Q = np.diag([0.01, 0.001, 0.01, 0.001])

    true_x = np.array([100., 0., 100., -1.])
    errors_late = []
    for i in range(80):
        true_x = fx(true_x, dt)
        z = hx(true_x) + np.random.randn(2) * np.sqrt([0.5, 0.01])
        ukf.predict()
        ukf.update(z)
        if i >= 60:
            errors_late.append(np.linalg.norm(ukf.x[[0, 2]] - true_x[[0, 2]]))

    check('UKF: late position error < 3.0',
          np.mean(errors_late) < 3.0,
          f'mean={np.mean(errors_late):.3f}')
    check('UKF: covariance is symmetric',
          np.allclose(ukf.P, ukf.P.T, atol=1e-10))


# ---------------------------------------------------------------------------
# 3. GHFilter — constant-velocity g-h filter
# ---------------------------------------------------------------------------
def test_gh_filter():
    from filterpy.gh import GHFilter

    f = GHFilter(x=0., dx=1., dt=1., g=0.6, h=0.02)
    true_pos = 0.
    errors_late = []
    for i in range(50):
        true_pos += 1.0
        z = true_pos + np.random.randn() * 1.5
        f.update(z)
        if i >= 40:
            errors_late.append(abs(f.x - true_pos))

    check('GHFilter: late position error < 3.0',
          np.mean(errors_late) < 3.0,
          f'mean={np.mean(errors_late):.3f}')
    check('GHFilter: velocity estimate positive',
          f.dx > 0)


# ---------------------------------------------------------------------------
# 4. discrete_bayes — 1-D hallway localisation
# ---------------------------------------------------------------------------
def test_discrete_bayes():
    from filterpy.discrete_bayes import update, predict

    # hallway: 1=door, 0=wall; dog starts at position 3
    hallway = np.array([1, 1, 0, 0, 0, 0, 0, 0, 1, 0])
    prior = np.ones(10) / 10.

    # sense door at position 3 (hallway[3]=0), actually a miss — use a door pos
    # Robot at pos 2 (hallway[2]=0), senses correctly
    # Just verify the math: posterior concentrates on door positions
    likelihood = np.array([0.75 if h == 1 else 0.1 for h in hallway])
    posterior = update(likelihood, prior)

    check('discrete_bayes: posterior sums to 1',
          abs(posterior.sum() - 1.0) < 1e-10,
          f'sum={posterior.sum()}')
    check('discrete_bayes: posterior concentrates on doors',
          posterior[0] > posterior[4],
          f'door={posterior[0]:.3f} wall={posterior[4]:.3f}')

    # predict: move one step right with some uncertainty
    kernel = np.array([0.1, 0.8, 0.1])
    moved = predict(posterior, 1, kernel)
    check('discrete_bayes: predict output sums to 1',
          abs(moved.sum() - 1.0) < 1e-10,
          f'sum={moved.sum()}')


# ---------------------------------------------------------------------------
# 5. Monte Carlo resampling — weights concentrate after resampling
# ---------------------------------------------------------------------------
def test_monte_carlo():
    from filterpy.monte_carlo import (systematic_resample, residual_resample,
                                      stratified_resample, multinomial_resample)

    N = 200
    # give all weight to first 10 particles
    weights = np.zeros(N)
    weights[:10] = 1.0 / 10.

    for name, fn in [('systematic', systematic_resample),
                     ('residual',   residual_resample),
                     ('stratified', stratified_resample),
                     ('multinomial', multinomial_resample)]:
        indices = fn(weights)
        check(f'MC resample {name}: returns N indices',
              len(indices) == N)
        check(f'MC resample {name}: indices in range',
              indices.max() < N and indices.min() >= 0)
        # most indices should be from the high-weight region (0-9)
        frac_high = np.mean(indices < 10)
        check(f'MC resample {name}: high-weight particles dominate',
              frac_high > 0.8,
              f'fraction from high-weight region={frac_high:.2f}')


# ---------------------------------------------------------------------------
# 6. EnsembleKalmanFilter — constant signal tracking
# ---------------------------------------------------------------------------
def test_ensemble_kalman():
    from filterpy.kalman import EnsembleKalmanFilter

    def hx(x):
        return np.array([x[0]])

    def fx(x, dt):
        return x

    N = 50
    enkf = EnsembleKalmanFilter(x=np.array([0.]), P=np.array([[2.]]), dim_z=1,
                                dt=1., N=N, hx=hx, fx=fx)
    enkf.R = np.array([[1.]])

    true_val = 5.0
    errors_late = []
    for i in range(40):
        z = np.array([true_val + np.random.randn()])
        enkf.predict()
        enkf.update(z)
        if i >= 30:
            errors_late.append(abs(enkf.x[0] - true_val))

    check('EnKF: late error < 1.0',
          np.mean(errors_late) < 1.0,
          f'mean={np.mean(errors_late):.3f}')


# ---------------------------------------------------------------------------
# 7. FixedLagSmoother — smoother improves on filter
# ---------------------------------------------------------------------------
def test_fixed_lag_smoother():
    from filterpy.kalman import FixedLagSmoother

    dt = 1.
    fls = FixedLagSmoother(dim_x=2, dim_z=1, N=4)
    fls.F = np.array([[1., dt], [0., 1.]])
    fls.H = np.array([[1., 0.]])
    fls.R = np.array([[10.]])
    fls.Q = np.array([[0.01, 0.], [0., 0.01]])
    fls.x = np.array([[0.], [1.]])
    fls.P = np.eye(2) * 100.

    true_pos = 0.
    for i in range(50):
        true_pos += 1.0
        z = true_pos + np.random.randn() * np.sqrt(10.)
        fls.smooth(z)

    check('FixedLagSmoother: produces smoothed estimates',
          len(fls.xSmooth) > 0)
    check('FixedLagSmoother: smoothed values are finite',
          all(np.isfinite(fls.xSmooth[-1].flatten())))


# ---------------------------------------------------------------------------
# Run all tests
# ---------------------------------------------------------------------------
test_kalman_filter()
test_ukf()
test_gh_filter()
test_discrete_bayes()
test_monte_carlo()
test_ensemble_kalman()
test_fixed_lag_smoother()

print(f'\n{PASS} passed, {FAIL} failed')
sys.exit(0 if FAIL == 0 else 1)
