diff --git a/pyramid_redis_sessions/__init__.py b/pyramid_redis_sessions/__init__.py index 072b67d..25aeedd 100644 --- a/pyramid_redis_sessions/__init__.py +++ b/pyramid_redis_sessions/__init__.py @@ -217,6 +217,7 @@ def factory(request, new_session_id=get_unique_session_id): set_cookie = functools.partial( _set_cookie, + session, cookie_name=cookie_name, cookie_max_age=cookie_max_age, cookie_path=cookie_path, @@ -233,6 +234,7 @@ def factory(request, new_session_id=get_unique_session_id): ) cookie_callback = functools.partial( _cookie_callback, + session, session_cookie_was_valid=session_cookie_was_valid, cookie_on_exception=cookie_on_exception, set_cookie=set_cookie, @@ -240,6 +242,12 @@ def factory(request, new_session_id=get_unique_session_id): ) request.add_response_callback(cookie_callback) + finished_callback = functools.partial( + _finished_callback, + session + ) + request.add_finished_callback(finished_callback) + return session return factory @@ -264,6 +272,7 @@ def _get_session_id_from_cookie(request, cookie_name, secret): def _set_cookie( + session, request, response, cookie_name, @@ -274,7 +283,11 @@ def _set_cookie( cookie_httponly, secret, ): - cookieval = signed_serialize(request.session.session_id, secret) + """ + `session` is via functools.partial + `request` and `response` are appended by add_response_callback + """ + cookieval = signed_serialize(session.session_id, secret) response.set_cookie( cookie_name, value=cookieval, @@ -291,6 +304,7 @@ def _delete_cookie(response, cookie_name, cookie_path, cookie_domain): def _cookie_callback( + session, request, response, session_cookie_was_valid, @@ -298,8 +312,11 @@ def _cookie_callback( set_cookie, delete_cookie, ): - """Response callback to set the appropriate Set-Cookie header.""" - session = request.session + """ + Response callback to set the appropriate Set-Cookie header. + `session` is via functools.partial + `request` and `response` are appended by add_response_callback + """ if session._invalidated: if session_cookie_was_valid: delete_cookie(response=response) @@ -313,3 +330,17 @@ def _cookie_callback( # still need to delete the existing cookie for the session that the # request started with (as the session has now been invalidated). delete_cookie(response=response) + + +def _finished_callback( + session, + request, + ): + """Finished callback to persist a cookie if needed. + `session` is via functools.partial + `request` is appended by add_finished_callback + """ + if session._session_state.please_persist: + session.do_persist() + elif session._session_state.please_refresh: + session.do_refresh() diff --git a/pyramid_redis_sessions/session.py b/pyramid_redis_sessions/session.py index ca14071..38d2ae0 100644 --- a/pyramid_redis_sessions/session.py +++ b/pyramid_redis_sessions/session.py @@ -17,6 +17,10 @@ class _SessionState(object): + # markers for update + please_persist = None + please_refresh = None + def __init__(self, session_id, managed_dict, created, timeout, new): self.session_id = session_id self.managed_dict = managed_dict @@ -156,6 +160,19 @@ def invalidate(self): # self._session_state) after this will trigger the creation of a new # session with a new session_id. + def do_persist(self): + """actually and immediately persist to Redis backend""" + # Redis is `key, value, timeout` + # StrictRedis is `key, timeout, value` + # this package uses StrictRedis + self.redis.setex(self.session_id, self.timeout, self.to_redis(), ) + self._session_state.please_persist = False + + def do_refresh(self): + """actually and immediately refresh the TTL to Redis backend""" + self.redis.expire(self.session_id, self.timeout) + self._session_state.please_refresh = False + # dict modifying methods decorated with @persist @persist def __delitem__(self, key): diff --git a/pyramid_redis_sessions/tests/__init__.py b/pyramid_redis_sessions/tests/__init__.py index bad8e30..79db256 100644 --- a/pyramid_redis_sessions/tests/__init__.py +++ b/pyramid_redis_sessions/tests/__init__.py @@ -3,6 +3,11 @@ from ..compat import cPickle +class DummySessionState(object): + please_persist = None + please_refresh = None + + class DummySession(object): def __init__(self, session_id, redis, timeout=300, serialize=cPickle.dumps): @@ -12,6 +17,7 @@ def __init__(self, session_id, redis, timeout=300, self.serialize = serialize self.managed_dict = {} self.created = float() + self._session_state = DummySessionState() def to_redis(self): return self.serialize({ @@ -42,6 +48,13 @@ def get(self, key): def set(self, key, value): self.store[key] = value + def setex(self, key, timeout, value): + # Redis is `key, value, timeout` + # StrictRedis is `key, timeout, value` + # this package uses StrictRedis + self.store[key] = value + self.timeouts[key] = timeout + def delete(self, *keys): for key in keys: del self.store[key] diff --git a/pyramid_redis_sessions/tests/test_factory.py b/pyramid_redis_sessions/tests/test_factory.py index 4f00955..d280c7b 100644 --- a/pyramid_redis_sessions/tests/test_factory.py +++ b/pyramid_redis_sessions/tests/test_factory.py @@ -465,6 +465,7 @@ def test_adjusted_session_timeout_persists(self): request = self._make_request() inst = self._makeOne(request) inst.adjust_timeout_for_session(555) + inst.do_persist() session_id = inst.session_id cookieval = self._serialize(session_id) request.cookies['session'] = cookieval diff --git a/pyramid_redis_sessions/tests/test_session.py b/pyramid_redis_sessions/tests/test_session.py index a3bf658..37e37ac 100644 --- a/pyramid_redis_sessions/tests/test_session.py +++ b/pyramid_redis_sessions/tests/test_session.py @@ -91,6 +91,7 @@ def test_delitem(self): inst = self._set_up_session_in_Redis_and_makeOne() inst['key'] = 'val' del inst['key'] + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] self.assertNotIn('key', inst) self.assertNotIn('key', session_dict_in_redis) @@ -98,6 +99,7 @@ def test_delitem(self): def test_setitem(self): inst = self._set_up_session_in_Redis_and_makeOne() inst['key'] = 'val' + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] self.assertIn('key', inst) self.assertIn('key', session_dict_in_redis) @@ -105,12 +107,14 @@ def test_setitem(self): def test_getitem(self): inst = self._set_up_session_in_Redis_and_makeOne() inst['key'] = 'val' + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] self.assertEqual(inst['key'], session_dict_in_redis['key']) def test_contains(self): inst = self._set_up_session_in_Redis_and_makeOne() inst['key'] = 'val' + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] self.assert_('key' in inst) self.assert_('key' in session_dict_in_redis) @@ -125,6 +129,7 @@ def test_keys(self): inst['key1'] = '' inst['key2'] = '' inst_keys = inst.keys() + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] persisted_keys = session_dict_in_redis.keys() self.assertEqual(inst_keys, persisted_keys) @@ -134,6 +139,7 @@ def test_items(self): inst['a'] = 1 inst['b'] = 2 inst_items = inst.items() + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] persisted_items = session_dict_in_redis.items() self.assertEqual(inst_items, persisted_items) @@ -151,6 +157,7 @@ def test_get(self): inst['key'] = 'val' get_from_inst = inst.get('key') self.assertEqual(get_from_inst, 'val') + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] get_from_redis = session_dict_in_redis.get('key') self.assertEqual(get_from_inst, get_from_redis) @@ -183,6 +190,7 @@ def test_update(self): inst.update(to_be_updated) self.assertEqual(inst['a'], 'overriden') self.assertEqual(inst['b'], 2) + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] self.assertEqual(session_dict_in_redis['a'], 'overriden') self.assertEqual(session_dict_in_redis['b'], 2) @@ -399,6 +407,7 @@ def test_mutablevalue_changed(self): tmp = inst['a'] tmp['3'] = 3 inst.changed() + inst.do_persist() session_dict_in_redis = inst.from_redis()['managed_dict'] self.assertEqual(session_dict_in_redis['a'], {'1':1, '2':2, '3':3}) @@ -446,5 +455,6 @@ def test_adjust_timeout_for_session(self): inst = self._set_up_session_in_Redis_and_makeOne(timeout=100) adjusted_timeout = 200 inst.adjust_timeout_for_session(adjusted_timeout) + inst.do_persist() self.assertEqual(inst.timeout, adjusted_timeout) self.assertEqual(inst.from_redis()['timeout'], adjusted_timeout) diff --git a/pyramid_redis_sessions/util.py b/pyramid_redis_sessions/util.py index f99e207..87314e5 100644 --- a/pyramid_redis_sessions/util.py +++ b/pyramid_redis_sessions/util.py @@ -133,28 +133,33 @@ def _parse_settings(settings): return options + def refresh(wrapped): """ Decorator to reset the expire time for this session's key in Redis. + This will mark the `_session_state.please_refresh` as True, to be + handled in a callback. + To immediately persist a session, call `session.do_refresh`. """ def wrapped_refresh(session, *arg, **kw): result = wrapped(session, *arg, **kw) - session.redis.expire(session.session_id, session.timeout) + session._session_state.please_refresh = True return result return wrapped_refresh + def persist(wrapped): """ Decorator to persist in Redis all the data that needs to be persisted for this session and reset the expire time. + This will mark the `_session_state.please_persist` as True, to be + handled in a callback. + To immediately persist a session, call `session.do_persist`. """ def wrapped_persist(session, *arg, **kw): result = wrapped(session, *arg, **kw) - with session.redis.pipeline() as pipe: - pipe.set(session.session_id, session.to_redis()) - pipe.expire(session.session_id, session.timeout) - pipe.execute() + session._session_state.please_persist = True return result return wrapped_persist