diff --git a/lib/Doctrine/ORM/Configuration.php b/lib/Doctrine/ORM/Configuration.php index 9c7d1cb13..db0ef7a7b 100644 --- a/lib/Doctrine/ORM/Configuration.php +++ b/lib/Doctrine/ORM/Configuration.php @@ -562,11 +562,22 @@ class Configuration extends \Doctrine\DBAL\Configuration * Add a filter to the list of possible filters. * * @param string $name The name of the filter. - * @param string $className The class name of the filter. + * @param string|Query\Filter\SQLFilter $filter The filter class name or an + * SQLFilter instance. + * + * @throws \InvalidArgumentException If the filter is an object and it doesn't + * extend the Query\Filter\SQLFilter class. */ - public function addFilter($name, $className) + public function addFilter($name, $filter) { - $this->_attributes['filters'][$name] = $className; + if (is_object($filter) && ! $filter instanceof Query\Filter\SQLFilter) { + throw new \InvalidArgumentException( + "A filter can be either a class name or an object extending \Doctrine\ORM\Query\Filter\SQLFilter," . + " instance of '" . get_class($filter) . "' given." + ); + } + + $this->_attributes['filters'][$name] = $filter; } /** @@ -574,10 +585,10 @@ class Configuration extends \Doctrine\DBAL\Configuration * * @param string $name The name of the filter. * - * @return string The class name of the filter, or null of it is not - * defined. + * @return string|Query\Filter\SQLFilter The class name of the filter, an + * SQLFilter instance or null of it is not defined. */ - public function getFilterClassName($name) + public function getFilter($name) { return isset($this->_attributes['filters'][$name]) ? $this->_attributes['filters'][$name] diff --git a/lib/Doctrine/ORM/Query/FilterCollection.php b/lib/Doctrine/ORM/Query/FilterCollection.php index fc47eb111..495b6a29b 100644 --- a/lib/Doctrine/ORM/Query/FilterCollection.php +++ b/lib/Doctrine/ORM/Query/FilterCollection.php @@ -103,12 +103,14 @@ class FilterCollection */ public function enable($name) { - if (null === $filterClass = $this->config->getFilterClassName($name)) { + if (null === $filter = $this->config->getFilter($name)) { throw new \InvalidArgumentException("Filter '" . $name . "' does not exist."); } if (!isset($this->enabledFilters[$name])) { - $this->enabledFilters[$name] = new $filterClass($this->em); + $this->enabledFilters[$name] = is_object($filter) + ? $filter + : new $filter($this->em); // Keep the enabled filters sorted for the hash ksort($this->enabledFilters); diff --git a/tests/Doctrine/Tests/ORM/ConfigurationTest.php b/tests/Doctrine/Tests/ORM/ConfigurationTest.php index b53fd617b..91343640d 100644 --- a/tests/Doctrine/Tests/ORM/ConfigurationTest.php +++ b/tests/Doctrine/Tests/ORM/ConfigurationTest.php @@ -215,9 +215,9 @@ class ConfigurationTest extends PHPUnit_Framework_TestCase public function testAddGetFilters() { - $this->assertSame(null, $this->configuration->getFilterClassName('NonExistingFilter')); + $this->assertSame(null, $this->configuration->getFilter('NonExistingFilter')); $this->configuration->addFilter('FilterName', __CLASS__); - $this->assertSame(__CLASS__, $this->configuration->getFilterClassName('FilterName')); + $this->assertSame(__CLASS__, $this->configuration->getFilter('FilterName')); } public function setDefaultRepositoryClassName() diff --git a/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php b/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php index 82c240951..133517a0a 100644 --- a/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php +++ b/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php @@ -58,11 +58,27 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase public function testConfigureFilter() { $config = new \Doctrine\ORM\Configuration(); + $validFilter = $this->getMockBuilder('\Doctrine\ORM\Query\Filter\SQLFilter') + ->disableOriginalConstructor() + ->getMock(); + $config->addFilter("geolocation", $validFilter); $config->addFilter("locale", "\Doctrine\Tests\ORM\Functional\MyLocaleFilter"); - $this->assertEquals("\Doctrine\Tests\ORM\Functional\MyLocaleFilter", $config->getFilterClassName("locale")); - $this->assertNull($config->getFilterClassName("foo")); + $this->assertEquals("\Doctrine\Tests\ORM\Functional\MyLocaleFilter", $config->getFilter("locale")); + $this->assertNull($config->getFilter("foo")); + $this->assertInstanceOf("\Doctrine\ORM\Query\Filter\SQLFilter", $config->getFilter("geolocation")); + } + + /** + * @expectedException InvalidArgumentException + */ + public function testConfigureFilterFails() + { + $config = new \Doctrine\ORM\Configuration(); + $invalidFilter = $this->getMock('\StdClass'); + + $config->addFilter("geolocation", $invalidFilter); } public function testEntityManagerEnableFilter()