From 5e91f0c1ca473f2c899c329c3a844257e7f48cd5 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 7 Dec 2011 10:02:15 +0100 Subject: [PATCH] [DDC-551] Update SQLWalker to reflect filter requirements for inheritance --- lib/Doctrine/ORM/Query/SqlWalker.php | 37 ++++++++++++++++--- .../Tests/ORM/Functional/SQLFilterTest.php | 14 +++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/lib/Doctrine/ORM/Query/SqlWalker.php b/lib/Doctrine/ORM/Query/SqlWalker.php index ea834fc6a..69bf903be 100644 --- a/lib/Doctrine/ORM/Query/SqlWalker.php +++ b/lib/Doctrine/ORM/Query/SqlWalker.php @@ -268,6 +268,11 @@ class SqlWalker implements TreeWalker $sqlParts[] = $baseTableAlias . '.' . $columnName . ' = ' . $tableAlias . '.' . $columnName; } + // Add filters on the root class + if ('' !== $filterSql = $this->generateFilterConditionSQL($parentClass, $tableAlias)) { + $sqlParts[] = $filterSql; + } + $sql .= implode(' AND ', $sqlParts); } @@ -363,13 +368,35 @@ class SqlWalker implements TreeWalker */ private function generateFilterConditionSQL(ClassMetadata $targetEntity, $targetTableAlias) { - $filterClauses = array(); + if (!$this->_em->hasFilters()) { + return ''; + } - if ($this->_em->hasFilters()) { - foreach ($this->_em->getFilters()->getEnabledFilters() as $filter) { - if ('' !== $filterExpr = $filter->addFilterConstraint($targetEntity, $targetTableAlias)) { - $filterClauses[] = '(' . $filterExpr . ')'; + switch($targetEntity->inheritanceType) { + case ClassMetadata::INHERITANCE_TYPE_NONE: + break; + case ClassMetadata::INHERITANCE_TYPE_JOINED: + // The classes in the inheritance will be added to the query one by one, + // but only the root node is getting filtered + if ($targetEntity->name !== $targetEntity->rootEntityName) { + return ''; } + break; + case ClassMetadata::INHERITANCE_TYPE_SINGLE_TABLE: + // With STI the table will only be queried once, make sure that the filters + // are added to the root entity + $targetEntity = $this->_em->getClassMetadata($targetEntity->rootEntityName); + break; + default: + //@todo: throw exception? + return ''; + break; + } + + $filterClauses = array(); + foreach ($this->_em->getFilters()->getEnabledFilters() as $filter) { + if ('' !== $filterExpr = $filter->addFilterConstraint($targetEntity, $targetTableAlias)) { + $filterClauses[] = '(' . $filterExpr . ')'; } } diff --git a/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php b/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php index ff9332373..d56ddb99b 100644 --- a/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php +++ b/tests/Doctrine/Tests/ORM/Functional/SQLFilterTest.php @@ -519,7 +519,10 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase public function testJoinSubclassPersister_FilterOnlyOnRootTableWhenFetchingSubEntity() { $this->loadCompanyJoinedSubclassFixtureData(); + // Persister $this->assertEquals(2, count($this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyManager')->findAll())); + // SQLWalker + $this->assertEquals(2, count($this->_em->createQuery("SELECT cm FROM Doctrine\Tests\Models\Company\CompanyManager cm")->getResult())); // Enable the filter $conf = $this->_em->getConfiguration(); @@ -531,12 +534,15 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase $managers = $this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyManager')->findAll(); $this->assertEquals(1, count($managers)); $this->assertEquals("Guilherme", $managers[0]->getName()); + + $this->assertEquals(1, count($this->_em->createQuery("SELECT cm FROM Doctrine\Tests\Models\Company\CompanyManager cm")->getResult())); } public function testJoinSubclassPersister_FilterOnlyOnRootTableWhenFetchingRootEntity() { $this->loadCompanyJoinedSubclassFixtureData(); $this->assertEquals(3, count($this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyPerson')->findAll())); + $this->assertEquals(3, count($this->_em->createQuery("SELECT cp FROM Doctrine\Tests\Models\Company\CompanyPerson cp")->getResult())); // Enable the filter $conf = $this->_em->getConfiguration(); @@ -548,6 +554,8 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase $persons = $this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyPerson')->findAll(); $this->assertEquals(1, count($persons)); $this->assertEquals("Guilherme", $persons[0]->getName()); + + $this->assertEquals(1, count($this->_em->createQuery("SELECT cp FROM Doctrine\Tests\Models\Company\CompanyPerson cp")->getResult())); } private function loadCompanyJoinedSubclassFixtureData() @@ -577,7 +585,10 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase public function testSingleTableInheritance_FilterOnlyOnRootTableWhenFetchingSubEntity() { $this->loadCompanySingleTableInheritanceFixtureData(); + // Persister $this->assertEquals(2, count($this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyFlexUltraContract')->findAll())); + // SQLWalker + $this->assertEquals(2, count($this->_em->createQuery("SELECT cfc FROM Doctrine\Tests\Models\Company\CompanyFlexUltraContract cfc")->getResult())); // Enable the filter $conf = $this->_em->getConfiguration(); @@ -586,12 +597,14 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase ->enable("completed_contract"); $this->assertEquals(1, count($this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyFlexUltraContract')->findAll())); + $this->assertEquals(1, count($this->_em->createQuery("SELECT cfc FROM Doctrine\Tests\Models\Company\CompanyFlexUltraContract cfc")->getResult())); } public function testSingleTableInheritance_FilterOnlyOnRootTableWhenFetchingRootEntity() { $this->loadCompanySingleTableInheritanceFixtureData(); $this->assertEquals(4, count($this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyFlexContract')->findAll())); + $this->assertEquals(4, count($this->_em->createQuery("SELECT cfc FROM Doctrine\Tests\Models\Company\CompanyFlexContract cfc")->getResult())); // Enable the filter $conf = $this->_em->getConfiguration(); @@ -600,6 +613,7 @@ class SQLFilterTest extends \Doctrine\Tests\OrmFunctionalTestCase ->enable("completed_contract"); $this->assertEquals(2, count($this->_em->getRepository('Doctrine\Tests\Models\Company\CompanyFlexContract')->findAll())); + $this->assertEquals(2, count($this->_em->createQuery("SELECT cfc FROM Doctrine\Tests\Models\Company\CompanyFlexContract cfc")->getResult())); } private function loadCompanySingleTableInheritanceFixtureData()