tdengine-mapper-go/mapping.go

286 lines
7.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tdmap
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
)
func NewMapper() *Mapper {
return &Mapper{}
}
// TableNamer 定义了一个获取表名的方法。
type TableNamer interface {
TableName() string
}
type Mapper struct {
structMateMap SyncMap[string, *StructMate] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息
}
func (b *Mapper) scanStruct(data ...any) error {
for _, datum := range data {
if reflect.TypeOf(datum).Kind() != reflect.Ptr {
//return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum))
return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(datum))
}
if mate, err := scan(datum); err != nil {
return err
} else {
b.structMateMap.Store(mate.UniqueTypeName, mate)
}
}
return nil
}
// 提取 struct 内的信息
// @param data 包含 struct 类型的数据
// @return 返回一个 map[表名][]*TableRowMateria 如果超级表名为空,则表示普通表
func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, error) {
result := make(map[string][]*TableRowMateria)
for _, item := range data {
tf, vf := getReflectTypeAndValue(item)
uniqueTypeName := getUniqueTypeName(tf)
// 获取表名
var tableName string
{
if vf.CanAddr() {
if a, ok := vf.Addr().Interface().(TableNamer); ok {
tableName = a.TableName()
}
}
if tableName == "" {
if a, ok := item.(TableNamer); ok {
tableName = a.TableName()
}
}
if tableName == "" {
return nil, fmt.Errorf("not import TableName() string func for struct type: %s", uniqueTypeName)
}
}
mate, ok := b.structMateMap.Load(uniqueTypeName)
if !ok {
return nil, fmt.Errorf("not found struct type: %s", uniqueTypeName)
}
materia := &TableRowMateria{
SuperTableName: mate.SuperTableName,
TableName: tableName,
TagColumns: make([]string, 0, len(mate.TaggedFieldNames)),
TagValues: make([]any, 0, len(mate.TaggedFieldNames)),
Columns: make([]string, 0, len(mate.DBAnnotatedNames)),
Values: make([]any, 0, len(mate.DBAnnotatedNames)),
}
for _, name := range mate.DBAnnotatedNames {
// db tag 名称 -- 数据库列名
dbColumn := mate.Filed2DBNameCache[name]
field := vf.FieldByIndex(mate.Field2IndexCache[name])
for field.Kind() == reflect.Ptr {
if field.IsNil() {
break
}
field = field.Elem()
}
// 字段值
dbValue := field.Interface()
materia.Columns = append(materia.Columns, dbColumn)
materia.Values = append(materia.Values, dbValue)
}
for _, name := range mate.TaggedFieldNames {
// db tag 名称 -- 数据库列名
tagColumn := mate.Filed2DBNameCache[name]
field := vf.FieldByIndex(mate.Field2IndexCache[name])
for field.Kind() == reflect.Ptr {
if field.IsNil() {
break
}
field = field.Elem()
}
// 字段值
tagValue := field.Interface()
materia.TagColumns = append(materia.TagColumns, tagColumn)
materia.TagValues = append(materia.TagValues, tagValue)
}
// 添加到结果
key := materia.TableName
result[key] = append(result[key], materia)
}
return result, nil
}
func (b *Mapper) ToInsertSQL(data ...any) (string, error) {
// 扫描 struct 类型数据
if err := b.scanStruct(data...); err != nil {
return "", fmt.Errorf("failed to scan struct: %w", err)
}
// 提取 struct 内的信息
tableMap, err := b.extractStructData(data...)
if err != nil {
return "", fmt.Errorf("failed to extract struct data: %w", err)
}
if len(tableMap) == 0 {
return "", fmt.Errorf("no data to insert")
}
// 构建 insert 语句
/*
INSERT INTO
表名1 USING 超级表名1 (tag列名) TAGS(tag值) (列名) VALUES (),()
表名2 USING 超级表名2 (tag列名)TAGS(tag值) (列名) VALUES (),()
*/
var buf strings.Builder
buf.WriteString("INSERT INTO \n")
for _, materials := range tableMap {
if len(materials) == 0 {
continue
}
var (
rowSql string
err error
)
if materials[0].SuperTableName == "" {
// 表名1 (列名) VALUES (),()
rowSql, err = buildInsertStatementForNormalTable(materials...)
} else {
// 表名1 USING 超级表名1 (tag列名) TAGS(tag值) (列名) VALUES (),()
rowSql, err = buildInsertStatementForSuperTable(materials...)
}
if err != nil {
return "", fmt.Errorf("failed to build insert statement: %w", err)
}
buf.WriteString(rowSql)
buf.WriteString(" \n")
}
return buf.String(), nil
}
// ScanRows 扫描多行数据
// A Superior Option: Use ScanRowsWithContext instead.
func (b *Mapper) ScanRows(target any, rows *sql.Rows) error {
return b.ScanRowsWithContext(context.Background(), target, rows)
}
// ScanRowsWithContext 扫描多行数据
// @param ctx 上下文
// @param target 结构体的指针 或 slice 指针 结构体需要使用 db 标签注解
// @param rows 数据库返回的行数据
// @return 返回错误
func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.Rows) error {
// 确保rows不为nil
if rows == nil {
return errors.New("rows cannot be nil")
}
// 获取target的反射类型和值
vf := reflect.ValueOf(target)
// 检查target是否为指针类型
if vf.Kind() != reflect.Ptr {
return errors.New("target must be a pointer to a struct or a slice of structs")
}
if vf.IsNil() {
return errors.New("target is nil, please pass a pointer to a struct or a slice of structs")
}
vf = vf.Elem() // 解引用指针以获取目标值
switch vf.Kind() {
case reflect.Struct:
// 提取 struct 内的信息
if err := b.scanStruct(target); err != nil {
return err
}
// target是指向单个struct的指针扫描一行数据
if !rows.Next() {
return sql.ErrNoRows
}
return b.scanRow(vf, rows)
case reflect.Slice:
// target是slice的指针
sliceElementType := vf.Type().Elem()
if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct {
return errors.New("target must be a pointer to a slice of structs")
}
// 提取 slice 中 struct 内的信息
if err := b.scanStruct(reflect.New(sliceElementType.Elem()).Interface()); err != nil { // 这里在抛出 不是指针类型
return err
}
for rows.Next() {
select {
case <-ctx.Done():
vf.SetLen(0) // 清空 slice
return ctx.Err()
default:
}
// 创建slice元素的新实例
newStruct := reflect.New(sliceElementType.Elem())
if err := b.scanRow(newStruct, rows); err != nil {
return err
}
// 将新实例添加到slice中
vf.Set(reflect.Append(vf, newStruct))
}
return nil
default:
return errors.New("target must be a pointer to a struct or a slice of structs")
}
}
// scanRow 用于扫描单行数据到结构体
func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error {
target = reflect.Indirect(target)
uniqueTypeName := getUniqueTypeName(target.Type())
mate, ok := b.structMateMap.Load(uniqueTypeName)
if !ok {
return fmt.Errorf("not found struct type mate: %s", uniqueTypeName)
}
// 假设我们有一个函数来获取列的值
columns, err := rows.Columns()
if err != nil {
return err
}
// 准备目标结构体的地址
dest := make([]interface{}, len(columns))
for i, colName := range columns {
// 拿到缓存的字段索引
idx, ok := mate.DBName2IndexCache[colName]
if !ok {
return fmt.Errorf("no corresponding field found for column %s", colName)
}
dest[i] = target.FieldByIndex(idx).Addr().Interface()
}
// 扫描数据
return rows.Scan(dest...)
}