Skip to content
55 changes: 16 additions & 39 deletions zeus/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
from helios.models import Election, Poll, Trustee, Voter
from heliosauth.models import User

from zeus.log import init_election_logger, init_poll_logger, _locals, \
_close_logger
from zeus.utils import resolve_ip
from zeus.log import init_election_logger, init_poll_logger, _locals

import logging
logger = logging.getLogger(__name__)


AUTH_RE = re.compile('Basic (\w+[=]*)')

def get_ip(request):
ip = request.META.get('HTTP_X_FORWARDER_FOR', None)
if not ip:
ip = request.META.get('REMOTE_ADDR')
return ip

def class_method(func):
def wrapper(self, request, *args, **kwargs):
Expand Down Expand Up @@ -52,61 +55,35 @@ def inner(request, *args, **kwargs):
allow_manager = getattr(func, '_allow_manager', False)
_check_access = check_access
user = request.zeususer
user_id = None
if user.is_authenticated():
try:
user_id = user.user_id
_locals.user_id = user_id
_locals.user_id = user.user_id
except Exception:
raise PermissionDenied("Election cannot be accessed by you")
ip = resolve_ip(request)
_locals.ip = ip
_locals.ip = get_ip(request)

if allow_manager and user.is_manager:
_check_access = False

logging_locals = {
'user_id': user_id,
'ip': ip
}

if 'election_uuid' in kwargs:
uuid = kwargs.pop('election_uuid')
election = get_object_or_404(Election, uuid=uuid)
if not user.can_access_election(election) and _check_access:
raise PermissionDenied("Election cannot be accessed by you")
kwargs['election'] = election
setattr(election, '_logging_locals', logging_locals)

if 'poll_uuid' in kwargs:
uuid = kwargs.pop('poll_uuid')
poll = get_object_or_404(Poll, uuid=uuid)
if not user.can_access_poll(poll) and _check_access:
raise PermissionDenied("Poll cannot be accessed by you")
kwargs['poll'] = poll
setattr(poll, '_logging_locals', logging_locals)

resp = func(request, *args, **kwargs)
if 'poll' in kwargs:
_close_logger(kwargs['poll'])
if 'election' in kwargs:
_close_logger(kwargs['election'])

return resp
return func(request, *args, **kwargs)
return inner
return wrapper


def poll_voter_or_admin_required(func):
@election_view()
@wraps(func)
def wrapper(request, *args, **kwargs):
if not request.zeususer.is_voter and not request.zeususer.is_admin:
raise PermissionDenied("Voter or admin can only access this view.")
return func(request, *args, **kwargs)
return wrapper


def poll_voter_required(func):
@election_view()
@wraps(func)
Expand Down Expand Up @@ -237,7 +214,7 @@ def from_request(self, request):
user = None
try:
users = get_users_from_request(request)
user = filter(lambda x:bool(x), users)[0]
user = [x for x in users if bool(x)][0]
except IndexError:
pass
return ZeusUser(user)
Expand All @@ -260,7 +237,8 @@ def __init__(self, user_obj):
self.is_trustee = True

if isinstance(self._user, Voter):
self.is_voter = True
if not self._user.excluded_at:
self.is_voter = True

@property
def user_id(self):
Expand Down Expand Up @@ -340,7 +318,7 @@ def get_users_from_request(request):
user, admin, trustee, voter = None, None, None, None

# identify user and admin
if session.has_key(USER_SESSION_KEY):
if USER_SESSION_KEY in session:
user = request.session[USER_SESSION_KEY]
try:
user = User.objects.get(pk=user)
Expand All @@ -350,7 +328,7 @@ def get_users_from_request(request):
pass

# idenitfy voter
if session.has_key(VOTER_SESSION_KEY):
if VOTER_SESSION_KEY in session:
voter = request.session[VOTER_SESSION_KEY]

try:
Expand Down Expand Up @@ -400,7 +378,7 @@ def get_users_from_request(request):
admin = None

# cleanup duplicate logins
if len(filter(lambda x:bool(x), [voter, trustee, admin])) > 1:
if len([x for x in [voter, trustee, admin] if bool(x)]) > 1:
if voter:
if trustee:
del session[TRUSTEE_SESSION_KEY]
Expand All @@ -414,10 +392,9 @@ def get_users_from_request(request):

def allow_manager_access(func):
func._allow_manager = True
func.func_globals['foo'] = 'bar'
func.__globals__['foo'] = 'bar'
return func


def make_shibboleth_login_url(endpoint):
shibboleth_login = reverse('shibboleth_login', kwargs={'endpoint': endpoint})
url = '/'.join(s.strip('/') for s in filter(bool,[
Expand Down
Loading