Skip to content

Commit 675dd1f

Browse files
Merge pull request #912 from circulon/fix/allow_override_of_model_default_selects
allow override of model default selects
2 parents ccf9025 + 3970262 commit 675dd1f

16 files changed

+44
-32
lines changed

src/masoniteorm/models/Model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def get_builder(self):
354354
dry=self.__dry__,
355355
)
356356

357-
return self.builder.select(*self.get_selects())
357+
return self.builder
358358

359359
def get_selects(self):
360360
return self.__selects__

src/masoniteorm/query/QueryBuilder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import inspect
22
from copy import deepcopy
33
from datetime import datetime
4-
from typing import Any, Dict, List, Optional, Callable
4+
from typing import Any, Callable, Dict, List, Optional
55

66
from ..collection.Collection import Collection
77
from ..config import load_config
88
from ..exceptions import (
99
HTTP404,
1010
ConnectionNotRegistered,
11+
InvalidArgument,
1112
ModelNotFound,
1213
MultipleRecordsFound,
13-
InvalidArgument,
1414
)
1515
from ..expressions.expressions import (
1616
AggregateExpression,
@@ -1229,6 +1229,7 @@ def or_where_doesnt_have(self, relationship, callback):
12291229
return self
12301230

12311231
def with_count(self, relationship, callback=None):
1232+
self.select(*self._model.get_selects())
12321233
return getattr(self._model, relationship).get_with_count_query(
12331234
self, callback=callback
12341235
)
@@ -2067,6 +2068,9 @@ def get_grammar(self):
20672068

20682069
# Either _creates when creating, otherwise use columns
20692070
columns = self._creates or self._columns
2071+
if not columns and not self._aggregates and self._model:
2072+
self.select(*self._model.get_selects())
2073+
columns = self._columns
20702074

20712075
return self.grammar(
20722076
columns=columns,

src/masoniteorm/query/grammars/BaseGrammar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import re
22

33
from ...expressions.expressions import (
4-
SubGroupExpression,
5-
SubSelectExpression,
6-
SelectExpression,
74
JoinClause,
85
OnClause,
6+
SelectExpression,
7+
SubGroupExpression,
8+
SubSelectExpression,
99
)
1010

1111

src/masoniteorm/query/grammars/PostgresGrammar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from .BaseGrammar import BaseGrammar
21
import re
32

3+
from .BaseGrammar import BaseGrammar
4+
45

56
class PostgresGrammar(BaseGrammar):
67
"""Postgres grammar class."""

src/masoniteorm/query/grammars/SQLiteGrammar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from .BaseGrammar import BaseGrammar
21
import re
32

3+
from .BaseGrammar import BaseGrammar
4+
45

56
class SQLiteGrammar(BaseGrammar):
67
"""SQLite grammar class."""
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .SQLiteGrammar import SQLiteGrammar
2-
from .PostgresGrammar import PostgresGrammar
3-
from .MySQLGrammar import MySQLGrammar
41
from .MSSQLGrammar import MSSQLGrammar
2+
from .MySQLGrammar import MySQLGrammar
3+
from .PostgresGrammar import PostgresGrammar
4+
from .SQLiteGrammar import SQLiteGrammar
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1+
from .MSSQLPostProcessor import MSSQLPostProcessor
12
from .MySQLPostProcessor import MySQLPostProcessor
23
from .PostgresPostProcessor import PostgresPostProcessor
34
from .SQLitePostProcessor import SQLitePostProcessor
4-
from .MSSQLPostProcessor import MSSQLPostProcessor

tests/models/test_models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ModelTestForced(Model):
3737
__force_update__ = True
3838

3939
class BaseModel(Model):
40+
__dry__ = True
4041
def get_selects(self):
4142
return [f"{self.get_table_name()}.*"]
4243

@@ -267,9 +268,26 @@ def test_model_can_provide_default_select(self):
267268
"""SELECT `users`.* FROM `users`""",
268269
)
269270

270-
def test_model_can_add_to_default_select(self):
271+
def test_model_can_override_to_default_select(self):
271272
sql = ModelWithBaseModel.select(["products.name", "products.id", "store.name"]).to_sql()
272273
self.assertEqual(
273274
sql,
274-
"""SELECT `users`.*, `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""",
275+
"""SELECT `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""",
276+
)
277+
278+
def test_model_can_use_aggregate_funcs_with_default_selects(self):
279+
sql = ModelWithBaseModel.count().to_sql()
280+
self.assertEqual(
281+
sql,
282+
"""SELECT COUNT(*) AS m_count_reserved FROM `users`""",
283+
)
284+
sql = ModelWithBaseModel.max("id").to_sql()
285+
self.assertEqual(
286+
sql,
287+
"""SELECT MAX(`users`.`id`) AS id FROM `users`""",
288+
)
289+
sql = ModelWithBaseModel.min("id").to_sql()
290+
self.assertEqual(
291+
sql,
292+
"""SELECT MIN(`users`.`id`) AS id FROM `users`""",
275293
)

tests/mssql/builder/test_mssql_query_builder_relationships.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_builder(self, table="users"):
5959
connection_class=connection,
6060
connection="mssql",
6161
table=table,
62-
model=User,
62+
model=User(),
6363
)
6464

6565
def test_has(self):

tests/mysql/builder/test_transactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TestTransactions(unittest.TestCase):
2626
pass
2727
# def get_builder(self, table="users"):
2828
# connection = ConnectionFactory().make("default")
29-
# return QueryBuilder(MySQLGrammar, connection, table=table, model=User)
29+
# return QueryBuilder(MySQLGrammar, connection, table=table, model=User())
3030

3131
# def test_can_start_transaction(self, table="users"):
3232
# builder = self.get_builder()

0 commit comments

Comments
 (0)