普通文本  |  314行  |  10.56 KB

# Copyright 2015 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for oauth2client.multistore_file."""

import contextlib
import datetime
import json
import multiprocessing
import os
import tempfile

import fasteners
import mock
from six import StringIO
import unittest2

from oauth2client import client
from oauth2client.contrib import multiprocess_file_storage

from ..http_mock import HttpMockSequence


@contextlib.contextmanager
def scoped_child_process(target, **kwargs):
    die_event = multiprocessing.Event()
    ready_event = multiprocessing.Event()
    process = multiprocessing.Process(
        target=target, args=(die_event, ready_event), kwargs=kwargs)
    process.start()
    try:
        ready_event.wait()
        yield
    finally:
        die_event.set()
        process.join(5)


def _create_test_credentials(expiration=None):
    access_token = 'foo'
    client_secret = 'cOuDdkfjxxnv+'
    refresh_token = '1/0/a.df219fjls0'
    token_expiry = expiration or (
        datetime.datetime.utcnow() + datetime.timedelta(seconds=3600))
    token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
    user_agent = 'refresh_checker/1.0'

    credentials = client.OAuth2Credentials(
        access_token, 'test-client-id', client_secret,
        refresh_token, token_expiry, token_uri,
        user_agent)
    return credentials


def _generate_token_response_http(new_token='new_token'):
    token_response = json.dumps({
        'access_token': new_token,
        'expires_in': '3600',
    })
    http = HttpMockSequence([
        ({'status': '200'}, token_response),
    ])

    return http


class MultiprocessStorageBehaviorTests(unittest2.TestCase):

    def setUp(self):
        filehandle, self.filename = tempfile.mkstemp(
            'oauth2client_test.data')
        os.close(filehandle)

    def tearDown(self):
        try:
            os.unlink(self.filename)
            os.unlink('{0}.lock'.format(self.filename))
        except OSError:  # pragma: NO COVER
            pass

    def test_basic_operations(self):
        credentials = _create_test_credentials()

        store = multiprocess_file_storage.MultiprocessFileStorage(
            self.filename, 'basic')

        # Save credentials
        store.put(credentials)
        credentials = store.get()

        self.assertIsNotNone(credentials)
        self.assertEqual('foo', credentials.access_token)

        # Reset internal cache, ensure credentials were saved.
        store._backend._credentials = {}
        credentials = store.get()

        self.assertIsNotNone(credentials)
        self.assertEqual('foo', credentials.access_token)

        # Delete credentials
        store.delete()
        credentials = store.get()

        self.assertIsNone(credentials)

    def test_single_process_refresh(self):
        store = multiprocess_file_storage.MultiprocessFileStorage(
            self.filename, 'single-process')
        credentials = _create_test_credentials()
        credentials.set_store(store)

        http = _generate_token_response_http()
        credentials.refresh(http)
        self.assertEqual(credentials.access_token, 'new_token')

        retrieved = store.get()
        self.assertEqual(retrieved.access_token, 'new_token')

    def test_multi_process_refresh(self):
        # This will test that two processes attempting to refresh credentials
        # will only refresh once.
        store = multiprocess_file_storage.MultiprocessFileStorage(
            self.filename, 'multi-process')
        credentials = _create_test_credentials()
        credentials.set_store(store)
        store.put(credentials)

        def child_process_func(
                die_event, ready_event, check_event):  # pragma: NO COVER
            store = multiprocess_file_storage.MultiprocessFileStorage(
                self.filename, 'multi-process')

            credentials = store.get()
            self.assertIsNotNone(credentials)

            # Make sure this thread gets to refresh first.
            original_acquire_lock = store.acquire_lock

            def replacement_acquire_lock(*args, **kwargs):
                result = original_acquire_lock(*args, **kwargs)
                ready_event.set()
                check_event.wait()
                return result

            credentials.store.acquire_lock = replacement_acquire_lock

            http = _generate_token_response_http('b')
            credentials.refresh(http)

            self.assertEqual(credentials.access_token, 'b')

        check_event = multiprocessing.Event()
        with scoped_child_process(child_process_func, check_event=check_event):
            # The lock should be currently held by the child process.
            self.assertFalse(
                store._backend._process_lock.acquire(blocking=False))
            check_event.set()

            # The child process will refresh first, so we should end up
            # with 'b' as the token.
            http = mock.Mock()
            credentials.refresh(http=http)
            self.assertEqual(credentials.access_token, 'b')
            self.assertFalse(http.request.called)

        retrieved = store.get()
        self.assertEqual(retrieved.access_token, 'b')

    def test_read_only_file_fail_lock(self):
        credentials = _create_test_credentials()

        # Grab the lock in another process, preventing this process from
        # acquiring the lock.
        def child_process(die_event, ready_event):  # pragma: NO COVER
            lock = fasteners.InterProcessLock(
                '{0}.lock'.format(self.filename))
            with lock:
                ready_event.set()
                die_event.wait()

        with scoped_child_process(child_process):
            store = multiprocess_file_storage.MultiprocessFileStorage(
                self.filename, 'fail-lock')
            store.put(credentials)
            self.assertTrue(store._backend._read_only)

        # These credentials should still be in the store's memory-only cache.
        self.assertIsNotNone(store.get())


