diff --git a/README.md b/README.md index 1366cb9..9237738 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ func QueryOne(db *sql.DB) (*User,error){ if err!=nil { return err } - def rows.Close() + defer rows.Close() var user User if err:=ScanRows(&user,rows);err!=nil { @@ -102,7 +102,7 @@ func QueryAll(db *sql.DB)([]*User, error) { if err!=nil { return err } - def rows.Close() + defer rows.Close() var users []*User if err:=ScanRows(&users,rows);err!=nil { diff --git a/insert_build.go b/insert_build.go index 31fe919..f60a211 100644 --- a/insert_build.go +++ b/insert_build.go @@ -11,14 +11,14 @@ import ( // buildInsertStatementForSuperTable 构建超级表插入语句 // @param rows 同一超级表下的数据行 // @return 插入语句 -func buildInsertStatementForSuperTable(rows ...*TableRowMateria) (string, error) { +func buildInsertStatementForSuperTable(rows ...*TableRowMaterial) (string, error) { if len(rows) == 0 { return "", fmt.Errorf("no rows provided for super table insert") } var sb strings.Builder - // 格式化标签列和值 + // 开始格式化标签和值 formattedTagColumns := formatColumns(rows[0].TagColumns...) formattedTagValues := formatRowValues(rows[0].TagValues...) formattedColumns := formatColumns(rows[0].Columns...) @@ -56,7 +56,7 @@ func buildInsertStatementForSuperTable(rows ...*TableRowMateria) (string, error) // buildInsertStatementForNormalTable 普通表插入构建 // @param data 数据 当前应该是同一个表的 // @return 插入语句 表名1 (列名) VALUES (),() -func buildInsertStatementForNormalTable(rows ...*TableRowMateria) (string, error) { +func buildInsertStatementForNormalTable(rows ...*TableRowMaterial) (string, error) { if len(rows) == 0 { return "", fmt.Errorf("no data provided for normal table insert") } diff --git a/mapping.go b/mapping.go index ab8a315..3c55d0b 100644 --- a/mapping.go +++ b/mapping.go @@ -19,35 +19,141 @@ type TableNamer interface { } type Mapper struct { - structMateMap SyncMap[string, *StructMate] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息 + structMateMap SyncMap[string, *StructMeta] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息 } func (b *Mapper) scanStruct(data ...any) error { for _, datum := range data { if reflect.TypeOf(datum).Kind() != reflect.Ptr { - //return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum)) - return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(datum)) + return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(data)) } - if mate, err := scan(datum); err != nil { + // 分析结构体并缓存结果 + if _, err := b.analyzeStructWithCache(datum); err != nil { return err - } else { - b.structMateMap.Store(mate.UniqueTypeName, mate) } } return nil } +// analyzeStructWithCache 带缓存的结构体分析 +func (b *Mapper) analyzeStructWithCache(data any) (*StructMeta, error) { + // 获取结构体的类型 + t, v := extractReflectInfo(data) + + // 确保是结构体类型 + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("input data is not a struct") + } + + // 获取类型的唯一标识符 + uniqueTypeName := buildTypeKey(t) + + // 先检查缓存 + if cached, ok := b.structMateMap.Load(uniqueTypeName); ok { + return cached, nil + } + + // 初始化结果结构体 + sr := StructMeta{ + UniqueTypeName: uniqueTypeName, + DBName2IndexCache: make(map[string][]int, t.NumField()), + Field2IndexCache: make(map[string][]int, t.NumField()), + Field2DBNameCache: make(map[string]string, t.NumField()), + DBAnnotatedNames: make([]string, 0, t.NumField()), + TaggedFieldNames: make([]string, 0, t.NumField()), + SuperTableName: "", + } + + // 遍历结构体的字段 + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldValue := v.Field(i) + + if field.Anonymous { + // 处理匿名结构体 + if field.Type.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + // 如果指针是 nil,创建一个该类型的零值实例 + zeroValue := reflect.Zero(field.Type.Elem()) + fieldValue = zeroValue + } else { + fieldValue = fieldValue.Elem() + } + } + + if !fieldValue.CanInterface() { + continue + } + + // 递归分析匿名结构体,使用相同的缓存 + subResult, err := b.analyzeStructWithCache(fieldValue.Interface()) + if err != nil { + return nil, err + } + + // 合并分析结果 + for k, v := range subResult.DBName2IndexCache { + sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], i) + sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], v...) + } + + for k, v := range subResult.Field2IndexCache { + sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], i) + sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], v...) + } + for k, v := range subResult.Field2DBNameCache { + sr.Field2DBNameCache[k] = v + } + + sr.DBAnnotatedNames = append(sr.DBAnnotatedNames, subResult.DBAnnotatedNames...) + sr.TaggedFieldNames = append(sr.TaggedFieldNames, subResult.TaggedFieldNames...) + + continue + } + + // 处理普通字段 + columnName := field.Tag.Get("db") + if columnName == "-" || columnName == "" { + continue + } + + if len(sr.Field2IndexCache[field.Name]) > 0 { + return nil, fmt.Errorf("duplicate field [%s %s `db:%s`]", field.Name, field.Type.Name(), columnName) + } + + sr.Field2IndexCache[field.Name] = append([]int{}, i) + sr.DBName2IndexCache[columnName] = append(sr.DBName2IndexCache[columnName], i) + sr.Field2DBNameCache[field.Name] = columnName + + // 检查字段是否有taos注解 + if field.Tag.Get("taos") == "tag" { + sr.TaggedFieldNames = append(sr.TaggedFieldNames, field.Name) + } else { + sr.DBAnnotatedNames = append(sr.DBAnnotatedNames, field.Name) + } + } + + if sTableName, ok := tryGetSuperTableName(data); ok { + sr.SuperTableName = sTableName + } + + // 将结果存入缓存 + b.structMateMap.Store(uniqueTypeName, &sr) + + return &sr, nil +} + // 提取 struct 内的信息 // @param data 包含 struct 类型的数据 -// @return 返回一个 map[表名][]*TableRowMateria, 如果超级表名为空,则表示普通表 -func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, error) { - result := make(map[string][]*TableRowMateria) +// @return 返回一个 map[表名][]*TableRowMaterial, 如果超级表名为空,则表示普通表 +func (b *Mapper) extractTableMaterial(data ...any) (map[string][]*TableRowMaterial, error) { + result := make(map[string][]*TableRowMaterial) for _, item := range data { - tf, vf := getReflectTypeAndValue(item) - uniqueTypeName := getUniqueTypeName(tf) + tf, vf := extractReflectInfo(item) + uniqueTypeName := buildTypeKey(tf) // 获取表名 var tableName string @@ -74,7 +180,7 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, return nil, fmt.Errorf("not found struct type: %s", uniqueTypeName) } - materia := &TableRowMateria{ + trMaterial := &TableRowMaterial{ SuperTableName: mate.SuperTableName, TableName: tableName, TagColumns: make([]string, 0, len(mate.TaggedFieldNames)), @@ -85,7 +191,7 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, for _, name := range mate.DBAnnotatedNames { // db tag 名称 -- 数据库列名 - dbColumn := mate.Filed2DBNameCache[name] + dbColumn := mate.Field2DBNameCache[name] field := vf.FieldByIndex(mate.Field2IndexCache[name]) for field.Kind() == reflect.Ptr { @@ -97,13 +203,13 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, // 字段值 dbValue := field.Interface() - materia.Columns = append(materia.Columns, dbColumn) - materia.Values = append(materia.Values, dbValue) + trMaterial.Columns = append(trMaterial.Columns, dbColumn) + trMaterial.Values = append(trMaterial.Values, dbValue) } for _, name := range mate.TaggedFieldNames { // db tag 名称 -- 数据库列名 - tagColumn := mate.Filed2DBNameCache[name] + tagColumn := mate.Field2DBNameCache[name] field := vf.FieldByIndex(mate.Field2IndexCache[name]) for field.Kind() == reflect.Ptr { if field.IsNil() { @@ -114,13 +220,13 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, // 字段值 tagValue := field.Interface() - materia.TagColumns = append(materia.TagColumns, tagColumn) - materia.TagValues = append(materia.TagValues, tagValue) + trMaterial.TagColumns = append(trMaterial.TagColumns, tagColumn) + trMaterial.TagValues = append(trMaterial.TagValues, tagValue) } // 添加到结果 - key := materia.TableName - result[key] = append(result[key], materia) + key := trMaterial.TableName + result[key] = append(result[key], trMaterial) } return result, nil } @@ -135,17 +241,17 @@ func (b *Mapper) normalizeData(data ...any) []any { var normalizedData []any for _, item := range data { - if item == nil { + if isNil(item) { // 跳过nil值,避免后续反射调用panic continue } itemType := reflect.TypeOf(item) - if itemType.Kind() == reflect.Slice { + if itemType.Kind() == reflect.Slice || itemType.Kind() == reflect.Array { // 处理切片类型参数,展平成单个元素 sliceValue := reflect.ValueOf(item) for i := 0; i < sliceValue.Len(); i++ { elem := sliceValue.Index(i).Interface() - if elem == nil { + if isNil(elem) { // 跳过nil元素 continue } @@ -159,18 +265,10 @@ func (b *Mapper) normalizeData(data ...any) []any { // 转换所有元素为指针类型,便于后续一致性处理 for i, item := range normalizedData { - if item == nil { + if isNil(item) { continue } - itemType := reflect.TypeOf(item) - if itemType.Kind() != reflect.Ptr { - // 创建对应类型的指针 - itemPtr := reflect.New(itemType) - // 设置指针指向的值为元素本身 - itemPtr.Elem().Set(reflect.ValueOf(item)) - // 替换切片中的元素为指针类型 - normalizedData[i] = itemPtr.Interface() - } + normalizedData[i] = ensurePtr(item) } return normalizedData } @@ -185,11 +283,11 @@ func (b *Mapper) ToInsertSQL(data ...any) (string, error) { // 扫描 struct 类型数据 if err := b.scanStruct(normalizedData...); err != nil { - return "", fmt.Errorf("failed to scan struct: %w", err) + return "", fmt.Errorf("failed to analyzeStruct struct: %w", err) } // 提取 struct 内的信息 - tableMap, err := b.extractStructData(normalizedData...) + tableMap, err := b.extractTableMaterial(normalizedData...) if err != nil { return "", fmt.Errorf("failed to extract struct data: %w", err) } @@ -246,7 +344,7 @@ func (b *Mapper) ScanRows(target any, rows *sql.Rows) error { // @return 返回错误 func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.Rows) error { // 确保rows不为nil - if rows == nil { + if isNil(rows) { return errors.New("rows cannot be nil") } @@ -275,8 +373,8 @@ func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql. return sql.ErrNoRows } return b.scanRow(vf, rows) - case reflect.Slice: - // target是slice的指针 + case reflect.Slice, reflect.Array: + // target是slice或array,扫描多行数据 sliceElementType := vf.Type().Elem() if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct { return errors.New("target must be a pointer to a slice of structs") @@ -314,7 +412,7 @@ func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql. func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error { target = reflect.Indirect(target) - uniqueTypeName := getUniqueTypeName(target.Type()) + uniqueTypeName := buildTypeKey(target.Type()) mate, ok := b.structMateMap.Load(uniqueTypeName) if !ok { return fmt.Errorf("not found struct type mate: %s", uniqueTypeName) @@ -387,5 +485,5 @@ func (b *Mapper) ScanRowsToMapWithContext(ctx context.Context, rows *sql.Rows) ( } result = append(result, rowMap) } - return result, err + return result, nil } diff --git a/mapping_test.go b/mapping_test.go index c343baa..c0d5499 100644 --- a/mapping_test.go +++ b/mapping_test.go @@ -4,9 +4,9 @@ import ( "database/sql" "encoding/json" "fmt" - //_ "github.com/taosdata/driver-go/v3/taosWS" "testing" "time" + //_ "github.com/taosdata/driver-go/v3/taosWS" ) type TaosTAG struct { @@ -92,9 +92,9 @@ func (u *User) SuperTableName() string { func TestSimpleInsert(t *testing.T) { tdMapper := NewMapper() - data := []User{ - {Name: "张三", Age: 18}, - {Name: "李四", Age: 20}, + data := []any{ + &User{Name: "张三", Age: 18}, + User{Name: "李四", Age: 20}, } insertSql, err := tdMapper.ToInsertSQL(data) @@ -109,6 +109,10 @@ type SuperDevTAG struct { DevType string `db:"dev_type" taos:"tag"` } +func (s *SuperDevTAG) SuperTableName() string { + return "SuperDevTAG" +} + type SuperDev struct { SuperDevTAG Ts time.Time `db:"ts"` // 时间戳 @@ -133,7 +137,7 @@ func TestSuperDevInsert(t *testing.T) { SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"}, Ts: time.Now().Add(time.Second), AppSn: "a0002", Ct: 2.0, }, - SuperDev{ + &SuperDev{ SuperDevTAG: SuperDevTAG{DevId: "SN002", DevType: "模拟设备"}, Ts: time.Now(), AppSn: "a0003", Ct: 3.0, }, diff --git a/reflect_utils.go b/reflect_utils.go new file mode 100644 index 0000000..6c7cd81 --- /dev/null +++ b/reflect_utils.go @@ -0,0 +1,58 @@ +package tdmap + +import ( + "reflect" +) + +// extractReflectInfo 提取反射信息(自动解引用到基础类型) +func extractReflectInfo(data any) (reflect.Type, reflect.Value) { + t := reflect.TypeOf(data) + v := reflect.ValueOf(data) + + // 处理结构体和结构体指针 + for t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + return t, v +} + +// buildTypeKey 构建类型键(包路径.类型名) +func buildTypeKey(t reflect.Type) string { + return t.PkgPath() + "." + t.Name() +} + +// tryGetSuperTableName 调用对象的 SuperTableName 方法(如果存在) +func tryGetSuperTableName(obj any) (string, bool) { + ptr := ensurePtr(obj) + if it, ok := ptr.(interface{ SuperTableName() string }); ok { + return it.SuperTableName(), true + } + return "", false +} + +// ensurePtr 确保返回的是指针类型(不是指针则转换为指针) +func ensurePtr(obj any) any { + if v := reflect.ValueOf(obj); v.Kind() != reflect.Ptr { + // 创建对应类型的指针 + ptr := reflect.New(v.Type()) + // 设置指针指向的值为元素本身 + ptr.Elem().Set(v) + // 返回 指针 + return ptr.Interface() + } + return obj +} + +// isNil 判断值是否为 nil(包括指针类型) +func isNil(d any) bool { + if d == nil { + return true + } + + // 反射判断指针类型是否为nil + if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { + return v.IsNil() + } + return false +} diff --git a/scan.go b/scan.go deleted file mode 100644 index d394d3d..0000000 --- a/scan.go +++ /dev/null @@ -1,143 +0,0 @@ -package tdmap - -import ( - "fmt" - "reflect" -) - -func scan(data interface{}) (*StructMate, error) { - // 获取结构体的类型 - t, v := getReflectTypeAndValue(data) - - // 确保是结构体类型 - if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("input data is not a struct") - } - - // 获取包路径和类型名称 - uniqueTypeName := getUniqueTypeName(t) - - // 初始化结果结构体 - sr := StructMate{ - UniqueTypeName: uniqueTypeName, - DBName2IndexCache: make(map[string][]int, t.NumField()), - Field2IndexCache: make(map[string][]int, t.NumField()), - Filed2DBNameCache: make(map[string]string, t.NumField()), - DBAnnotatedNames: make([]string, 0, t.NumField()), - TaggedFieldNames: make([]string, 0, t.NumField()), - SuperTableName: "", - } - - //timeType := reflect.TypeOf(time.Time{}) - - // 遍历结构体的字段 - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - fieldValue := v.Field(i) - - if field.Anonymous { - - if field.Type.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - // 如果指针是 nil,创建一个该类型的零值实例 - zeroValue := reflect.Zero(field.Type.Elem()) - fieldValue = zeroValue - } else { - fieldValue = fieldValue.Elem() - } - } - - if !fieldValue.CanInterface() { - continue - } - - subResult, err := scan(fieldValue.Interface()) - if err != nil { - return nil, err - } - - for k, v := range subResult.DBName2IndexCache { - sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], i) - sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], v...) - } - - for k, v := range subResult.Field2IndexCache { - sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], i) - sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], v...) - } - for k, v := range subResult.Filed2DBNameCache { - sr.Filed2DBNameCache[k] = v - } - - sr.DBAnnotatedNames = append(sr.DBAnnotatedNames, subResult.DBAnnotatedNames...) - sr.TaggedFieldNames = append(sr.TaggedFieldNames, subResult.TaggedFieldNames...) - - continue - } - - // 检查字段是否有db注解 - columnName := field.Tag.Get("db") - if columnName == "-" || columnName == "" { - continue - } - - if len(sr.Field2IndexCache[field.Name]) > 0 { - return nil, fmt.Errorf("duplicate field [%s %s `db:%s`]", field.Name, field.Type.Name(), columnName) - } - - sr.Field2IndexCache[field.Name] = append([]int{}, i) - sr.DBName2IndexCache[columnName] = append(sr.DBName2IndexCache[columnName], i) - sr.Filed2DBNameCache[field.Name] = columnName - - // 检查字段是否有taos注解 - if field.Tag.Get("taos") == "tag" { - sr.TaggedFieldNames = append(sr.TaggedFieldNames, field.Name) - } else { - sr.DBAnnotatedNames = append(sr.DBAnnotatedNames, field.Name) - } - } - - sr.SuperTableName = callSuperTableName(data) - return &sr, nil -} - -// 获取包路径和类型名称 -func getUniqueTypeName(t reflect.Type) string { - pkgPath := t.PkgPath() - typeName := t.Name() - uniqueTypeName := fmt.Sprintf("%s.%s", pkgPath, typeName) - return uniqueTypeName -} - -func getReflectTypeAndValue(data any) (reflect.Type, reflect.Value) { - t := reflect.TypeOf(data) - v := reflect.ValueOf(data) - - // 处理结构体和结构体指针 - for t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - return t, v -} - -func callSuperTableName(obj any) string { - v := reflect.ValueOf(obj) - - if v.Kind() == reflect.Struct { - return "" - } - - // 检查是否可以调用 SuperTableName 方法 - superTableNameMethod := v.MethodByName("SuperTableName") - if !superTableNameMethod.IsValid() { - return "" - } - - // 调用 SuperTableName 方法 - results := superTableNameMethod.Call(nil) - if len(results) == 1 && results[0].Kind() == reflect.String { - return results[0].String() - } - return "" -} diff --git a/types.go b/types.go index de510bd..89874f1 100644 --- a/types.go +++ b/types.go @@ -1,6 +1,7 @@ package tdmap -type TableRowMateria struct { +// TableRowMaterial 表行数据 +type TableRowMaterial struct { SuperTableName string TableName string @@ -11,13 +12,13 @@ type TableRowMateria struct { Values []any // 值 } -// StructMate 静态化的结构体信息 -type StructMate struct { +// StructMeta 静态化的结构体信息 +type StructMeta struct { UniqueTypeName string // 结构体的唯一标识符 DBName2IndexCache map[string][]int // db 注解的名称到索引的映射缓存 Field2IndexCache map[string][]int // 字段名到索引的映射缓存 - Filed2DBNameCache map[string]string // 字段名到 db 注解的名称的映射缓存 + Field2DBNameCache map[string]string // 字段名到 db 注解的名称的映射缓存 DBAnnotatedNames []string // 包含 db 注解的 属性的名称 TaggedFieldNames []string // 包含的 tag 注解的 属性的名称