'''
Скрипт control.py должен:
- добавить в таблицу tit3_alive поле control типа int;
- используя модель, заполнить новую колонку предсказанными
  моделью данными: выжил -> 1, не выжил -> 0;
- найти число ошибок и их процент.

Источник: https://www.phind.com
'''

import psycopg2
import pickle
from catboost import CatBoostClassifier

def add_control_column():
    """Добавление колонки control в таблицу tit3_alive"""
    conn = psycopg2.connect(
        dbname='titanic',
        user="pupkin",
        password="1q2w3e"
    )

    cur = conn.cursor()
    cur.execute("""
        ALTER TABLE tit3_alive
        ADD COLUMN IF NOT EXISTS control INTEGER;
    """)
    conn.commit()
    conn.close()

def predict_and_update():
    """Предсказание выживаемости и обновление таблицы"""
    conn = psycopg2.connect(
        dbname='titanic',
        user="pupkin",
        password="1q2w3e"
    )

    cur = conn.cursor()

    # Загружаем модель
    model = CatBoostClassifier()
    model.load_model('titanic_survival_model.cbm', format='cbm')

    # Получаем все ID из tit3_norm
    cur.execute("SELECT id FROM tit3_norm")
    all_ids = [row[0] for row in cur.fetchall()]

    # Предсказываем для всех ID
    predictions = []
    for id_batch in range(0, len(all_ids), 1000):
        batch = all_ids[id_batch:id_batch + 1000]
        cur.execute("""
            SELECT
                age_norm,
                pclass_norm,
                CASE WHEN sex = 'male' THEN 0 ELSE 1 END as sex,
                sibsp_norm,
                parch_norm,
                fare_norm,
                CASE WHEN embarked = 'S' THEN 0
                     WHEN embarked = 'C' THEN 1
                     ELSE 2 END as embarked
            FROM tit3_norm
            WHERE id = ANY(%s)
        """, (batch,))

        batch_data = cur.fetchall()
        pred = model.predict(batch_data)
        predictions.extend(pred)

    # Обновляем таблицу
    for id_val, pred_val in zip(all_ids, predictions):
        cur.execute("""
            UPDATE tit3_alive
            SET control = %s
            WHERE id = %s
        """, (int(pred_val), id_val))

    conn.commit()
    conn.close()

def calculate_errors():
    """Расчёт количества и процента ошибок"""
    conn = psycopg2.connect(
        dbname='titanic',
        user="pupkin",
        password="1q2w3e"
    )

    cur = conn.cursor()
    cur.execute("""
        SELECT
            COUNT(*) as total,
            SUM(CASE WHEN survived != control THEN 1 ELSE 0 END) as errors
        FROM tit3_alive
    """)

    total, errors = cur.fetchone()
    error_rate = (errors / total) * 100 if total > 0 else 0

    conn.close()

    return total, errors, error_rate

if __name__ == "__main__":
    try:
        # Добавляем колонку control
        add_control_column()

        # Заполняем предсказаниями
        predict_and_update()

        # Вычисляем ошибки
        total, errors, error_rate = calculate_errors()

        print(f"Всего записей: {total}")
        print(f"Количество ошибок: {errors}")
        print(f"Процент ошибок: {error_rate:.2f}%")

    except Exception as e:
        print(f"Ошибка: {str(e)}")


# End 205control.py