diff --git a/README.md b/README.md index 6440947..631fbe3 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ type TableNamer interface { -### 示例 +### 批量插入 -- 示例 ```go @@ -52,12 +52,12 @@ func (u *User) SuperTableName() string { } func TestSimpleInsert(t *testing.T) { - tdMapping := NewTdMapping() + tdMapper := NewMapper() data := []any{ &User{Name: "张三", Age: 18}, &User{Name: "李四", Age: 20}, } - insertSql, err := tdMapping.ToInsertSQL(data...) + insertSql, err := tdMapper.ToInsertSQL(data...) if err != nil { t.Fatal(err) } @@ -70,6 +70,44 @@ INSERT INTO `user_李四` USING `super_user` (`name`) TAGS ('李四') (`age`) VALUES (20) ``` +### 查询映射 -- 示例 + +```go +func QueryOne(db *sql.DB) (*User,error){ + tdMapper := NewMapper() + + rows,err:=db.Query("select * from User limit 100") + if err!=nil { + return err + } + def rows.Close() + + var user User + if err:=ScanRows(&user,rows);err!=nil { + return err + } + return &user, nil +} + +// 如果查询结果数量为0,则返回 sql.ErrNoRows +// var ErrNoRows = errors.New("sql: no rows in result set") +func QueryAll(db *sql.DB)([]*User, error) { + tdMapper := NewMapper() + + rows,err:=db.Query("select * from User limit 100") + if err!=nil { + return err + } + def rows.Close() + + var users []*User + if err:=ScanRows(&users,rows);err!=nil { + return err + } + return &users, nil +} +``` + diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/insert_build.go b/insert_build.go index 6373dc2..1d293ee 100644 --- a/insert_build.go +++ b/insert_build.go @@ -148,7 +148,7 @@ func formatRowValues(values ...any) string { case nil: formattedValues[i] = "null" case time.Time: - formattedValues[i] = fmt.Sprintf("'%s'", v.Format("2006-01-02 15:04:05")) + formattedValues[i] = fmt.Sprintf("'%s'", v.Format(time.RFC3339)) default: if reflect.TypeOf(val).Kind() == reflect.Ptr && reflect.ValueOf(val).IsNil() { formattedValues[i] = "null" diff --git a/mapping.go b/mapping.go index 4cc5ac8..005fd81 100644 --- a/mapping.go +++ b/mapping.go @@ -1,13 +1,16 @@ package tdmap import ( + "context" + "database/sql" + "errors" "fmt" "reflect" "strings" ) -func NewMapping() *TdMapping { - return &TdMapping{} +func NewMapper() *Mapper { + return &Mapper{} } // TableNamer 定义了一个获取表名的方法。 @@ -15,20 +18,21 @@ type TableNamer interface { TableName() string } -type TdMapping struct { - modelMates Map[string, *StructMate] +type Mapper struct { + structMateMap SyncMap[string, *StructMate] // 结构体 元信息缓存 key: struct 的唯一类型名称 value: 元信息 } -func (b *TdMapping) scanStruct(data ...any) error { +func (b *Mapper) scanStruct(data ...any) error { for _, datum := range data { if reflect.TypeOf(datum).Kind() != reflect.Ptr { - return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum)) + //return fmt.Errorf("需要指针类型:%v", reflect.TypeOf(datum)) + return fmt.Errorf("need a pointer type: %v", reflect.TypeOf(datum)) } if mate, err := scan(datum); err != nil { return err } else { - b.modelMates.Store(mate.UniqueTypeName, mate) + b.structMateMap.Store(mate.UniqueTypeName, mate) } } @@ -38,7 +42,7 @@ func (b *TdMapping) scanStruct(data ...any) error { // 提取 struct 内的信息 // @param data 包含 struct 类型的数据 // @return 返回一个 map[表名][]*TableRowMateria, 如果超级表名为空,则表示普通表 -func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateria, error) { +func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, error) { result := make(map[string][]*TableRowMateria) for _, item := range data { @@ -65,7 +69,7 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri } } - mate, ok := b.modelMates.Load(uniqueTypeName) + mate, ok := b.structMateMap.Load(uniqueTypeName) if !ok { return nil, fmt.Errorf("not found struct type: %s", uniqueTypeName) } @@ -81,9 +85,9 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri for _, name := range mate.DBAnnotatedNames { // db tag 名称 -- 数据库列名 - dbColumn := mate.FiledDBNameCache[name] + dbColumn := mate.Filed2DBNameCache[name] - field := vf.FieldByIndex(mate.FieldIndexCache[name]) + field := vf.FieldByIndex(mate.Field2IndexCache[name]) for field.Kind() == reflect.Ptr { if field.IsNil() { break @@ -99,8 +103,8 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri for _, name := range mate.TaggedFieldNames { // db tag 名称 -- 数据库列名 - tagColumn := mate.FiledDBNameCache[name] - field := vf.FieldByIndex(mate.FieldIndexCache[name]) + tagColumn := mate.Filed2DBNameCache[name] + field := vf.FieldByIndex(mate.Field2IndexCache[name]) for field.Kind() == reflect.Ptr { if field.IsNil() { break @@ -121,7 +125,7 @@ func (b *TdMapping) extractStructData(data ...any) (map[string][]*TableRowMateri return result, nil } -func (b *TdMapping) ToInsertSQL(data ...any) (string, error) { +func (b *Mapper) ToInsertSQL(data ...any) (string, error) { // 扫描 struct 类型数据 if err := b.scanStruct(data...); err != nil { return "", fmt.Errorf("failed to scan struct: %w", err) @@ -171,3 +175,115 @@ func (b *TdMapping) ToInsertSQL(data ...any) (string, error) { return buf.String(), nil } + +// ScanRows 扫描多行数据 +// A Superior Option: Use ScanRowsWithContext instead. +func (b *Mapper) ScanRows(target any, rows *sql.Rows) error { + return b.ScanRowsWithContext(context.Background(), target, rows) +} + +// ScanRowsWithContext 扫描多行数据 +// @param ctx 上下文 +// @param target 结构体的指针 或 slice 指针 结构体需要使用 db 标签注解 +// @param rows 数据库返回的行数据 +// @return 返回错误 +func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.Rows) error { + // 确保rows不为nil + if rows == nil { + return errors.New("rows cannot be nil") + } + + // 获取target的反射类型和值 + vf := reflect.ValueOf(target) + + // 检查target是否为指针类型 + if vf.Kind() != reflect.Ptr { + return errors.New("target must be a pointer to a struct or a slice of structs") + } + + if vf.IsNil() { + return errors.New("target is nil, please pass a pointer to a struct or a slice of structs") + } + + vf = vf.Elem() // 解引用指针以获取目标值 + + switch vf.Kind() { + case reflect.Struct: + // 提取 struct 内的信息 + if err := b.scanStruct(target); err != nil { + return err + } + // target是指向单个struct的指针,扫描一行数据 + if !rows.Next() { + return sql.ErrNoRows + } + return b.scanRow(vf, rows) + case reflect.Slice: + // target是slice的指针 + sliceElementType := vf.Type().Elem() + if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct { + return errors.New("target must be a pointer to a slice of structs") + } + + // 提取 slice 中 struct 内的信息 + if err := b.scanStruct(reflect.New(sliceElementType.Elem()).Interface()); err != nil { // 这里在抛出 不是指针类型 + return err + } + + for rows.Next() { + select { + case <-ctx.Done(): + vf.SetLen(0) // 清空 slice + return ctx.Err() + default: + } + + // 创建slice元素的新实例 + newStruct := reflect.New(sliceElementType.Elem()) + if err := b.scanRow(newStruct, rows); err != nil { + return err + } + + // 将新实例添加到slice中 + vf.Set(reflect.Append(vf, newStruct)) + } + + if vf.Len() == 0 { // 如果slice为空,返回 sql.ErrNoRows + return sql.ErrNoRows + } + return nil + default: + return errors.New("target must be a pointer to a struct or a slice of structs") + } +} + +// scanRow 用于扫描单行数据到结构体 +func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error { + target = reflect.Indirect(target) + + uniqueTypeName := getUniqueTypeName(target.Type()) + mate, ok := b.structMateMap.Load(uniqueTypeName) + if !ok { + return fmt.Errorf("not found struct type mate: %s", uniqueTypeName) + } + + // 假设我们有一个函数来获取列的值 + columns, err := rows.Columns() + if err != nil { + return err + } + + // 准备目标结构体的地址 + dest := make([]interface{}, len(columns)) + for i, colName := range columns { + // 拿到缓存的字段索引 + if idx, ok := mate.DBName2IndexCache[colName]; ok { + dest[i] = target.FieldByIndex(idx).Addr().Interface() + } else { + return fmt.Errorf("no corresponding field found for column %s", colName) + } + } + + // 扫描数据 + return rows.Scan(dest...) +} diff --git a/mapping_test.go b/mapping_test.go index 6c2e0a7..4e75f70 100644 --- a/mapping_test.go +++ b/mapping_test.go @@ -2,6 +2,7 @@ package tdmap import ( "database/sql" + "encoding/json" "fmt" "testing" "time" @@ -49,50 +50,26 @@ func (s *TaosUser) TableName() string { } func TestBuilderInsert(t *testing.T) { - tdMapping := NewMapping() + tdMapper := NewMapper() p := 1 data := []any{ &TaosDevice{ - TaosTAG: &TaosTAG{ - DevId: "设备ID", - DevType: "测试设备", - DataType: "测试数据", - }, - Ts: time.Now(), - LoadUnitId: "负载单体ID", - PInt: &p, - NullInt32: sql.NullInt32{Int32: 32, Valid: true}, + TaosTAG: &TaosTAG{DevId: "设备ID", DevType: "测试设备", DataType: "测试数据"}, + Ts: time.Now(), LoadUnitId: "负载单体ID", PInt: &p, NullInt32: sql.NullInt32{Int32: 32, Valid: true}, }, &TaosUser{ - TaosTAG: &TaosTAG{ - DevId: "User001", - DevType: "User类型", - DataType: "User数据类型001", - }, - Ts: time.Now(), - Name: "张三", + TaosTAG: &TaosTAG{DevId: "User001", DevType: "User类型", DataType: "User数据类型001"}, + Ts: time.Now(), Name: "张三", }, &TaosUser{ - TaosTAG: &TaosTAG{ - DevId: "User002", - DevType: "User类型", - DataType: "User数据类型002", - }, - Ts: time.Now(), - Name: "李四", - Weight: 110, + TaosTAG: &TaosTAG{DevId: "User002", DevType: "User类型", DataType: "User数据类型002"}, + Ts: time.Now(), Name: "李四", Weight: 110, }, &TaosUser{ - TaosTAG: &TaosTAG{ - DevId: "User002", - DevType: "User类型", - DataType: "User数据类型002", - }, - Name: "李四", - Ts: time.Now(), - Weight: 100, + TaosTAG: &TaosTAG{DevId: "User002", DevType: "User类型", DataType: "User数据类型002"}, + Name: "李四", Ts: time.Now(), Weight: 100, }, } - insertSql, err := tdMapping.ToInsertSQL(data...) + insertSql, err := tdMapper.ToInsertSQL(data...) if err != nil { t.Fatal(err) } @@ -113,14 +90,108 @@ func (u *User) SuperTableName() string { } func TestSimpleInsert(t *testing.T) { - tdMapping := NewMapping() + tdMapper := NewMapper() data := []any{ &User{Name: "张三", Age: 18}, &User{Name: "李四", Age: 20}, } - insertSql, err := tdMapping.ToInsertSQL(data...) + + insertSql, err := tdMapper.ToInsertSQL(data...) if err != nil { t.Fatal(err) } fmt.Println(insertSql) } + +type SuperDevTAG struct { + DevId string `db:"dev_id" taos:"tag"` + DevType string `db:"dev_type" taos:"tag"` +} + +type SuperDev struct { + SuperDevTAG + Ts time.Time `db:"ts"` // 时间戳 + AppSn string `db:"app_sn"` + Ct float64 `db:"ct"` +} + +func (s *SuperDev) SuperTableName() string { + return "super_dev" +} +func (s *SuperDev) TableName() string { + return "dev_" + s.DevId +} + +func TestSuperDevInsert(t *testing.T) { + var data = []any{ + &SuperDev{ + SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"}, + Ts: time.Now(), AppSn: "a0001", Ct: 1.0, + }, + &SuperDev{ + SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"}, + Ts: time.Now().Add(time.Second), AppSn: "a0002", Ct: 2.0, + }, + &SuperDev{ + SuperDevTAG: SuperDevTAG{DevId: "SN002", DevType: "模拟设备"}, + Ts: time.Now(), AppSn: "a0003", Ct: 3.0, + }, + } + + tdMapper := NewMapper() + insertSql, err := tdMapper.ToInsertSQL(data...) + if err != nil { + t.Fatal(err) + } + fmt.Println(insertSql) +} + +func TestScanRows(t *testing.T) { + /* + 文档参考: https://docs.taosdata.com/reference/connector/go/#websocket-%E8%BF%9E%E6%8E%A5 + + go get github.com/taosdata/driver-go/v3 + + import _ "github.com/taosdata/driver-go/v3/taosWS" + + 超级表创建 + CREATE STABLE `super_dev` (`ts` TIMESTAMP , `app_sn` VARCHAR(500) , `ct` INT ) TAGS (`dev_id` VARCHAR(50), `dev_type` VARCHAR(50)) + + 批量插入 + INSERT INTO + `dev_SN001` USING `super_dev` (`dev_id`,`dev_type`) TAGS ('SN001','模拟设备') (`ts`,`app_sn`,`ct`) + VALUES ('2024-09-18T16:22:17+08:00','a0001',1),('2024-09-18T16:22:18+08:00','a0002',2) + `dev_SN002` USING `super_dev` (`dev_id`,`dev_type`) TAGS ('SN002','模拟设备') (`ts`,`app_sn`,`ct`) + VALUES ('2024-09-18T16:22:17+08:00','a0003',3) + + */ + + var taosUri = "root:taosdata@localhost:6041/test" + db, err := sql.Open("taosWS", taosUri) + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query("select * from super_dev order by ts desc limit 100") + if err != nil { + t.Fatal(err) + } + defer func() { _ = rows.Close() }() + + tdMapper := NewMapper() + + var sd SuperDev + if err = tdMapper.ScanRows(&sd, rows); err != nil { + t.Fatal(err) + } + + indent, _ := json.MarshalIndent(&sd, "", " ") + fmt.Println(string(indent)) + + //var sdArray []SuperDev + //if err = tdMapper.ScanRows(&sdArray, rows); err != nil { + // t.Fatal(err) + //} + //marshal, _ := json.MarshalIndent(sdArray, "", " ") + //fmt.Println(len(sdArray), string(marshal)) +} diff --git a/scan.go b/scan.go index 79892d4..eb93511 100644 --- a/scan.go +++ b/scan.go @@ -19,12 +19,13 @@ func scan(data interface{}) (*StructMate, error) { // 初始化结果结构体 sr := StructMate{ - UniqueTypeName: uniqueTypeName, - FieldIndexCache: make(map[string][]int, t.NumField()), - FiledDBNameCache: make(map[string]string, t.NumField()), - DBAnnotatedNames: make([]string, 0, t.NumField()), - TaggedFieldNames: make([]string, 0, t.NumField()), - SuperTableName: "", + UniqueTypeName: uniqueTypeName, + DBName2IndexCache: make(map[string][]int, t.NumField()), + Field2IndexCache: make(map[string][]int, t.NumField()), + Filed2DBNameCache: make(map[string]string, t.NumField()), + DBAnnotatedNames: make([]string, 0, t.NumField()), + TaggedFieldNames: make([]string, 0, t.NumField()), + SuperTableName: "", } //timeType := reflect.TypeOf(time.Time{}) @@ -55,52 +56,34 @@ func scan(data interface{}) (*StructMate, error) { return nil, err } - for k, v := range subResult.FieldIndexCache { - sr.FieldIndexCache[k] = append(sr.FieldIndexCache[k], i) - sr.FieldIndexCache[k] = append(sr.FieldIndexCache[k], v...) + for k, v := range subResult.DBName2IndexCache { + sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], i) + sr.DBName2IndexCache[k] = append(sr.DBName2IndexCache[k], v...) } - for k, v := range subResult.FiledDBNameCache { - sr.FiledDBNameCache[k] = v + + for k, v := range subResult.Field2IndexCache { + sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], i) + sr.Field2IndexCache[k] = append(sr.Field2IndexCache[k], v...) } + for k, v := range subResult.Filed2DBNameCache { + sr.Filed2DBNameCache[k] = v + } + sr.DBAnnotatedNames = append(sr.DBAnnotatedNames, subResult.DBAnnotatedNames...) sr.TaggedFieldNames = append(sr.TaggedFieldNames, subResult.TaggedFieldNames...) continue } - //// 如果字段是结构体或结构体指针,递归处理 - //if (field.Type.Kind() == reflect.Struct || field.Type.Kind() == reflect.Ptr) && field.Type != timeType { - // - // if field.Type.Kind() == reflect.Ptr { - // if fieldValue.IsNil() { - // // 如果指针是 nil,创建一个该类型的零值实例 - // zeroValue := reflect.Zero(field.Type.Elem()) - // fieldValue = zeroValue - // } else { - // fieldValue = fieldValue.Elem() - // } - // } - // - // if !fieldValue.CanInterface() { - // continue - // } - // - // subResult, err := scan(fieldValue.Interface()) - // if err != nil { - // return nil, err - // } - // sr.merge(subResult) - // continue - //} - // 检查字段是否有db注解 columnName := field.Tag.Get("db") if columnName == "-" || columnName == "" { continue } - sr.FieldIndexCache[field.Name] = []int{i} - sr.FiledDBNameCache[field.Name] = columnName + sr.Field2IndexCache[field.Name] = append(sr.Field2IndexCache[field.Name], i) + sr.DBName2IndexCache[columnName] = append(sr.DBName2IndexCache[columnName], i) + sr.Filed2DBNameCache[field.Name] = columnName // 检查字段是否有taos注解 if field.Tag.Get("taos") == "tag" { @@ -122,7 +105,7 @@ func getUniqueTypeName(t reflect.Type) string { return uniqueTypeName } -func getReflectTypeAndValue(data interface{}) (reflect.Type, reflect.Value) { +func getReflectTypeAndValue(data any) (reflect.Type, reflect.Value) { t := reflect.TypeOf(data) v := reflect.ValueOf(data) diff --git a/sync_map.go b/sync_map.go index 72cbab8..8c2214f 100644 --- a/sync_map.go +++ b/sync_map.go @@ -4,16 +4,16 @@ import ( "sync" ) -type Map[T any, V any] struct { +type SyncMap[T any, V any] struct { sMap sync.Map } -func (m *Map[T, V]) Store(key T, value V) { +func (m *SyncMap[T, V]) Store(key T, value V) { m.sMap.Store(key, value) } -func (m *Map[T, V]) Load(key T) (value V, ok bool) { +func (m *SyncMap[T, V]) Load(key T) (value V, ok bool) { v, ok := m.sMap.Load(key) if ok { return v.(V), ok @@ -21,17 +21,17 @@ func (m *Map[T, V]) Load(key T) (value V, ok bool) { return } -func (m *Map[T, V]) Delete(key T) { +func (m *SyncMap[T, V]) Delete(key T) { m.sMap.Delete(key) } -func (m *Map[T, V]) Range(f func(T, V) bool) { +func (m *SyncMap[T, V]) Range(f func(T, V) bool) { m.sMap.Range(func(key, value any) bool { return f(key.(T), value.(V)) }) } -func (m *Map[T, V]) LoadOrStore(key T, value V) (actual V, loaded bool) { +func (m *SyncMap[T, V]) LoadOrStore(key T, value V) (actual V, loaded bool) { _actual, loaded := m.sMap.LoadOrStore(key, value) if loaded { actual = _actual.(V) diff --git a/sync_map_test.go b/sync_map_test.go index 618b227..5ab1f22 100644 --- a/sync_map_test.go +++ b/sync_map_test.go @@ -6,7 +6,7 @@ import ( ) func TestAA(t *testing.T) { - var m Map[string, int] + var m SyncMap[string, int] m.Store("a", 1234) value, _ := m.Load("a") diff --git a/types.go b/types.go index 17e46b0..de510bd 100644 --- a/types.go +++ b/types.go @@ -15,8 +15,9 @@ type TableRowMateria struct { type StructMate struct { UniqueTypeName string // 结构体的唯一标识符 - FieldIndexCache map[string][]int // 字段名到索引的映射缓存 - FiledDBNameCache map[string]string // 字段名到 db 注解的名称的映射缓存 + DBName2IndexCache map[string][]int // db 注解的名称到索引的映射缓存 + Field2IndexCache map[string][]int // 字段名到索引的映射缓存 + Filed2DBNameCache map[string]string // 字段名到 db 注解的名称的映射缓存 DBAnnotatedNames []string // 包含 db 注解的 属性的名称 TaggedFieldNames []string // 包含的 tag 注解的 属性的名称