from __future__ import absolute_import
from flask_sqlalchemy.query import Query as BaseQuery
from sqlalchemy.dialects.postgresql import JSONB
from dostadmin import db
from dostadmin.mixins import TimestampMixin


class CohortQuery(BaseQuery):
    def get_by_id(self, cohort_id):
        return self.filter(Cohort.id == cohort_id).first()


class Cohort(TimestampMixin, db.Model):
    __tablename__ = "cohort"
    query_class = CohortQuery

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(20))
    experiment_id = db.Column(db.Integer, db.ForeignKey("experiment.id"))
    start_date = db.Column(db.Date)
    end_date = db.Column(db.Date)
    inputs = db.Column(JSONB)
    outputs = db.Column(JSONB)

    cohort_details = db.relationship(
        "CohortDetails",
        backref="cohort",
        primaryjoin="Cohort.id == CohortDetails.cohort_id",
    )
