from bson import json_util
from django.core.exceptions import SuspiciousOperation
from django.utils.encoding import force_str
from mongoengine import Document, fields, OperationError
from django.contrib.sessions.backends.base import CreateError, SessionBase
from django.utils import timezone

from django.conf import settings
import datetime

SESSION_COLLECTION = "django_session"
MONGOENGINE_SESSION_DATA_ENCODE = True

class MongoSession(Document):
    meta = {
        "collection": SESSION_COLLECTION,
        "id_field": "session_key"
    }

    _id = fields.ObjectIdField()
    session_key = fields.StringField(unique=True, max_length=40)
    session_data = fields.StringField()  # if MONGOENGINE_SESSION_DATA_ENCODE else fields.DictField()
    expire_date = fields.DateTimeField()

    def get_decoded(self):
        return SessionStore().decode(self.session_data)


class SessionStore(SessionBase):
    def create(self):
        while True:
            self._session_key = self._get_new_session_key()
            try:
                self.save(must_create=True)
            except CreateError:
                continue
            self.modified = True
            return

    def save(self, must_create=False):
        new_session_key = False
        session_existing = True

        if self.session_key is None:
            self.create()
            new_session_key = True
        data = self._get_session(no_load=must_create)

        s = None
        if not new_session_key: #Grab existing session object
            s = MongoSession.objects(session_key=self.session_key).first()

        if not s: #Create new session object
            session_existing = False
            s = MongoSession(session_key=self.session_key)

        if MONGOENGINE_SESSION_DATA_ENCODE:
            s.session_data = self.encode(data)
        else:
            s.session_data = data
        s.expire_date = self.get_expiry_date()
        try:
            # Prevent race condition where new session is inserted simultaneously with delete
            if session_existing or not must_create:
                s.update(session_key=s.session_key, session_data=s.session_data, expire_date=s.expire_date)
            else:
                s.save(force_insert=must_create)
        except OperationError:
            if must_create:
                raise CreateError
            raise

    def delete(self, session_key=None):
        if session_key:
            MongoSession.objects(session_key=session_key).delete()
        elif self.session_key:
            MongoSession.objects(session_key=self.session_key).delete()

    def load(self):
        try:
            s = MongoSession.objects(session_key=self.session_key, expire_date__gt=timezone.now())[0]

            session_mongo_id = ''
            if getattr(s, '_id'):
                session_mongo_id = s._id
            self._mongo_id = session_mongo_id

            if MONGOENGINE_SESSION_DATA_ENCODE:
                return self.decode(force_str(s.session_data))
            else:
                return s.session_data
        except (IndexError, SuspiciousOperation) as e:
            self._session_key = None
            return {}

    @classmethod
    def clear_expired(cls):
        MongoSession.objects(expire_date__lt=timezone.now()).delete()

    def exists(self, session_key):
        return bool(MongoSession.objects(session_key=session_key).first())

class BSONSerializer:
    """
    Serializer that can handle BSON types (eg ObjectId).
    """

    def dumps(self, obj):
        return json_util.dumps(obj, separators=(",", ":")).encode("ascii")

    def loads(self, data):
        return json_util.loads(data.decode("ascii"))