# No shebang line, this module is meant to be imported
#
# Copyright 2014 Oliver Palmer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test Utilities
==============
Functions and classes mainly used during the unittests.
"""
import json
import os
import time
import warnings
import uuid
from unittest import TestCase
try:
from httplib import (
OK, CREATED, ACCEPTED, NO_CONTENT, BAD_REQUEST, UNAUTHORIZED,
FORBIDDEN, NOT_FOUND, NOT_ACCEPTABLE, INTERNAL_SERVER_ERROR, CONFLICT,
UNSUPPORTED_MEDIA_TYPE, METHOD_NOT_ALLOWED, TEMPORARY_REDIRECT)
except ImportError:
from http.client import (
OK, CREATED, ACCEPTED, NO_CONTENT, BAD_REQUEST, UNAUTHORIZED,
FORBIDDEN, NOT_FOUND, NOT_ACCEPTABLE, INTERNAL_SERVER_ERROR, CONFLICT,
UNSUPPORTED_MEDIA_TYPE, METHOD_NOT_ALLOWED, TEMPORARY_REDIRECT)
try:
from UserDict import UserDict
except ImportError:
from collections import UserDict
try:
import blinker
except ImportError:
blinker = NotImplemented
from flask import Response, json_available
from sqlalchemy.exc import SAWarning
from werkzeug.utils import cached_property
from pyfarm.master.application import get_application, db, before_request
[docs]class JsonResponseMixin(object):
"""
Mixin with testing helper methods
"""
@cached_property
[docs] def json(self):
if not json_available: # pragma: no cover
raise NotImplementedError
return json.loads(self.data.decode("utf-8"))
[docs]def make_test_response(response_class=None):
if response_class is None:
return
class TestResponse(response_class, JsonResponseMixin):
pass
return TestResponse
[docs]class BaseTestCase(TestCase):
ENVIRONMENT_SETUP = False
ORIGINAL_ENVIRONMENT = os.environ.copy()
maxDiff = None
@classmethod
[docs] def build_environment(cls):
"""
Sets up the current environment with some values for
unittesting. This must be used before any other code
is imported otherwise
.. warning::
This classmethod should not be used outside of a testing context
"""
# Override the table prefix so tests are not done in the same table
# namespace as other tests. Note that although 'db' is imported
# up above the table prefix itself is not used until the models
# are initially imported (below).
from pyfarm.master.config import config
config["table_prefix"] = "test%s_" % time.strftime("%M%d%Y%H%M%S")
# import all the models we have so the relationships
# can be setup properly
from pyfarm.models.disk import AgentDisk
from pyfarm.models.agent import Agent
from pyfarm.models.job import Job
from pyfarm.models.jobtype import JobType
from pyfarm.models.software import (
Software, SoftwareVersion, JobSoftwareRequirement,
JobTypeSoftwareRequirement)
from pyfarm.models.tag import Tag
from pyfarm.models.task import Task
from pyfarm.models.user import User
from pyfarm.models.jobqueue import JobQueue
from pyfarm.models.gpu import GPU
from pyfarm.models.jobgroup import JobGroup
# set ENVIRONMENT_SETUP so the tests will run
cls.ENVIRONMENT_SETUP = True
[docs] def setup_warning_filter(self):
for warning_class in (SAWarning, ):
warnings.simplefilter("ignore", warning_class)
[docs] def teardown_warning_filter(self):
for warning_class in (SAWarning, ):
warning_entry = ("ignore", None, warning_class, None, 0)
while warning_entry in warnings.filters:
warnings.filters.remove(warning_entry)
[docs] def setup_app(self):
"""
Constructs the application object and assigns the instance
variables for tests. If you're testing the master your
sublcass will probably need to extend this method.
"""
environment = os.environ.copy()
environment.setdefault("app_name", uuid.uuid4().hex)
self.app = get_application(**environment)
@self.app.before_request
def before_request_handler():
return before_request()
# construct response class so we can use the json methods
# in our handlers
self._original_response_class = self.app.response_class
self.app.response_class = make_test_response(self.app.response_class)
# construct and push the context
self._context = self.app.test_request_context()
self._context.push()
[docs] def setup_client(self, app):
"""returns the test client from the given application instance"""
self.client = app.test_client()
[docs] def setup_database(self):
db.create_all()
[docs] def teardown_database(self):
db.session.remove()
db.drop_all()
[docs] def teardown_app(self):
self.app.response_class = self._original_response_class
[docs] def setUp(self):
# be sure this value has been set first, not doing so
# could cause some dangerous behaviors (such as testing
# on production data)
if not self.ENVIRONMENT_SETUP:
self.fail(
"build_environment() not called, aborting due to "
"possibility of dangerous behaviors")
self.setup_warning_filter()
self.setup_app()
self.setup_client(self.app)
self.setup_database()
[docs] def tearDown(self):
self.teardown_app()
self.teardown_database()
self.teardown_warning_filter()
[docs] def assert_contents_equal(self, a_source, b_source):
"""
Explicitly check to see of the two iterable objects
contain the same data. This method exists to check to make
sure two iterables contain the same data without regards to order. This
is mostly meant for cases where two lists contain unhashable
types.
"""
# for now, we only support lists
self.assertIsInstance(a_source, list)
self.assertIsInstance(b_source, list)
a_copy = a_source[:]
b_copy = b_source[:]
for a_value, b_value in zip(a_source, b_source):
self.assertIn(a_value, b_source)
self.assertIn(b_value, a_source)
a_copy.pop()
b_copy.pop()
# There should not be any data left over after the above
# has completed.
self.assertEqual(len(b_copy), 0)
self.assertEqual(len(a_copy), 0)
[docs] def assert_status(self, response, status_code=None):
assert status_code is not None
self.assertIsInstance(response, Response)
self.assertEqual(response.status_code, status_code)
[docs] def assert_ok(self, response):
self.assert_status(response, status_code=OK)
[docs] def assert_created(self, response):
self.assert_status(response, status_code=CREATED)
[docs] def assert_accepted(self, response):
self.assert_status(response, status_code=ACCEPTED)
[docs] def assert_no_content(self, response):
self.assert_status(response, status_code=NO_CONTENT)
[docs] def assert_temporary_redirect(self, response):
self.assert_status(response, status_code=TEMPORARY_REDIRECT)
[docs] def assert_method_not_allowed(self, response):
self.assert_status(response, status_code=METHOD_NOT_ALLOWED)
[docs] def assert_bad_request(self, response):
self.assert_status(response, status_code=BAD_REQUEST)
[docs] def assert_conflict(self, response):
self.assert_status(response, status_code=CONFLICT)
[docs] def assert_unauthorized(self, response):
self.assert_status(response, status_code=UNAUTHORIZED)
[docs] def assert_forbidden(self, response):
self.assert_status(response, status_code=FORBIDDEN)
[docs] def assert_not_found(self, response):
self.assert_status(response, status_code=NOT_FOUND)
[docs] def assert_not_acceptable(self, response):
self.assert_status(response, status_code=NOT_ACCEPTABLE)
[docs] def assert_internal_server_error(self, response):
self.assert_status(response, status_code=INTERNAL_SERVER_ERROR)