Source code for RestAuth.backends.memory_backend

# -*- coding: utf-8 -*-
#
# This file is part of RestAuth (https://restauth.net).
#
# RestAuth is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RestAuth is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with RestAuth.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import unicode_literals

from collections import defaultdict
from datetime import datetime

from django.conf import settings
from django.utils import six

from RestAuth.backends.base import GroupInstance
from RestAuth.backends.base import UserInstance
from RestAuth.common.errors import GroupExists
from RestAuth.common.errors import GroupNotFound
from RestAuth.common.errors import PropertyExists
from RestAuth.common.errors import PropertyNotFound
from RestAuth.common.errors import UserExists
from RestAuth.common.errors import UserNotFound


class MemoryUserInstance(UserInstance):
    def __init__(self, id, username, password=None):
        super(MemoryUserInstance, self).__init__(id, username)
        self.password = password


class MemoryGroupInstance(GroupInstance):
    def __init__(self, id, name, service):
        super(MemoryGroupInstance, self).__init__(id, name, service)
        self._members = set()
        self.parents = set()
        self.children = set()

    def add_user(self, user):
        self._members.add(user.username)

    def rm_user(self, user):
        try:
            self._members.remove(user.username)
        except KeyError:
            raise UserNotFound(user.username)

    def members(self, depth=None):
        users = self._members.copy()
        if depth == 0:
            return users

        if depth is None:
            depth = settings.GROUP_RECURSION_DEPTH

        for parent in self.parents:
            users |= parent.members(depth=depth - 1)
        return users

    def is_member(self, user):
        return user.username in self.members()

    def subgroups(self, filter=True):
        if filter:
            return [g for g in self.children if g.service == self.service]
        else:
            return self.children.copy()

    def add_subgroup(self, subgroup):
        subgroup.parents.add(self)
        self.children.add(subgroup)

    def rm_subgroup(self, subgroup):
        try:
            self.children.remove(subgroup)
            subgroup.parents.remove(self)
        except KeyError:
            raise GroupNotFound(subgroup.name)

    def __eq__(self, other):
        return self.service == other.service and self.name == other.name


[docs]class MemoryUserBackend(object): """Dummy backend that stores all users in memory (for debugging purposes). Please note the obvious: This backend should *never be used in a production environment*. Any restart of the server software will completely wipe all data. """ def __init__(self): self._users = {} def get(self, username): try: return self._users[username] except KeyError: raise UserNotFound(username) def list(self): return six.iterkeys(self._users) def create(self, username, password=None, properties=None, property_backend=None, dry=False, transaction=True): if username in self._users: raise UserExists(username) user_id = id(username) user = MemoryUserInstance(user_id, username, password) if properties is None: stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") properties = {'date joined': stamp} elif 'date joined' not in properties: stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") properties['date joined'] = stamp property_backend.set_multiple(user, properties, dry=dry, transaction=transaction) if not dry: self._users[username] = user return user def exists(self, username): return username in self._users def check_password(self, username, password): try: user = self._users[username] if user.password is None or user.password == '': return False return user.password == password except KeyError: raise UserNotFound(username) def set_password(self, username, password=None): try: self._users[username].password = password except KeyError: raise UserNotFound(username) def remove(self, username): try: del self._users[username] except KeyError: raise UserNotFound(username) def testSetUp(self): pass def testTearDown(self): self._users = {}
[docs]class MemoryPropertyBackend(object): """Dummy backend that stores all properties in memory (for debugging). Please note the obvious: This backend should *never be used in a production environment*. Any restart of the server software will completely wipe all data. """ def __init__(self): self._properties = defaultdict(dict) def list(self, user): return dict(self._properties[user.username]) def create(self, user, key, value, dry=False): name = user.username if key in self._properties[name]: raise PropertyExists(name) if not dry: self._properties[name][key] = value return key, value def get(self, user, key): try: return self._properties[user.username][key] except KeyError: raise PropertyNotFound(key) def set(self, user, key, value, dry=False, transaction=True): old = self._properties[user.username].get(key, None) if not dry: self._properties[user.username][key] = value return key, old def set_multiple(self, user, props, dry=False, transaction=True): if not dry: self._properties[user.username].update(props) def remove(self, user, key): try: del self._properties[user.username][key] except KeyError: raise PropertyNotFound(key) def testSetUp(self): pass def testTearDown(self): self._properties = defaultdict(dict)
[docs]class MemoryGroupBackend(object): """Dummy backend that stores all groups in memory (for debugging). Please note the obvious: This backend should *never be used in a production environment*. Any restart of the server software will completely wipe all data. """ def __init__(self): self._groups = defaultdict(dict) def get(self, name, service=None): try: return self._groups[self._service(service)][name] except KeyError: raise GroupNotFound(name) def _service(self, service): if service is None: return service else: return service.id def list(self, service, user=None): if user is None: return six.iterkeys(self._groups[self._service(service)]) else: groups = self._groups[self._service(service)] return [k for k, v in six.iteritems(groups) if v.is_member(user)] def create(self, name, service=None, dry=False, transaction=True): if name in self._groups[self._service(service)]: raise GroupExists(name) group = MemoryGroupInstance(service=service, id=id(name), name=name) if not dry: self._groups[self._service(service)][name] = group return group def exists(self, name, service=None): return name in self._groups[self._service(service)] def add_user(self, group, user): self._groups[self._service(group.service)][group.name].add_user(user) def members(self, group, depth=None): return list(group.members(depth=depth)) def is_member(self, group, user): return group.is_member(user) def rm_user(self, group, user): return group.rm_user(user) def add_subgroup(self, group, subgroup): group.add_subgroup(subgroup) def subgroups(self, group, filter=True): subgroups = group.subgroups(filter=filter) if filter: return [g.name for g in subgroups] else: return subgroups def rm_subgroup(self, group, subgroup): group.rm_subgroup(subgroup) def remove(self, group): service_id = self._service(group.service) if group.name in self._groups[service_id]: del self._groups[service_id][group.name] else: raise GroupNotFound(group.name) def parents(self, group): return list(group.parents) def testSetUp(self): pass def testTearDown(self): self._groups = defaultdict(dict)