class MultiprocessStorageUnitTests(unittest2.TestCase):

    def setUp(self):
        filehandle, self.filename = tempfile.mkstemp(
            'oauth2client_test.data')
        os.close(filehandle)

    def tearDown(self):
        try:
            os.unlink(self.filename)
            os.unlink('{0}.lock'.format(self.filename))
        except OSError:  # pragma: NO COVER
            pass

    def test__create_file_if_needed(self):
        self.assertFalse(
            multiprocess_file_storage._create_file_if_needed(self.filename))
        os.unlink(self.filename)
        self.assertTrue(
            multiprocess_file_storage._create_file_if_needed(self.filename))
        self.assertTrue(
            os.path.exists(self.filename))

    def test__get_backend(self):
        backend_one = multiprocess_file_storage._get_backend('file_a')
        backend_two = multiprocess_file_storage._get_backend('file_a')
        backend_three = multiprocess_file_storage._get_backend('file_b')

        self.assertIs(backend_one, backend_two)
        self.assertIsNot(backend_one, backend_three)

    def test__read_write_credentials_file(self):
        credentials = _create_test_credentials()
        contents = StringIO()

        multiprocess_file_storage._write_credentials_file(
            contents, {'key': credentials})

        contents.seek(0)
        data = json.load(contents)
        self.assertEqual(data['file_version'], 2)
        self.assertTrue(data['credentials']['key'])

        # Read it back.
        contents.seek(0)
        results = multiprocess_file_storage._load_credentials_file(contents)
        self.assertEqual(
            results['key'].access_token, credentials.access_token)

        # Add an invalid credential and try reading it back. It should ignore
        # the invalid one but still load the valid one.
        data['credentials']['invalid'] = '123'
        results = multiprocess_file_storage._load_credentials_file(
            StringIO(json.dumps(data)))
        self.assertNotIn('invalid', results)
        self.assertEqual(
            results['key'].access_token, credentials.access_token)

    def test__load_credentials_file_invalid_json(self):
        contents = StringIO('{[')
        self.assertEqual(
            multiprocess_file_storage._load_credentials_file(contents), {})

    def test__load_credentials_file_no_file_version(self):
        contents = StringIO('{}')
        self.assertEqual(
            multiprocess_file_storage._load_credentials_file(contents), {})

    def test__load_credentials_file_bad_file_version(self):
        contents = StringIO(json.dumps({'file_version': 1}))
        self.assertEqual(
            multiprocess_file_storage._load_credentials_file(contents), {})

    def test__load_credentials_no_open_file(self):
        backend = multiprocess_file_storage._get_backend(self.filename)
        backend._credentials = mock.Mock()
        backend._credentials.update.side_effect = AssertionError()
        backend._load_credentials()

    def test_acquire_lock_nonexistent_file(self):
        backend = multiprocess_file_storage._get_backend(self.filename)
        os.unlink(self.filename)
        backend._process_lock = mock.Mock()
        backend._process_lock.acquire.return_value = False
        backend.acquire_lock()
        self.assertIsNone(backend._file)

    def test_release_lock_with_no_file(self):
        backend = multiprocess_file_storage._get_backend(self.filename)
        backend._file = None
        backend._read_only = True
        backend._thread_lock.acquire()
        backend.release_lock()

    def test__refresh_predicate(self):
        backend = multiprocess_file_storage._get_backend(self.filename)

        credentials = _create_test_credentials()
        self.assertFalse(backend._refresh_predicate(credentials))

        credentials.invalid = True
        self.assertTrue(backend._refresh_predicate(credentials))

        credentials = _create_test_credentials(
            expiration=(
                datetime.datetime.utcnow() - datetime.timedelta(seconds=3600)))
        self.assertTrue(backend._refresh_predicate(credentials))


if __name__ == '__main__':  # pragma: NO COVER
    unittest2.main()