tdengine-mapper-go/mapping.go

286 lines
7.4 KiB
Go
Raw Permalink Normal View History

2024-09-14 15:57:03 +08:00
package tdmap
import (
2024-09-18 16:51:30 +08:00
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
)
2024-09-18 16:51:30 +08:00
func NewMapper() *Mapper {
return &Mapper{}
}
// TableNamer 定义了一个获取表名的方法。
type TableNamer interface {
TableName() string
}
2024-09-18 16:51:30 +08:00
type Mapper struct {
structMateMap SyncMap[string, *StructMate] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息
}
2024-09-18 16:51:30 +08:00
func (b *Mapper) scanStruct(data ...any) error {
for _, datum := range data {
if reflect.TypeOf(datum).Kind() != reflect.Ptr {
2024-09-18 16:51:30 +08:00
//return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum))
return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(datum))
}
if mate, err := scan(datum); err != nil {
return err
} else {
2024-09-18 16:51:30 +08:00
b.structMateMap.Store(mate.UniqueTypeName, mate)
}
}
return nil
}
// 提取 struct 内的信息
// @param data 包含 struct 类型的数据
// @return 返回一个 map[表名][]*TableRowMateria 如果超级表名为空,则表示普通表
2024-09-18 16:51:30 +08:00
func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, error) {
result := make(map[string][]*TableRowMateria)
for _, item := range data {
tf, vf := getReflectTypeAndValue(item)
uniqueTypeName := getUniqueTypeName(tf)
// 获取表名
var tableName string
{
if vf.CanAddr() {
if a, ok := vf.Addr().Interface().(TableNamer); ok {
tableName = a.TableName()
}
}
if tableName == "" {
if a, ok := item.(TableNamer); ok {
tableName = a.TableName()
}
}
if tableName == "" {
return nil, fmt.Errorf("not import TableName() string func for struct type: %s", uniqueTypeName)
}
}
2024-09-18 16:51:30 +08:00
mate, ok := b.structMateMap.Load(uniqueTypeName)
if !ok {
return nil, fmt.Errorf("not found struct type: %s", uniqueTypeName)
}
materia := &TableRowMateria{
SuperTableName: mate.SuperTableName,
TableName: tableName,
TagColumns: make([]string, 0, len(mate.TaggedFieldNames)),
TagValues: make([]any, 0, len(mate.TaggedFieldNames)),
Columns: make([]string, 0, len(mate.DBAnnotatedNames)),
Values: make([]any, 0, len(mate.DBAnnotatedNames)),
}
for _, name := range mate.DBAnnotatedNames {
// db tag 名称 -- 数据库列名
2024-09-18 16:51:30 +08:00
dbColumn := mate.Filed2DBNameCache[name]
2024-09-18 16:51:30 +08:00
field := vf.FieldByIndex(mate.Field2IndexCache[name])
for field.Kind() == reflect.Ptr {
if field.IsNil() {
break
}
field = field.Elem()
}
// 字段值
dbValue := field.Interface()
materia.Columns = append(materia.Columns, dbColumn)
materia.Values = append(materia.Values, dbValue)
}
for _, name := range mate.TaggedFieldNames {
// db tag 名称 -- 数据库列名
2024-09-18 16:51:30 +08:00
tagColumn := mate.Filed2DBNameCache[name]
field := vf.FieldByIndex(mate.Field2IndexCache[name])
for field.Kind() == reflect.Ptr {
if field.IsNil() {
break
}
field = field.Elem()
}
// 字段值
tagValue := field.Interface()
materia.TagColumns = append(materia.TagColumns, tagColumn)
materia.TagValues = append(materia.TagValues, tagValue)
}
// 添加到结果
key := materia.TableName
result[key] = append(result[key], materia)
}
return result, nil
}
2024-09-18 16:51:30 +08:00
func (b *Mapper) ToInsertSQL(data ...any) (string, error) {
// 扫描 struct 类型数据
if err := b.scanStruct(data...); err != nil {
return "", fmt.Errorf("failed to scan struct: %w", err)
}
// 提取 struct 内的信息
tableMap, err := b.extractStructData(data...)
if err != nil {
return "", fmt.Errorf("failed to extract struct data: %w", err)
}
if len(tableMap) == 0 {
return "", fmt.Errorf("no data to insert")
}
// 构建 insert 语句
/*
INSERT INTO
表名1 USING 超级表名1 (tag列名) TAGS(tag值) (列名) VALUES (),()
表名2 USING 超级表名2 (tag列名)TAGS(tag值) (列名) VALUES (),()
*/
var buf strings.Builder
buf.WriteString("INSERT INTO \n")
for _, materials := range tableMap {
if len(materials) == 0 {
continue
}
var (
rowSql string
err error
)
if materials[0].SuperTableName == "" {
// 表名1 (列名) VALUES (),()
rowSql, err = buildInsertStatementForNormalTable(materials...)
} else {
// 表名1 USING 超级表名1 (tag列名) TAGS(tag值) (列名) VALUES (),()
rowSql, err = buildInsertStatementForSuperTable(materials...)
}
if err != nil {
return "", fmt.Errorf("failed to build insert statement: %w", err)
}
buf.WriteString(rowSql)
buf.WriteString(" \n")
}
return buf.String(), nil
}
2024-09-18 16:51:30 +08:00
// ScanRows 扫描多行数据
// A Superior Option: Use ScanRowsWithContext instead.
func (b *Mapper) ScanRows(target any, rows *sql.Rows) error {
return b.ScanRowsWithContext(context.Background(), target, rows)
}
// ScanRowsWithContext 扫描多行数据
// @param ctx 上下文
// @param target 结构体的指针 或 slice 指针 结构体需要使用 db 标签注解
// @param rows 数据库返回的行数据
// @return 返回错误
func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.Rows) error {
// 确保rows不为nil
if rows == nil {
return errors.New("rows cannot be nil")
}
// 获取target的反射类型和值
vf := reflect.ValueOf(target)
// 检查target是否为指针类型
if vf.Kind() != reflect.Ptr {
return errors.New("target must be a pointer to a struct or a slice of structs")
}
if vf.IsNil() {
return errors.New("target is nil, please pass a pointer to a struct or a slice of structs")
}
vf = vf.Elem() // 解引用指针以获取目标值
switch vf.Kind() {
case reflect.Struct:
// 提取 struct 内的信息
if err := b.scanStruct(target); err != nil {
return err
}
// target是指向单个struct的指针扫描一行数据
if !rows.Next() {
return sql.ErrNoRows
}
return b.scanRow(vf, rows)
case reflect.Slice:
// target是slice的指针
sliceElementType := vf.Type().Elem()
if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct {
return errors.New("target must be a pointer to a slice of structs")
}
// 提取 slice 中 struct 内的信息
if err := b.scanStruct(reflect.New(sliceElementType.Elem()).Interface()); err != nil { // 这里在抛出 不是指针类型
return err
}
for rows.Next() {
select {
case <-ctx.Done():
vf.SetLen(0) // 清空 slice
return ctx.Err()
default:
}
// 创建slice元素的新实例
newStruct := reflect.New(sliceElementType.Elem())
if err := b.scanRow(newStruct, rows); err != nil {
return err
}
// 将新实例添加到slice中
vf.Set(reflect.Append(vf, newStruct))
}
return nil
default:
return errors.New("target must be a pointer to a struct or a slice of structs")
}
}
// scanRow 用于扫描单行数据到结构体
func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error {
target = reflect.Indirect(target)
uniqueTypeName := getUniqueTypeName(target.Type())
mate, ok := b.structMateMap.Load(uniqueTypeName)
if !ok {
return fmt.Errorf("not found struct type mate: %s", uniqueTypeName)
}
// 假设我们有一个函数来获取列的值
columns, err := rows.Columns()
if err != nil {
return err
}
// 准备目标结构体的地址
dest := make([]interface{}, len(columns))
for i, colName := range columns {
// 拿到缓存的字段索引
idx, ok := mate.DBName2IndexCache[colName]
if !ok {
2024-09-18 16:51:30 +08:00
return fmt.Errorf("no corresponding field found for column %s", colName)
}
dest[i] = target.FieldByIndex(idx).Addr().Interface()
2024-09-18 16:51:30 +08:00
}
// 扫描数据
return rows.Scan(dest...)
}