from sortedcontainers import SortedDict
from dostadmin import app_logger
from dostadmin.db_model import CustomizedProgramSeq, Programseq, ExperimentConditions


class ProgramSequenceService:
    def get_program_type_sequences(self):
        program_sequences = Programseq.query.get_all_sequence_by_sequence_type()
        program_seq_list = []
        for program_seq in program_sequences:
            pseq_obj = CustomizedProgramSeq(
                program_seq.id,
                program_seq.content_id,
                program_seq.program_id,
                program_seq.sequence_index,
                program_seq.day,
                program_seq.week,
            )
            program_seq_list.append(pseq_obj)

        return program_seq_list

    def get_module_wise_formatted_program_sequence(self):
        program_sequences = Programseq.query.get_all_sequence()
        formatted_program_sequence = {}
        experiment_conditions_dict = {}
        experiment_conditions = ExperimentConditions.query.get_experiment_log()

        for condition in experiment_conditions:
            experiment_conditions_dict[condition.id] = condition

        for sequence in program_sequences:
            is_experiment_sequence = (
                sequence.sequence_type == Programseq.SequenceType.EXPERIMENT
            )
            formatted_module_key_name = f"M{sequence.week}_Prog"
            experiment_condition = None
            if is_experiment_sequence:
                experiment_condition = experiment_conditions_dict.get(
                    sequence.experiment_condition_id
                )
                if not experiment_condition:
                    return None

                formatted_module_key_name = f"M{sequence.week}_Exp{experiment_condition.experiment_id}_{experiment_condition.experiment_group_name}"

            if formatted_program_sequence.get(sequence.program_id) is None:
                formatted_program_sequence[sequence.program_id] = {
                    formatted_module_key_name: SortedDict(
                        [[sequence.sequence_index, sequence]]
                    ),
                    f"{formatted_module_key_name}_module_length": None
                    if experiment_condition is None
                    else experiment_condition.value,
                }
            elif (
                formatted_program_sequence[sequence.program_id].get(
                    formatted_module_key_name
                )
                is None
            ):
                formatted_program_sequence[sequence.program_id] = {
                    **formatted_program_sequence[sequence.program_id],
                    **{
                        formatted_module_key_name: SortedDict(
                            [[sequence.sequence_index, sequence]]
                        ),
                        f"{formatted_module_key_name}_module_length": None
                        if experiment_condition is None
                        else experiment_condition.value,
                    },
                }
            else:
                formatted_program_sequence[sequence.program_id][
                    formatted_module_key_name
                ][sequence.sequence_index] = sequence

        return formatted_program_sequence

    def get_content_wise_formatted_program_sequence(self):
        program_sequences = Programseq.query.get_all_sequence()
        formatted_program_sequence = {}

        for sequence in program_sequences:
            if formatted_program_sequence.get(sequence.program_id) is None:
                formatted_program_sequence[sequence.program_id] = {
                    sequence.content_id: SortedDict([[sequence.id, sequence]])
                }
            elif (
                formatted_program_sequence[sequence.program_id].get(sequence.content_id)
                is None
            ):
                formatted_program_sequence[sequence.program_id] = {
                    **formatted_program_sequence[sequence.program_id],
                    **{sequence.content_id: SortedDict([[sequence.id, sequence]])},
                }
            else:
                formatted_program_sequence[sequence.program_id][sequence.content_id][
                    sequence.sequence_index
                ] = sequence

        return formatted_program_sequence

    def get_current_program_sequence(
        self,
        formatted_program_sequences,
        content_id,
        experience,
        special_content_id_list,
    ):
        is_enrolled_in_experiment = experience.experiment_id is not None
        experiment_id = experience.experiment_id
        experiment_group_name = (
            experience.experiment_group_name if is_enrolled_in_experiment else None
        )
        content_wise_formatted_program_sequence_dict = formatted_program_sequences.get(
            "content_wise_formatted_program_sequence_dict"
        )
        program_sequences = content_wise_formatted_program_sequence_dict[
            experience.program_id
        ].get(content_id)
        program_sequence_for_content = None
        error_message = f"Program Sequence Service: Program Sequence not found for content id {content_id} for user number {experience.phone}"

        if program_sequences is None:
            if content_id is None:
                app_logger.warning(error_message)
            elif content_id not in special_content_id_list:
                app_logger.error(error_message)
            return None

        for sequence in program_sequences.values():
            if (
                program_sequence_for_content is None
                and sequence.sequence_type == Programseq.SequenceType.PROGRAM
            ):
                program_sequence_for_content = sequence

            if sequence.sequence_type != (
                Programseq.SequenceType.EXPERIMENT
                if is_enrolled_in_experiment
                else Programseq.SequenceType.PROGRAM
            ):
                continue

            experiment_condition = ExperimentConditions.query.get_by_id(
                sequence.experiment_condition_id
            )

            if (
                experiment_condition
                and experiment_condition.experiment_id == experiment_id
                and experiment_condition.experiment_group_name == experiment_group_name
            ):
                program_sequence_for_content = sequence
                break

        if not program_sequence_for_content:
            if content_id not in special_content_id_list:
                app_logger.error(error_message)
            return None

        return program_sequence_for_content

    def get_next_program_sequence(
        self,
        formatted_program_sequences,
        previous_content_id,
        experience,
        special_content_id_list,
    ):
        is_module_getting_changed = False
        module_wise_formatted_program_sequence_dict = formatted_program_sequences.get(
            "module_wise_formatted_program_sequence_dict"
        )

        ### Logic to find program sequence for last content of user
        previous_delivered_program_sequence = self.get_current_program_sequence(
            formatted_program_sequences,
            previous_content_id,
            experience,
            special_content_id_list,
        )

        if previous_delivered_program_sequence is None:
            ### If no content is yet delivered then returning first program sequence for the program
            next_sequence = self.get_first_sequence_of_a_week(
                formatted_program_sequences, 1, experience
            )

            if next_sequence is None:
                return None, None

            return (next_sequence, is_module_getting_changed)

        ### Logic to find the module to which the last content was associated
        current_module_sequence = module_wise_formatted_program_sequence_dict[
            experience.program_id
        ][
            self.get_key_name_for_module_based_program_sequence(
                previous_delivered_program_sequence, experience
            )
        ]

        ### Logic to find the next sequence
        next_sequence_index = previous_delivered_program_sequence.sequence_index + 1
        next_sequence = current_module_sequence.get(next_sequence_index)

        if next_sequence is None:
            next_sequence = self.get_first_sequence_of_a_week(
                formatted_program_sequences,
                previous_delivered_program_sequence.week + 1,
                experience,
            )

            ### Need to make it dynamic later
            if next_sequence is None:
                next_sequence = self.get_first_sequence_of_a_week(
                    formatted_program_sequences,
                    previous_delivered_program_sequence.week + 2,
                    experience,
                )

            if next_sequence:
                is_module_getting_changed = True

        return (next_sequence, is_module_getting_changed)

    def get_key_name_for_module_based_program_sequence(self, sequence, experience):
        key_name = f"M{sequence.week}_Prog"

        if sequence.sequence_type == Programseq.SequenceType.EXPERIMENT:
            key_name = f"M{sequence.week}_Exp{experience.experiment_id}_{experience.experiment_group_name}"

        return key_name

    def is_content_exist_in_program_sequence(
        self, formatted_program_sequences, content_id, experience
    ):
        content_wise_formatted_program_sequence_dict = formatted_program_sequences.get(
            "content_wise_formatted_program_sequence_dict"
        )

        sequences_for_content = content_wise_formatted_program_sequence_dict[
            experience.program_id
        ].get(content_id)

        if sequences_for_content is None:
            return False

        return True

    def get_first_sequence_of_a_week(
        self, formatted_program_sequences, week_number, experience
    ):
        try:
            if not self.is_module_valid_for_experiment(
                formatted_program_sequences, week_number, experience
            ):
                return None

            is_enrolled_in_experiment = experience.experiment_id is not None
            module_wise_formatted_program_sequence_dict = (
                formatted_program_sequences.get(
                    "module_wise_formatted_program_sequence_dict"
                )
            )
            module_sequence = module_wise_formatted_program_sequence_dict[
                experience.program_id
            ].get(f"M{week_number}_Prog")

            if is_enrolled_in_experiment:
                experiment_sequence = module_wise_formatted_program_sequence_dict[
                    experience.program_id
                ].get(
                    f"M{week_number}_Exp{experience.experiment_id}_{experience.experiment_group_name}"
                )

                if experiment_sequence:
                    module_sequence = experiment_sequence

            if module_sequence is None:
                ### Need to make it an error level log later.
                app_logger.warning(
                    f"Program Sequence Service: First program sequence cannot be found for week {week_number} and program id {experience.program_id} for user number {experience.phone}."
                )
                return None

            first_sequence_index = next(iter(module_sequence))
            first_sequence_of_week = module_sequence[first_sequence_index]

            return first_sequence_of_week
        except Exception as error:
            app_logger.error(
                f"Program Sequence Service: First program sequence cannot be found for week {week_number} and program id {experience.program_id} for user number {experience.phone}. Error message: {error}"
            )
            return None

    def is_module_valid_for_experiment(
        self, formatted_program_sequences, week_number, experience
    ):
        formatted_module_key_name = f"M{week_number}_Exp{experience.experiment_id}_{experience.experiment_group_name}_module_length"
        max_module_for_experiment = formatted_program_sequences.get(
            "module_wise_formatted_program_sequence_dict"
        )[experience.program_id].get(f"{formatted_module_key_name}")

        if max_module_for_experiment and week_number > int(max_module_for_experiment):
            return False

        return True

    def create_program_seq_dicts(self, program_seq_list):
        program_sequence = {}
        content_sequence_map = {}
        for pseq_obj in program_seq_list:
            if pseq_obj.program_id in program_sequence:
                sorted_program_seq = program_sequence[pseq_obj.program_id]
                sorted_program_seq[pseq_obj.sequence_index] = pseq_obj
            else:
                program_sequence[pseq_obj.program_id] = SortedDict(
                    [[pseq_obj.sequence_index, pseq_obj]]
                )

            content_sequence_map[pseq_obj.id] = pseq_obj.sequence_index

        return program_sequence, content_sequence_map

    def is_last_content_of_program(self, program_sequence_list, content_id):
        length_of_program = len(program_sequence_list)

        for program_sequence_key in program_sequence_list:
            program_sequence = program_sequence_list[program_sequence_key]
            if program_sequence.content_id == content_id:
                if program_sequence_key == length_of_program:
                    return True

        return False

    def is_last_content_of_module(self, program_sequence_list, content_id):
        if self.is_last_content_of_program(program_sequence_list, content_id):
            return True

        for program_sequence_key in program_sequence_list:
            program_sequence = program_sequence_list[program_sequence_key]
            if program_sequence.content_id == content_id:
                if program_sequence_key == len(program_sequence_list):
                    return True
                next_program_sequence = program_sequence_list[program_sequence_key + 1]
                if next_program_sequence.day == 1:
                    return True

        return False

    def get_last_content_for_programs(self, cursor):
        sql = """
            WITH max_module_for_program_type_sequences as (
                SELECT program_id, null::integer as experiment_id, null::integer as experiment_condition_id, null as experiment_group_name, max(week) as max_module FROM programseq
                WHERE programseq.sequence_type = 'program'
                GROUP BY 1, 2, 3, 4
            ), max_module_for_experiment_type_sequences as (
                SELECT programseq.program_id, experiment_conditions.experiment_id, programseq.experiment_condition_id, experiment_conditions.experiment_group_name, experiment_conditions.value::Integer as max_module FROM programseq
                JOIN experiment_conditions
                ON experiment_conditions.id = programseq.experiment_condition_id
                AND experiment_conditions.key = 'max_module_length'
                GROUP BY 1, 2, 3, 4, 5
            ), max_module as (
                SELECT * FROM max_module_for_program_type_sequences
                UNION ALL
                SELECT * FROM max_module_for_experiment_type_sequences
            ), max_sequence_index_module_wise as (
                SELECT
                max_module.program_id, max_module.experiment_id, max_module.experiment_group_name, max_module,
                max_module.experiment_condition_id, max(programseq.sequence_index) as max_sequence_index_for_week
                FROM max_module
                JOIN programseq
                ON programseq.program_id = max_module.program_id
                AND programseq.week = max_module
                AND CASE
                		WHEN max_module.experiment_condition_id IS NULL
                		THEN programseq.experiment_condition_id IS NULL
                		ELSE programseq.experiment_condition_id = max_module.experiment_condition_id
                	END
                GROUP BY 1, 2, 3, 4, 5
            )
            SELECT
                msimw.program_id,
                msimw.experiment_id,
                msimw.experiment_group_name,
				programseq.content_id
            FROM max_sequence_index_module_wise as msimw
            JOIN programseq
            on programseq.program_id = msimw.program_id
            	AND programseq.week = msimw.max_module
            	AND programseq.sequence_index = msimw.max_sequence_index_for_week
            	AND CASE
                		WHEN msimw.experiment_condition_id IS NULL
                		THEN programseq.experiment_condition_id IS NULL
                		ELSE programseq.experiment_condition_id = msimw.experiment_condition_id
                	END
        """
        cursor.execute(sql)
        sequences = cursor.fetchall()
        last_contents = []

        for sequence in sequences:
            program_id = sequence[0]
            experiment_id = sequence[1]
            group_name = sequence[2]
            content_id = sequence[3]

            last_contents.append((program_id, content_id, experiment_id, group_name))

        return last_contents
