diff --git a/lib/Doctrine/ORM/Query/Filter/SQLFilter.php b/lib/Doctrine/ORM/Query/Filter/SQLFilter.php index 2ac413788..dd90f1789 100644 --- a/lib/Doctrine/ORM/Query/Filter/SQLFilter.php +++ b/lib/Doctrine/ORM/Query/Filter/SQLFilter.php @@ -67,5 +67,8 @@ abstract class SQLFilter return serialize($this->parameters); } + /** + * @return string The contstraint if there is one, empty string otherwise + */ abstract function addFilterConstraint(ClassMetadata $targetEntity, $targetTableAlias); } diff --git a/lib/Doctrine/ORM/Query/SqlWalker.php b/lib/Doctrine/ORM/Query/SqlWalker.php index 7393ba893..8810aed87 100644 --- a/lib/Doctrine/ORM/Query/SqlWalker.php +++ b/lib/Doctrine/ORM/Query/SqlWalker.php @@ -350,6 +350,20 @@ class SqlWalker implements TreeWalker return ($encapsulate) ? '(' . $sql . ')' : $sql; } + private function generateFilterConditionSQL(ClassMetadata $targetEntity, $targetTableAlias) + { + $filterSql = ''; + + $first = true; + foreach($this->_em->getEnabledFilters() as $filter) { + if("" !== $filterExpr = $filter->addFilterConstraint($targetEntity, $targetTableAlias)) { + if ( ! $first) $sql .= ' AND '; else $first = false; + $filterSql .= '(' . $filterExpr . ')'; + } + } + + return $filterSql; + } /** * Walks down a SelectStatement AST node, thereby generating the appropriate SQL. * @@ -771,6 +785,12 @@ class SqlWalker implements TreeWalker $sql .= $sourceTableAlias . '.' . $quotedTargetColumn . ' = ' . $targetTableAlias . '.' . $sourceColumn; } } + + // Apply the filters + if("" !== $filterExpr = $this->generateFilterConditionSQL($targetClass, $targetTableAlias)) { + if ( ! $first) $sql .= ' AND '; else $first = false; + $sql .= $filterExpr; + } } else if ($assoc['type'] == ClassMetadata::MANY_TO_MANY) { // Join relation table $joinTable = $assoc['joinTable']; @@ -835,6 +855,12 @@ class SqlWalker implements TreeWalker $sql .= $targetTableAlias . '.' . $quotedTargetColumn . ' = ' . $joinTableAlias . '.' . $relationColumn; } } + + // Apply the filters + if("" !== $filterExpr = $this->generateFilterConditionSQL($targetClass, $targetTableAlias)) { + if ( ! $first) $sql .= ' AND '; else $first = false; + $sql .= $filterExpr; + } } // Handle WITH clause @@ -1414,6 +1440,26 @@ class SqlWalker implements TreeWalker $condSql = null !== $whereClause ? $this->walkConditionalExpression($whereClause->conditionalExpression) : ''; $discrSql = $this->_generateDiscriminatorColumnConditionSql($this->_rootAliases); + $filterSql = ''; + + // Apply the filters for all entities in FROM + $first = true; + foreach ($this->_rootAliases as $dqlAlias) { + $class = $this->_queryComponents[$dqlAlias]['metadata']; + $tableAlias = $this->getSQLTableAlias($class->table['name'], $dqlAlias); + + if("" !== $filterExpr = $this->generateFilterConditionSQL($class, $tableAlias)) { + if ( ! $first) $filterSql .= ' AND '; else $first = false; + $filterSql .= $filterExpr; + } + } + + if ('' !== $filterSql) { + if($condSql) $condSql .= ' AND '; + + $condSql .= $filterSql; + } + if ($condSql) { return ' WHERE ' . (( ! $discrSql) ? $condSql : '(' . $condSql . ') AND ' . $discrSql); } else if ($discrSql) { diff --git a/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php b/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php index ac9bae0c7..997032cbc 100644 --- a/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php +++ b/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php @@ -6,6 +6,15 @@ use Doctrine\ORM\Query\Filter\SQLFilter; use Doctrine\ORM\Mapping\ClassMetaData; use Doctrine\Common\Cache\ArrayCache; +use Doctrine\ORM\Mapping\ClassMetadataInfo; + +use Doctrine\Tests\Models\CMS\CmsUser; +use Doctrine\Tests\Models\CMS\CmsPhonenumber; +use Doctrine\Tests\Models\CMS\CmsAddress; +use Doctrine\Tests\Models\CMS\CmsGroup; +use Doctrine\Tests\Models\CMS\CmsArticle; +use Doctrine\Tests\Models\CMS\CmsComment; + require_once __DIR__ . '/../../TestInit.php'; /** @@ -15,6 +24,8 @@ require_once __DIR__ . '/../../TestInit.php'; */ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase { + private $userId, $userId2, $articleId, $articleId2; + public function setUp() { $this->useModelSet('cms'); @@ -211,10 +222,124 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase $query->getResult(); $this->assertEquals(2, count($cache->getIds())); - return $query; + // Another time doesn't add another cache entry + $query->getResult(); + $this->assertEquals(2, count($cache->getIds())); + } + + public function testToOneFilter() + { + //$this->_em->getConnection()->getConfiguration()->setSQLLogger(new \Doctrine\DBAL\Logging\EchoSQLLogger); + $this->loadFixtureData(); + + $query = $this->_em->createQuery('select ux, ua from Doctrine\Tests\Models\CMS\CmsUser ux JOIN ux.address ua'); + + // We get two users before enabling the filter + $this->assertEquals(2, count($query->getResult())); + + $conf = $this->_em->getConfiguration(); + $conf->addFilter("country", "\Doctrine\Tests\ORM\Functional\CMSCountryFilter"); + $this->_em->enableFilter("country")->setParameter("country", "Germany", \Doctrine\DBAL\Types\Type::getType(\Doctrine\DBAL\Types\Type::STRING)->getBindingType()); + + // We get one user after enabling the filter + $this->assertEquals(1, count($query->getResult())); + } + + public function testManyToManyFilter() + { + $this->loadFixtureData(); + $query = $this->_em->createQuery('select ux, ug from Doctrine\Tests\Models\CMS\CmsUser ux JOIN ux.groups ug'); + + // We get two users before enabling the filter + $this->assertEquals(2, count($query->getResult())); + + $conf = $this->_em->getConfiguration(); + $conf->addFilter("group_prefix", "\Doctrine\Tests\ORM\Functional\CMSGroupPrefixFilter"); + $this->_em->enableFilter("group_prefix")->setParameter("prefix", "bar_%", \Doctrine\DBAL\Types\Type::getType(\Doctrine\DBAL\Types\Type::STRING)->getBindingType()); + + // We get one user after enabling the filter + $this->assertEquals(1, count($query->getResult())); + + } + + public function testWhereFilter() + { + $this->loadFixtureData(); + $query = $this->_em->createQuery('select ug from Doctrine\Tests\Models\CMS\CmsGroup ug WHERE 1=1'); + + // We get two users before enabling the filter + $this->assertEquals(2, count($query->getResult())); + + $conf = $this->_em->getConfiguration(); + $conf->addFilter("group_prefix", "\Doctrine\Tests\ORM\Functional\CMSGroupPrefixFilter"); + $this->_em->enableFilter("group_prefix")->setParameter("prefix", "bar_%", \Doctrine\DBAL\Types\Type::getType(\Doctrine\DBAL\Types\Type::STRING)->getBindingType()); + + // We get one user after enabling the filter + $this->assertEquals(1, count($query->getResult())); } + private function loadFixtureData() + { + $user = new CmsUser; + $user->name = 'Roman'; + $user->username = 'romanb'; + $user->status = 'developer'; + + $address = new CmsAddress; + $address->country = 'Germany'; + $address->city = 'Berlin'; + $address->zip = '12345'; + + $user->address = $address; // inverse side + $address->user = $user; // owning side! + + $group = new CmsGroup; + $group->name = 'foo_group'; + $user->addGroup($group); + + $article1 = new CmsArticle; + $article1->topic = "Test1"; + $article1->text = "Test"; + $article1->setAuthor($user); + + $article2 = new CmsArticle; + $article2->topic = "Test2"; + $article2->text = "Test"; + $article2->setAuthor($user); + + $this->_em->persist($article1); + $this->_em->persist($article2); + + $this->_em->persist($user); + + $user2 = new CmsUser; + $user2->name = 'Guilherme'; + $user2->username = 'gblanco'; + $user2->status = 'developer'; + + $address2 = new CmsAddress; + $address2->country = 'France'; + $address2->city = 'Paris'; + $address2->zip = '12345'; + + $user->address = $address2; // inverse side + $address2->user = $user2; // owning side! + + $user2->addGroup($group); + $group2 = new CmsGroup; + $group2->name = 'bar_group'; + $user2->addGroup($group2); + + $this->_em->persist($user2); + $this->_em->flush(); + $this->_em->clear(); + + $this->userId = $user->getId(); + $this->userId2 = $user2->getId(); + $this->articleId = $article1->id; + $this->articleId2 = $article2->id; + } } class MySoftDeleteFilter extends SQLFilter @@ -237,6 +362,41 @@ class MyLocaleFilter extends SQLFilter return ""; } - return $targetTableAlias.'.locale = ' . $this->getParam('locale'); // getParam uses connection to quote the value. + return $targetTableAlias.'.locale = ' . $this->getParameter('locale'); // getParam uses connection to quote the value. + } +} + +class CMSCountryFilter extends SQLFilter +{ + public function addFilterConstraint(ClassMetadata $targetEntity, $targetTableAlias) + { + if ($targetEntity->name != "Doctrine\Tests\Models\CMS\CmsAddress") { + return ""; + } + + return $targetTableAlias.'.country = ' . $this->getParameter('country'); // getParam uses connection to quote the value. + } +} + +class CMSGroupPrefixFilter extends SQLFilter +{ + public function addFilterConstraint(ClassMetadata $targetEntity, $targetTableAlias) + { + if ($targetEntity->name != "Doctrine\Tests\Models\CMS\CmsGroup") { + return ""; + } + + return $targetTableAlias.'.name LIKE ' . $this->getParameter('prefix'); // getParam uses connection to quote the value. + } +} +class CMSArticleTopicFilter extends SQLFilter +{ + public function addFilterConstraint(ClassMetadata $targetEntity, $targetTableAlias) + { + if ($targetEntity->name != "Doctrine\Tests\Models\CMS\CmsArticle") { + return ""; + } + + return $targetTableAlias.'.topic = ' . $this->getParameter('topic'); // getParam uses connection to quote the value. } }