Skip to content

Commit 19f56ff

Browse files
authored
feature: Datasource 抽象 (#170)
1 parent a97835a commit 19f56ff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+2646
-880
lines changed

.CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- [eorm: 分库分表: Merger抽象与批量查询实现](https://github.com/ecodeclub/eorm/pull/160)
1717
- [eorm: 增强的 ShardingAlgorithm 设计与实现](https://github.com/ecodeclub/eorm/pull/161)
1818
- [eorm: 分库分表: Merger排序实现](https://github.com/ecodeclub/eorm/pull/166)
19+
- [eorm: Datasource 抽象](https://github.com/ecodeclub/eorm/pull/167)
1920
- [eorm: 分库分表: Merger分页实现](https://github.com/ecodeclub/eorm/pull/175)
2021
- [eorm: BasicTypeValue重命名](https://github.com/ecodeclub/eorm/pull/177)
2122
- [eorm: 分库分表: hash、shadow_hash算法不符合预期](https://github.com/ecodeclub/eorm/pull/174)

builder.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,22 @@ import (
1818
"context"
1919
"database/sql"
2020

21+
"github.com/ecodeclub/eorm/internal/datasource"
22+
2123
"github.com/ecodeclub/eorm/internal/errs"
2224
"github.com/ecodeclub/eorm/internal/model"
25+
"github.com/ecodeclub/eorm/internal/query"
2326
"github.com/valyala/bytebufferpool"
2427
)
2528

2629
var _ Executor = &Inserter[any]{}
2730
var _ Executor = &Updater[any]{}
2831
var _ Executor = &Deleter[any]{}
2932

33+
var EmptyQuery = Query{}
34+
3035
// Query 代表一个查询
31-
type Query struct {
32-
SQL string
33-
Args []any
34-
}
36+
type Query = query.Query
3537

3638
// Querier 查询器,代表最基本的查询
3739
type Querier[T any] struct {
@@ -48,7 +50,7 @@ func RawQuery[T any](sess Session, sql string, args ...any) Querier[T] {
4850
core: sess.getCore(),
4951
Session: sess,
5052
qc: &QueryContext{
51-
q: &Query{
53+
q: Query{
5254
SQL: sql,
5355
Args: args,
5456
},
@@ -57,7 +59,7 @@ func RawQuery[T any](sess Session, sql string, args ...any) Querier[T] {
5759
}
5860
}
5961

60-
func newQuerier[T any](sess Session, q *Query, meta *model.TableMeta, typ string) Querier[T] {
62+
func newQuerier[T any](sess Session, q Query, meta *model.TableMeta, typ string) Querier[T] {
6163
return Querier[T]{
6264
core: sess.getCore(),
6365
Session: sess,
@@ -72,7 +74,7 @@ func newQuerier[T any](sess Session, q *Query, meta *model.TableMeta, typ string
7274
// Exec 执行 SQL
7375
func (q Querier[T]) Exec(ctx context.Context) Result {
7476
var handler HandleFunc = func(ctx context.Context, qc *QueryContext) *QueryResult {
75-
res, err := q.Session.execContext(ctx, qc.q.SQL, qc.q.Args...)
77+
res, err := q.Session.execContext(ctx, datasource.Query(qc.q))
7678
return &QueryResult{Result: res, Err: err}
7779
}
7880

@@ -326,16 +328,16 @@ func (b *builder) buildColumn(c Column) error {
326328
// buildSubquery 構建子查詢 SQL,
327329
// useAlias 決定是否顯示別名,即使有別名
328330
func (b *builder) buildSubquery(sub Subquery, useAlias bool) error {
329-
query, err := sub.q.Build()
331+
q, err := sub.q.Build()
330332
if err != nil {
331333
return err
332334
}
333335
b.writeByte('(')
334336
// 拿掉最後 ';'
335-
b.writeString(query.SQL[:len(query.SQL)-1])
337+
b.writeString(q.SQL[:len(q.SQL)-1])
336338
// 因為有 build() ,所以理應 args 也需要跟 SQL 一起處理
337-
if len(query.Args) > 0 {
338-
b.addArgs(query.Args...)
339+
if len(q.Args) > 0 {
340+
b.addArgs(q.Args...)
339341
}
340342
b.writeByte(')')
341343
if useAlias {

builder_test.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ import (
2121
"fmt"
2222
"testing"
2323

24+
"github.com/ecodeclub/eorm/internal/datasource/single"
25+
2426
"github.com/DATA-DOG/go-sqlmock"
2527
"github.com/ecodeclub/eorm/internal/errs"
2628
"github.com/ecodeclub/eorm/internal/valuer"
2729
"github.com/stretchr/testify/assert"
2830
)
2931

3032
func ExampleRawQuery() {
31-
orm := memoryDB()
32-
q := RawQuery[any](orm, `SELECT * FROM user_tab WHERE id = ?;`, 1)
33+
db := memoryDB()
34+
q := RawQuery[any](db, `SELECT * FROM user_tab WHERE id = ?;`, 1)
3335
fmt.Printf(`
3436
SQL: %s
3537
Args: %v
@@ -40,9 +42,9 @@ Args: %v
4042
}
4143

4244
func ExampleQuerier_Exec() {
43-
orm := memoryDB()
45+
db := memoryDB()
4446
// 在 Exec 的时候,泛型参数可以是任意的
45-
q := RawQuery[any](orm, `CREATE TABLE IF NOT EXISTS groups (
47+
q := RawQuery[any](db, `CREATE TABLE IF NOT EXISTS groups (
4648
group_id INTEGER PRIMARY KEY,
4749
name TEXT NOT NULL
4850
)`)
@@ -54,10 +56,6 @@ func ExampleQuerier_Exec() {
5456
// SUCCESS
5557
}
5658

57-
func (q Query) string() string {
58-
return fmt.Sprintf("SQL: %s\nArgs: %#v\n", q.SQL, q.Args)
59-
}
60-
6159
func TestQuerier_Get(t *testing.T) {
6260
t.Run("unsafe", func(t *testing.T) {
6361
testQuerierGet(t, valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue})
@@ -75,7 +73,7 @@ func testQuerierGet(t *testing.T, creator valuer.PrimitiveCreator) {
7573
}
7674
defer func() { _ = db.Close() }()
7775

78-
orm, err := openDB("mysql", db)
76+
orm, err := OpenDS("mysql", single.NewDB(db))
7977
if err != nil {
8078
t.Fatal(err)
8179
}
@@ -165,7 +163,7 @@ func testQuerier_GetMulti(t *testing.T, creator valuer.PrimitiveCreator) {
165163
defer func() {
166164
_ = db.Close()
167165
}()
168-
orm, err := openDB("mysql", db)
166+
orm, err := OpenDS("mysql", single.NewDB(db))
169167
if err != nil {
170168
t.Fatal(err)
171169
}

core.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type core struct {
3232
}
3333

3434
func getHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult {
35-
rows, err := sess.queryContext(ctx, qc.q.SQL, qc.q.Args...)
35+
rows, err := sess.queryContext(ctx, qc.q)
3636
if err != nil {
3737
return &QueryResult{Err: err}
3838
}
@@ -68,7 +68,7 @@ func get[T any](ctx context.Context, sess Session, core core, qc *QueryContext)
6868
}
6969

7070
func getMultiHandler[T any](ctx context.Context, sess Session, c core, qc *QueryContext) *QueryResult {
71-
rows, err := sess.queryContext(ctx, qc.q.SQL, qc.q.Args...)
71+
rows, err := sess.queryContext(ctx, qc.q)
7272
if err != nil {
7373
return &QueryResult{Err: err}
7474
}

db.go

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ package eorm
1717
import (
1818
"context"
1919
"database/sql"
20-
"database/sql/driver"
21-
"log"
22-
"time"
2320

21+
"github.com/ecodeclub/eorm/internal/datasource"
22+
"github.com/ecodeclub/eorm/internal/datasource/single"
2423
"github.com/ecodeclub/eorm/internal/dialect"
24+
"github.com/ecodeclub/eorm/internal/errs"
2525
"github.com/ecodeclub/eorm/internal/model"
2626
"github.com/ecodeclub/eorm/internal/valuer"
2727
)
@@ -39,8 +39,8 @@ type DBOption func(db *DB)
3939

4040
// DB represents a database
4141
type DB struct {
42-
db *sql.DB
4342
core
43+
ds datasource.DataSource
4444
}
4545

4646
// DBWithMiddlewares 为 db 配置 Middleware
@@ -50,31 +50,37 @@ func DBWithMiddlewares(ms ...Middleware) DBOption {
5050
}
5151
}
5252

53+
func DBOptionWithMetaRegistry(r model.MetaRegistry) DBOption {
54+
return func(db *DB) {
55+
db.metaRegistry = r
56+
}
57+
}
58+
5359
func UseReflection() DBOption {
5460
return func(db *DB) {
5561
db.valCreator = valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue}
5662
}
5763
}
5864

59-
func (db *DB) queryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
60-
return db.db.QueryContext(ctx, query, args...)
65+
func (db *DB) queryContext(ctx context.Context, q datasource.Query) (*sql.Rows, error) {
66+
return db.ds.Query(ctx, q)
6167
}
6268

63-
func (db *DB) execContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
64-
return db.db.ExecContext(ctx, query, args...)
69+
func (db *DB) execContext(ctx context.Context, q datasource.Query) (sql.Result, error) {
70+
return db.ds.Exec(ctx, q)
6571
}
6672

6773
// Open 创建一个 ORM 实例
6874
// 注意该实例是一个无状态的对象,你应该尽可能复用它
6975
func Open(driver string, dsn string, opts ...DBOption) (*DB, error) {
70-
db, err := sql.Open(driver, dsn)
76+
db, err := single.OpenDB(driver, dsn)
7177
if err != nil {
7278
return nil, err
7379
}
74-
return openDB(driver, db, opts...)
80+
return OpenDS(driver, db, opts...)
7581
}
7682

77-
func openDB(driver string, db *sql.DB, opts ...DBOption) (*DB, error) {
83+
func OpenDS(driver string, ds datasource.DataSource, opts ...DBOption) (*DB, error) {
7884
dl, err := dialect.Of(driver)
7985
if err != nil {
8086
return nil, err
@@ -88,37 +94,28 @@ func openDB(driver string, db *sql.DB, opts ...DBOption) (*DB, error) {
8894
Creator: valuer.NewUnsafeValue,
8995
},
9096
},
91-
db: db,
97+
ds: ds,
9298
}
9399
for _, o := range opts {
94100
o(orm)
95101
}
96102
return orm, nil
97103
}
98104

99-
// BeginTx 开启事务
100105
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
101-
tx, err := db.db.BeginTx(ctx, opts)
106+
inst, ok := db.ds.(datasource.TxBeginner)
107+
if !ok {
108+
return nil, errs.ErrNotCompleteTxBeginner
109+
}
110+
tx, err := inst.BeginTx(ctx, opts)
102111
if err != nil {
103112
return nil, err
104113
}
105-
return &Tx{tx: tx, db: db.db, core: db.core}, nil
106-
}
107-
108-
// Wait 会等待数据库连接
109-
// 注意只能用于测试
110-
func (db *DB) Wait() error {
111-
err := db.db.Ping()
112-
for err == driver.ErrBadConn {
113-
log.Printf("等待数据库启动...")
114-
err = db.db.Ping()
115-
time.Sleep(time.Second)
116-
}
117-
return err
114+
return &Tx{tx: tx, core: db.getCore()}, nil
118115
}
119116

120117
func (db *DB) Close() error {
121-
return db.db.Close()
118+
return db.ds.Close()
122119
}
123120

124121
func (db *DB) getCore() core {

0 commit comments

Comments
 (0)