from __future__ import absolute_import
from flask_sqlalchemy.query import Query as BaseQuery
from sqlalchemy import desc, func
from dostadmin import db
from dostadmin.mixins import TimestampMixin


class ProgramseqQuery(BaseQuery):
    def get_program_seq_with_id(self, programseq_id):
        return self.filter(Programseq.id == programseq_id).first()

    def get_first_sequence_of_program(self, program_id):
        return self.filter(
            Programseq.program_id == program_id, Programseq.sequence_index == 1
        ).first()

    def get_week_two_sequence_by_program(self, program_id):
        return self.filter(
            Programseq.program_id == program_id,
            Programseq.week == 2,
            Programseq.day == 1,
        ).first()

    def get_programseq_id_with_content_id(self, content_id):
        programseq = (
            self.filter(Programseq.content_id == content_id)
            .order_by(Programseq.sequence_index)
            .first()
        )
        if programseq:
            return programseq.id
        return None

    def get_program_len(self, program_id):
        return (
            self.filter(Programseq.program_id == program_id)
            .distinct(Programseq.week)
            .count()
        )

    def get_program_len_by_content(self, program_id):
        return self.filter(Programseq.program_id == program_id).count()

    def get_program_seq_by_content_and_program(self, program_id, content_id):
        return self.filter(
            Programseq.program_id == program_id,
            Programseq.content_id == content_id,
        ).first()

    def get_all_sequence(self):
        return self.all()

    def get_all_sequence_by_sequence_type(self, sequence_type="program"):
        return self.filter(Programseq.sequence_type == sequence_type).all()

    def get_last_content_of_experiment_conditions(self):
        # Creating a subquery to add row numbers
        subquery = (
            db.session.query(
                Programseq.content_id,
                Programseq.experiment_condition_id,
                Programseq.program_id,
                Programseq.sequence_index,
                func.row_number()
                .over(
                    partition_by=(
                        Programseq.experiment_condition_id,
                        Programseq.program_id,
                    ),
                    order_by=Programseq.sequence_index.desc(),
                )
                .label("rn"),
            )
            .filter(Programseq.sequence_type == Programseq.SequenceType.EXPERIMENT)
            .subquery()
        )

        # Main query to select rows with row number 1
        last_content_rows = (
            db.session.query(
                subquery.c.content_id,
                subquery.c.experiment_condition_id,
                subquery.c.program_id,
                subquery.c.sequence_index,
            )
            .filter(subquery.c.rn == 1)
            .all()
        )

        return last_content_rows


class CustomizedProgramSeq:
    def __init__(
        self, programseq_id, content_id, program_id, sequence_index, day, week
    ):
        self.id = programseq_id
        self.content_id = content_id
        self.program_id = program_id
        self.sequence_index = sequence_index
        self.day = day
        self.week = week


class Programseq(TimestampMixin, db.Model):
    __tablename__ = "programseq"
    query_class = ProgramseqQuery

    class SequenceType:
        PROGRAM = "program"
        EXPERIMENT = "experiment"

    id = db.Column(db.Integer, primary_key=True)
    content_id = db.Column(db.Integer, db.ForeignKey("content.id"))
    program_id = db.Column(db.Integer, db.ForeignKey("program.id"))
    sequence_index = db.Column(db.Integer)
    week = db.Column(db.Integer)
    day = db.Column(db.Integer)
    experiment_condition_id = db.Column(
        db.Integer, db.ForeignKey("experiment_conditions.id"), nullable=True
    )
    sequence_type = db.Column(db.String(50))

    def __repr__(self):
        return (
            "\n Programseq: id_ "
            + str(self.id)
            + ", program: "
            + str(self.program_id)
            + ", content: "
            + str(self.content_id)
            + ", sequence: "
            + str(self.sequence_index)
        )

    @classmethod
    def check_if_last_content(cls, content_id, program_id):
        last_content = (
            Programseq.query.filter(Programseq.program_id == program_id)
            .order_by(desc(Programseq.sequence_index))
            .first()
        )

        if not last_content:
            return False

        return last_content.content_id == content_id
