'''
Скрипт predict.py должен:

- считать модель из файла;
- получить из аргумента ID пассажира;
- дать ответ выжил этот пассажир или нет.

Источник: https://www.phind.com
'''

import psycopg2
import pickle
from catboost import CatBoostClassifier

def predict_survival(passenger_id):
    """Предсказание выживаемости пассажира по ID"""
    # Подключение к базе данных
    conn = psycopg2.connect(
        dbname='titanic',
        user="pupkin",
        password="1q2w3e"
    )

    cur = conn.cursor()

    # Получаем данные пассажира
    cur.execute("""
        SELECT
            age_norm,
            pclass_norm,
            CASE WHEN sex = 'male' THEN 0 ELSE 1 END as sex,
            sibsp_norm,
            parch_norm,
            fare_norm,
            CASE WHEN embarked = 'S' THEN 0
                 WHEN embarked = 'C' THEN 1
                 ELSE 2 END as embarked
        FROM tit3_norm
        WHERE id = %s
    """, (passenger_id,))

    passenger_data = cur.fetchone()
    conn.close()

    if not passenger_data:
        raise ValueError(f"Пассажир с ID {passenger_id} не найден")

    # Загружаем обученную модель
    model = CatBoostClassifier()
    model.load_model('titanic_survival_model.cbm', format='cbm')

    # Делаем предсказание
    prediction = model.predict([passenger_data])[0]

    return bool(prediction)

if __name__ == "__main__":
    import sys

    if len(sys.argv) != 2:
        print("Использование: python script.py ")
        print("Пример: python script.py 42")
        sys.exit(1)

    try:
        passenger_id = int(sys.argv[1])
        if passenger_id <= 0:
            raise ValueError("ID должен быть положительным числом")
    except ValueError as e:
        print(f"Ошибка: Неверный формат ID. ID должен быть положительным целым числом")
        print("Пример использования: python script.py 42")
        sys.exit(1)

    try:
        survived = predict_survival(passenger_id)
        print(f"Пассажир с ID {passenger_id} {'выжил' if survived else 'не выжил'}")
    except Exception as e:
        print(f"Ошибка: {str(e)}")



# End predict.py