Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature "consolidated updates" #63

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions pyramid_redis_sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def factory(request, new_session_id=get_unique_session_id):

set_cookie = functools.partial(
_set_cookie,
session,
cookie_name=cookie_name,
cookie_max_age=cookie_max_age,
cookie_path=cookie_path,
Expand All @@ -233,13 +234,20 @@ def factory(request, new_session_id=get_unique_session_id):
)
cookie_callback = functools.partial(
_cookie_callback,
session,
session_cookie_was_valid=session_cookie_was_valid,
cookie_on_exception=cookie_on_exception,
set_cookie=set_cookie,
delete_cookie=delete_cookie,
)
request.add_response_callback(cookie_callback)

finished_callback = functools.partial(
_finished_callback,
session
)
request.add_finished_callback(finished_callback)

return session

return factory
Expand All @@ -264,6 +272,7 @@ def _get_session_id_from_cookie(request, cookie_name, secret):


def _set_cookie(
session,
request,
response,
cookie_name,
Expand All @@ -274,7 +283,11 @@ def _set_cookie(
cookie_httponly,
secret,
):
cookieval = signed_serialize(request.session.session_id, secret)
"""
`session` is via functools.partial
`request` and `response` are appended by add_response_callback
"""
cookieval = signed_serialize(session.session_id, secret)
response.set_cookie(
cookie_name,
value=cookieval,
Expand All @@ -291,15 +304,19 @@ def _delete_cookie(response, cookie_name, cookie_path, cookie_domain):


def _cookie_callback(
session,
request,
response,
session_cookie_was_valid,
cookie_on_exception,
set_cookie,
delete_cookie,
):
"""Response callback to set the appropriate Set-Cookie header."""
session = request.session
"""
Response callback to set the appropriate Set-Cookie header.
`session` is via functools.partial
`request` and `response` are appended by add_response_callback
"""
if session._invalidated:
if session_cookie_was_valid:
delete_cookie(response=response)
Expand All @@ -313,3 +330,17 @@ def _cookie_callback(
# still need to delete the existing cookie for the session that the
# request started with (as the session has now been invalidated).
delete_cookie(response=response)


def _finished_callback(
session,
request,
):
"""Finished callback to persist a cookie if needed.
`session` is via functools.partial
`request` is appended by add_finished_callback
"""
if session._session_state.please_persist:
session.do_persist()
elif session._session_state.please_refresh:
session.do_refresh()
17 changes: 17 additions & 0 deletions pyramid_redis_sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@


class _SessionState(object):
# markers for update
please_persist = None
please_refresh = None

def __init__(self, session_id, managed_dict, created, timeout, new):
self.session_id = session_id
self.managed_dict = managed_dict
Expand Down Expand Up @@ -156,6 +160,19 @@ def invalidate(self):
# self._session_state) after this will trigger the creation of a new
# session with a new session_id.

def do_persist(self):
"""actually and immediately persist to Redis backend"""
# Redis is `key, value, timeout`
# StrictRedis is `key, timeout, value`
# this package uses StrictRedis
self.redis.setex(self.session_id, self.timeout, self.to_redis(), )
self._session_state.please_persist = False

def do_refresh(self):
"""actually and immediately refresh the TTL to Redis backend"""
self.redis.expire(self.session_id, self.timeout)
self._session_state.please_refresh = False

# dict modifying methods decorated with @persist
@persist
def __delitem__(self, key):
Expand Down
13 changes: 13 additions & 0 deletions pyramid_redis_sessions/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from ..compat import cPickle


class DummySessionState(object):
please_persist = None
please_refresh = None


class DummySession(object):
def __init__(self, session_id, redis, timeout=300,
serialize=cPickle.dumps):
Expand All @@ -12,6 +17,7 @@ def __init__(self, session_id, redis, timeout=300,
self.serialize = serialize
self.managed_dict = {}
self.created = float()
self._session_state = DummySessionState()

def to_redis(self):
return self.serialize({
Expand Down Expand Up @@ -42,6 +48,13 @@ def get(self, key):
def set(self, key, value):
self.store[key] = value

def setex(self, key, timeout, value):
# Redis is `key, value, timeout`
# StrictRedis is `key, timeout, value`
# this package uses StrictRedis
self.store[key] = value
self.timeouts[key] = timeout

def delete(self, *keys):
for key in keys:
del self.store[key]
Expand Down
1 change: 1 addition & 0 deletions pyramid_redis_sessions/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def test_adjusted_session_timeout_persists(self):
request = self._make_request()
inst = self._makeOne(request)
inst.adjust_timeout_for_session(555)
inst.do_persist()
session_id = inst.session_id
cookieval = self._serialize(session_id)
request.cookies['session'] = cookieval
Expand Down
10 changes: 10 additions & 0 deletions pyramid_redis_sessions/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,26 +91,30 @@ def test_delitem(self):
inst = self._set_up_session_in_Redis_and_makeOne()
inst['key'] = 'val'
del inst['key']
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
self.assertNotIn('key', inst)
self.assertNotIn('key', session_dict_in_redis)

def test_setitem(self):
inst = self._set_up_session_in_Redis_and_makeOne()
inst['key'] = 'val'
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
self.assertIn('key', inst)
self.assertIn('key', session_dict_in_redis)

def test_getitem(self):
inst = self._set_up_session_in_Redis_and_makeOne()
inst['key'] = 'val'
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
self.assertEqual(inst['key'], session_dict_in_redis['key'])

def test_contains(self):
inst = self._set_up_session_in_Redis_and_makeOne()
inst['key'] = 'val'
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
self.assert_('key' in inst)
self.assert_('key' in session_dict_in_redis)
Expand All @@ -125,6 +129,7 @@ def test_keys(self):
inst['key1'] = ''
inst['key2'] = ''
inst_keys = inst.keys()
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
persisted_keys = session_dict_in_redis.keys()
self.assertEqual(inst_keys, persisted_keys)
Expand All @@ -134,6 +139,7 @@ def test_items(self):
inst['a'] = 1
inst['b'] = 2
inst_items = inst.items()
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
persisted_items = session_dict_in_redis.items()
self.assertEqual(inst_items, persisted_items)
Expand All @@ -151,6 +157,7 @@ def test_get(self):
inst['key'] = 'val'
get_from_inst = inst.get('key')
self.assertEqual(get_from_inst, 'val')
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
get_from_redis = session_dict_in_redis.get('key')
self.assertEqual(get_from_inst, get_from_redis)
Expand Down Expand Up @@ -183,6 +190,7 @@ def test_update(self):
inst.update(to_be_updated)
self.assertEqual(inst['a'], 'overriden')
self.assertEqual(inst['b'], 2)
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
self.assertEqual(session_dict_in_redis['a'], 'overriden')
self.assertEqual(session_dict_in_redis['b'], 2)
Expand Down Expand Up @@ -399,6 +407,7 @@ def test_mutablevalue_changed(self):
tmp = inst['a']
tmp['3'] = 3
inst.changed()
inst.do_persist()
session_dict_in_redis = inst.from_redis()['managed_dict']
self.assertEqual(session_dict_in_redis['a'], {'1':1, '2':2, '3':3})

Expand Down Expand Up @@ -446,5 +455,6 @@ def test_adjust_timeout_for_session(self):
inst = self._set_up_session_in_Redis_and_makeOne(timeout=100)
adjusted_timeout = 200
inst.adjust_timeout_for_session(adjusted_timeout)
inst.do_persist()
self.assertEqual(inst.timeout, adjusted_timeout)
self.assertEqual(inst.from_redis()['timeout'], adjusted_timeout)
15 changes: 10 additions & 5 deletions pyramid_redis_sessions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,33 @@ def _parse_settings(settings):

return options


def refresh(wrapped):
"""
Decorator to reset the expire time for this session's key in Redis.
This will mark the `_session_state.please_refresh` as True, to be
handled in a callback.
To immediately persist a session, call `session.do_refresh`.
"""
def wrapped_refresh(session, *arg, **kw):
result = wrapped(session, *arg, **kw)
session.redis.expire(session.session_id, session.timeout)
session._session_state.please_refresh = True
return result

return wrapped_refresh


def persist(wrapped):
"""
Decorator to persist in Redis all the data that needs to be persisted for
this session and reset the expire time.
This will mark the `_session_state.please_persist` as True, to be
handled in a callback.
To immediately persist a session, call `session.do_persist`.
"""
def wrapped_persist(session, *arg, **kw):
result = wrapped(session, *arg, **kw)
with session.redis.pipeline() as pipe:
pipe.set(session.session_id, session.to_redis())
pipe.expire(session.session_id, session.timeout)
pipe.execute()
session._session_state.please_persist = True
return result

return wrapped_persist