feat: 新增 ScanRows 接口
This commit is contained in:
144
mapping.go
144
mapping.go
@@ -1,13 +1,16 @@
|
||||
package tdmap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func NewMapping() *TdMapping {
|
||||
return &TdMapping{}
|
||||
func NewMapper() *Mapper {
|
||||
return &Mapper{}
|
||||
}
|
||||
|
||||
// TableNamer 定义了一个获取表名的方法。
|
||||
@@ -15,20 +18,21 @@ type TableNamer interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type TdMapping struct {
|
||||
modelMates Map[string, *StructMate]
|
||||
type Mapper struct {
|
||||
structMateMap SyncMap[string, *StructMate] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息
|
||||
}
|
||||
|
||||
func (b *TdMapping) scanStruct(data ...any) error {
|
||||
func (b *Mapper) scanStruct(data ...any) error {
|
||||
for _, datum := range data {
|
||||
if reflect.TypeOf(datum).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum))
|
||||
//return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum))
|
||||
return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(datum))
|
||||
}
|
||||
|
||||
if mate, err := scan(datum); err != nil {
|
||||
return err
|
||||
} else {
|
||||
b.modelMates.Store(mate.UniqueTypeName, mate)
|
||||
b.structMateMap.Store(mate.UniqueTypeName, mate)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +42,7 @@ func (b *TdMapping) scanStruct(data ...any) error {
|
||||
// 提取 struct 内的信息
|
||||
// @param data 包含 struct 类型的数据
|
||||
// @return 返回一个 map[表名][]*TableRowMateria, 如果超级表名为空,则表示普通表
|
||||
func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateria, error) {
|
||||
func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, error) {
|
||||
result := make(map[string][]*TableRowMateria)
|
||||
|
||||
for _, item := range data {
|
||||
@@ -65,7 +69,7 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri
|
||||
}
|
||||
}
|
||||
|
||||
mate, ok := b.modelMates.Load(uniqueTypeName)
|
||||
mate, ok := b.structMateMap.Load(uniqueTypeName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not found struct type: %s", uniqueTypeName)
|
||||
}
|
||||
@@ -81,9 +85,9 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri
|
||||
|
||||
for _, name := range mate.DBAnnotatedNames {
|
||||
// db tag 名称 -- 数据库列名
|
||||
dbColumn := mate.FiledDBNameCache[name]
|
||||
dbColumn := mate.Filed2DBNameCache[name]
|
||||
|
||||
field := vf.FieldByIndex(mate.FieldIndexCache[name])
|
||||
field := vf.FieldByIndex(mate.Field2IndexCache[name])
|
||||
for field.Kind() == reflect.Ptr {
|
||||
if field.IsNil() {
|
||||
break
|
||||
@@ -99,8 +103,8 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri
|
||||
|
||||
for _, name := range mate.TaggedFieldNames {
|
||||
// db tag 名称 -- 数据库列名
|
||||
tagColumn := mate.FiledDBNameCache[name]
|
||||
field := vf.FieldByIndex(mate.FieldIndexCache[name])
|
||||
tagColumn := mate.Filed2DBNameCache[name]
|
||||
field := vf.FieldByIndex(mate.Field2IndexCache[name])
|
||||
for field.Kind() == reflect.Ptr {
|
||||
if field.IsNil() {
|
||||
break
|
||||
@@ -121,7 +125,7 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (b *TdMapping) ToInsertSQL(data ...any) (string, error) {
|
||||
func (b *Mapper) ToInsertSQL(data ...any) (string, error) {
|
||||
// 扫描 struct 类型数据
|
||||
if err := b.scanStruct(data...); err != nil {
|
||||
return "", fmt.Errorf("failed to scan struct: %w", err)
|
||||
@@ -171,3 +175,115 @@ func (b *TdMapping) ToInsertSQL(data ...any) (string, error) {
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// ScanRows 扫描多行数据
|
||||
// A Superior Option: Use ScanRowsWithContext instead.
|
||||
func (b *Mapper) ScanRows(target any, rows *sql.Rows) error {
|
||||
return b.ScanRowsWithContext(context.Background(), target, rows)
|
||||
}
|
||||
|
||||
// ScanRowsWithContext 扫描多行数据
|
||||
// @param ctx 上下文
|
||||
// @param target 结构体的指针 或 slice 指针 结构体需要使用 db 标签注解
|
||||
// @param rows 数据库返回的行数据
|
||||
// @return 返回错误
|
||||
func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.Rows) error {
|
||||
// 确保rows不为nil
|
||||
if rows == nil {
|
||||
return errors.New("rows cannot be nil")
|
||||
}
|
||||
|
||||
// 获取target的反射类型和值
|
||||
vf := reflect.ValueOf(target)
|
||||
|
||||
// 检查target是否为指针类型
|
||||
if vf.Kind() != reflect.Ptr {
|
||||
return errors.New("target must be a pointer to a struct or a slice of structs")
|
||||
}
|
||||
|
||||
if vf.IsNil() {
|
||||
return errors.New("target is nil, please pass a pointer to a struct or a slice of structs")
|
||||
}
|
||||
|
||||
vf = vf.Elem() // 解引用指针以获取目标值
|
||||
|
||||
switch vf.Kind() {
|
||||
case reflect.Struct:
|
||||
// 提取 struct 内的信息
|
||||
if err := b.scanStruct(target); err != nil {
|
||||
return err
|
||||
}
|
||||
// target是指向单个struct的指针,扫描一行数据
|
||||
if !rows.Next() {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return b.scanRow(vf, rows)
|
||||
case reflect.Slice:
|
||||
// target是slice的指针
|
||||
sliceElementType := vf.Type().Elem()
|
||||
if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct {
|
||||
return errors.New("target must be a pointer to a slice of structs")
|
||||
}
|
||||
|
||||
// 提取 slice 中 struct 内的信息
|
||||
if err := b.scanStruct(reflect.New(sliceElementType.Elem()).Interface()); err != nil { // 这里在抛出 不是指针类型
|
||||
return err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
vf.SetLen(0) // 清空 slice
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// 创建slice元素的新实例
|
||||
newStruct := reflect.New(sliceElementType.Elem())
|
||||
if err := b.scanRow(newStruct, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将新实例添加到slice中
|
||||
vf.Set(reflect.Append(vf, newStruct))
|
||||
}
|
||||
|
||||
if vf.Len() == 0 { // 如果slice为空,返回 sql.ErrNoRows
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return errors.New("target must be a pointer to a struct or a slice of structs")
|
||||
}
|
||||
}
|
||||
|
||||
// scanRow 用于扫描单行数据到结构体
|
||||
func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error {
|
||||
target = reflect.Indirect(target)
|
||||
|
||||
uniqueTypeName := getUniqueTypeName(target.Type())
|
||||
mate, ok := b.structMateMap.Load(uniqueTypeName)
|
||||
if !ok {
|
||||
return fmt.Errorf("not found struct type mate: %s", uniqueTypeName)
|
||||
}
|
||||
|
||||
// 假设我们有一个函数来获取列的值
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 准备目标结构体的地址
|
||||
dest := make([]interface{}, len(columns))
|
||||
for i, colName := range columns {
|
||||
// 拿到缓存的字段索引
|
||||
if idx, ok := mate.DBName2IndexCache[colName]; ok {
|
||||
dest[i] = target.FieldByIndex(idx).Addr().Interface()
|
||||
} else {
|
||||
return fmt.Errorf("no corresponding field found for column %s", colName)
|
||||
}
|
||||
}
|
||||
|
||||
// 扫描数据
|
||||
return rows.Scan(dest...)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user