5 Commits

Author SHA1 Message Date
ac1054d833 feat(ScanRows): 新增 scanRows 转map数组 2025-06-25 17:42:22 +08:00
340deecefb feat(ToInsert): 支持数组直接传入 2025-06-20 14:13:24 +08:00
09770b07dc fix(insert_build): 修复在生成sql语句时,时间对象精度损失问题 2025-06-19 17:10:47 +08:00
ca03e37c18 fix(scan): 修复对struct存在重复属性导致的崩溃
新增了在递归scan对象属性的时候发现嵌套属性名称和当前属性名称一致的时候做了友好错误提示
2025-02-05 13:26:41 +08:00
dde9b6f74e feat: 去掉了批量查询时,结果为空的报错 2024-09-19 14:11:17 +08:00
5 changed files with 175 additions and 24 deletions

View File

@@ -73,6 +73,7 @@ INSERT INTO
### 查询映射 -- 示例 ### 查询映射 -- 示例
```go ```go
// 如果没有结果,会返回 sql.ErrNoRows
func QueryOne(db *sql.DB) (*User,error){ func QueryOne(db *sql.DB) (*User,error){
tdMapper := NewMapper() tdMapper := NewMapper()
@@ -89,8 +90,7 @@ func QueryOne(db *sql.DB) (*User,error){
return &user, nil return &user, nil
} }
// 如果查询结果数量为0则返回 sql.ErrNoRows // 如果查询结果数量为0不会返回错误
// var ErrNoRows = errors.New("sql: no rows in result set")
func QueryAll(db *sql.DB)([]*User, error) { func QueryAll(db *sql.DB)([]*User, error) {
tdMapper := NewMapper() tdMapper := NewMapper()

View File

@@ -137,7 +137,7 @@ func formatRowValues(values ...any) string {
} }
case sql.NullTime: case sql.NullTime:
if v.Valid { if v.Valid {
formattedValues[i] = fmt.Sprintf("'%s'", v.Time.Format(time.RFC3339)) formattedValues[i] = fmt.Sprintf("'%s'", v.Time.Format(time.RFC3339Nano))
} else { } else {
formattedValues[i] = "null" formattedValues[i] = "null"
} }
@@ -148,7 +148,7 @@ func formatRowValues(values ...any) string {
case nil: case nil:
formattedValues[i] = "null" formattedValues[i] = "null"
case time.Time: case time.Time:
formattedValues[i] = fmt.Sprintf("'%s'", v.Format(time.RFC3339)) formattedValues[i] = fmt.Sprintf("'%s'", v.Format(time.RFC3339Nano))
default: default:
if reflect.TypeOf(val).Kind() == reflect.Ptr && reflect.ValueOf(val).IsNil() { if reflect.TypeOf(val).Kind() == reflect.Ptr && reflect.ValueOf(val).IsNil() {
formattedValues[i] = "null" formattedValues[i] = "null"

View File

@@ -65,7 +65,7 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria,
} }
if tableName == "" { if tableName == "" {
return nil, fmt.Errorf("not import TableName() string func for struct type: %s", uniqueTypeName) return nil, fmt.Errorf("not func TableName() string func for struct type: %s", uniqueTypeName)
} }
} }
@@ -125,14 +125,71 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria,
return result, nil return result, nil
} }
// normalizeData 对输入数据进行统一处理和规范化
// 支持输入为任意数量的参数,参数可以是单个结构体指针或结构体,也可以是结构体切片或指针切片
// 处理逻辑包括:
// 1. 将切片类型的参数拆开放入一个统一的一维切片中(展平)
// 2. 跳过其中的 nil 元素,确保数据有效性
// 3. 将所有非指针类型的元素转换为对应的指针类型,方便后续统一处理
func (b *Mapper) normalizeData(data ...any) []any {
var normalizedData []any
for _, item := range data {
if item == nil {
// 跳过nil值避免后续反射调用panic
continue
}
itemType := reflect.TypeOf(item)
if itemType.Kind() == reflect.Slice {
// 处理切片类型参数,展平成单个元素
sliceValue := reflect.ValueOf(item)
for i := 0; i < sliceValue.Len(); i++ {
elem := sliceValue.Index(i).Interface()
if elem == nil {
// 跳过nil元素
continue
}
normalizedData = append(normalizedData, elem)
}
} else {
// 非切片类型直接追加
normalizedData = append(normalizedData, item)
}
}
// 转换所有元素为指针类型,便于后续一致性处理
for i, item := range normalizedData {
if item == nil {
continue
}
itemType := reflect.TypeOf(item)
if itemType.Kind() != reflect.Ptr {
// 创建对应类型的指针
itemPtr := reflect.New(itemType)
// 设置指针指向的值为元素本身
itemPtr.Elem().Set(reflect.ValueOf(item))
// 替换切片中的元素为指针类型
normalizedData[i] = itemPtr.Interface()
}
}
return normalizedData
}
func (b *Mapper) ToInsertSQL(data ...any) (string, error) { func (b *Mapper) ToInsertSQL(data ...any) (string, error) {
// 统一规范化处理输入数据,将切片拆平并转换元素为指针
normalizedData := b.normalizeData(data...)
if len(normalizedData) == 0 {
return "", fmt.Errorf("data is empty")
}
// 扫描 struct 类型数据 // 扫描 struct 类型数据
if err := b.scanStruct(data...); err != nil { if err := b.scanStruct(normalizedData...); err != nil {
return "", fmt.Errorf("failed to scan struct: %w", err) return "", fmt.Errorf("failed to scan struct: %w", err)
} }
// 提取 struct 内的信息 // 提取 struct 内的信息
tableMap, err := b.extractStructData(data...) tableMap, err := b.extractStructData(normalizedData...)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to extract struct data: %w", err) return "", fmt.Errorf("failed to extract struct data: %w", err)
} }
@@ -247,10 +304,6 @@ func (b *Mapper) ScanRowsWithContext(ctx context.Context, target any, rows *sql.
// 将新实例添加到slice中 // 将新实例添加到slice中
vf.Set(reflect.Append(vf, newStruct)) vf.Set(reflect.Append(vf, newStruct))
} }
if vf.Len() == 0 { // 如果slice为空返回 sql.ErrNoRows
return sql.ErrNoRows
}
return nil return nil
default: default:
return errors.New("target must be a pointer to a struct or a slice of structs") return errors.New("target must be a pointer to a struct or a slice of structs")
@@ -277,13 +330,62 @@ func (b *Mapper) scanRow(target reflect.Value, rows *sql.Rows) error {
dest := make([]interface{}, len(columns)) dest := make([]interface{}, len(columns))
for i, colName := range columns { for i, colName := range columns {
// 拿到缓存的字段索引 // 拿到缓存的字段索引
if idx, ok := mate.DBName2IndexCache[colName]; ok { idx, ok := mate.DBName2IndexCache[colName]
dest[i] = target.FieldByIndex(idx).Addr().Interface() if !ok {
} else {
return fmt.Errorf("no corresponding field found for column %s", colName) return fmt.Errorf("no corresponding field found for column %s", colName)
} }
dest[i] = target.FieldByIndex(idx).Addr().Interface()
} }
// 扫描数据 // 扫描数据
return rows.Scan(dest...) return rows.Scan(dest...)
} }
// ScanRowsToMap 扫描多行数据到map
// @param rows 数据库返回的行数据
// @return 返回map数组和错误
func (b *Mapper) ScanRowsToMap(rows *sql.Rows) ([]map[string]any, error) {
return b.ScanRowsToMapWithContext(context.Background(), rows)
}
// ScanRowsToMapWithContext 扫描多行数据到map
// @param ctx 上下文
// @param rows 数据库返回的行数据
// @return 返回map数组和错误
func (b *Mapper) ScanRowsToMapWithContext(ctx context.Context, rows *sql.Rows) ([]map[string]any, error) {
columns, err := rows.Columns()
if err != nil {
return nil, err
}
result := make([]map[string]any, 0)
for rows.Next() {
// 检查上下文是否已取消
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// 创建一个切片来存储列的值
values := make([]any, len(columns))
for i := range values {
values[i] = new(any)
}
// 扫描行数据到切片中
if err := rows.Scan(values...); err != nil {
return nil, err
}
// 构建 结果映射
rowMap := make(map[string]any, len(columns))
for i := range values {
column := columns[i]
value := *(values[i].(*any))
rowMap[column] = value
}
result = append(result, rowMap)
}
return result, err
}

View File

@@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
//_ "github.com/taosdata/driver-go/v3/taosWS"
"testing" "testing"
"time" "time"
) )
@@ -58,7 +59,7 @@ func TestBuilderInsert(t *testing.T) {
Ts: time.Now(), LoadUnitId: "负载单体ID", PInt: &p, NullInt32: sql.NullInt32{Int32: 32, Valid: true}, Ts: time.Now(), LoadUnitId: "负载单体ID", PInt: &p, NullInt32: sql.NullInt32{Int32: 32, Valid: true},
}, },
&TaosUser{ &TaosUser{
TaosTAG: &TaosTAG{DevId: "User001", DevType: "User类型", DataType: "User数据类型001"}, TaosTAG: &TaosTAG{DevId: "User001", DevType: "User类型", DataType: "User数据类型001", Alias: "三儿"},
Ts: time.Now(), Name: "张三", Ts: time.Now(), Name: "张三",
}, &TaosUser{ }, &TaosUser{
TaosTAG: &TaosTAG{DevId: "User002", DevType: "User类型", DataType: "User数据类型002"}, TaosTAG: &TaosTAG{DevId: "User002", DevType: "User类型", DataType: "User数据类型002"},
@@ -69,7 +70,7 @@ func TestBuilderInsert(t *testing.T) {
}, },
} }
insertSql, err := tdMapper.ToInsertSQL(data...) insertSql, err := tdMapper.ToInsertSQL(data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -91,12 +92,12 @@ func (u *User) SuperTableName() string {
func TestSimpleInsert(t *testing.T) { func TestSimpleInsert(t *testing.T) {
tdMapper := NewMapper() tdMapper := NewMapper()
data := []any{ data := []User{
&User{Name: "张三", Age: 18}, {Name: "张三", Age: 18},
&User{Name: "李四", Age: 20}, {Name: "李四", Age: 20},
} }
insertSql, err := tdMapper.ToInsertSQL(data...) insertSql, err := tdMapper.ToInsertSQL(data)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -124,15 +125,15 @@ func (s *SuperDev) TableName() string {
func TestSuperDevInsert(t *testing.T) { func TestSuperDevInsert(t *testing.T) {
var data = []any{ var data = []any{
&SuperDev{ SuperDev{
SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"}, SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"},
Ts: time.Now(), AppSn: "a0001", Ct: 1.0, Ts: time.Now(), AppSn: "a0001", Ct: 1.0,
}, },
&SuperDev{ SuperDev{
SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"}, SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"},
Ts: time.Now().Add(time.Second), AppSn: "a0002", Ct: 2.0, Ts: time.Now().Add(time.Second), AppSn: "a0002", Ct: 2.0,
}, },
&SuperDev{ SuperDev{
SuperDevTAG: SuperDevTAG{DevId: "SN002", DevType: "模拟设备"}, SuperDevTAG: SuperDevTAG{DevId: "SN002", DevType: "模拟设备"},
Ts: time.Now(), AppSn: "a0003", Ct: 3.0, Ts: time.Now(), AppSn: "a0003", Ct: 3.0,
}, },
@@ -195,3 +196,47 @@ func TestScanRows(t *testing.T) {
//marshal, _ := json.MarshalIndent(sdArray, "", " ") //marshal, _ := json.MarshalIndent(sdArray, "", " ")
//fmt.Println(len(sdArray), string(marshal)) //fmt.Println(len(sdArray), string(marshal))
} }
func TestScanRowsToMap(t *testing.T) {
/*
文档参考: https://docs.taosdata.com/reference/connector/go/#websocket-%E8%BF%9E%E6%8E%A5
go get github.com/taosdata/driver-go/v3
import _ "github.com/taosdata/driver-go/v3/taosWS"
超级表创建
CREATE STABLE `super_dev` (`ts` TIMESTAMP , `app_sn` VARCHAR(500) , `ct` INT ) TAGS (`dev_id` VARCHAR(50), `dev_type` VARCHAR(50))
批量插入
INSERT INTO
`dev_SN001` USING `super_dev` (`dev_id`,`dev_type`) TAGS ('SN001','模拟设备') (`ts`,`app_sn`,`ct`)
VALUES ('2024-09-18T16:22:17+08:00','a0001',1),('2024-09-18T16:22:18+08:00','a0002',2)
`dev_SN002` USING `super_dev` (`dev_id`,`dev_type`) TAGS ('SN002','模拟设备') (`ts`,`app_sn`,`ct`)
VALUES ('2024-09-18T16:22:17+08:00','a0003',3)
*/
var taosUri = "root:taosdata@localhost:6041/test"
db, err := sql.Open("taosWS", taosUri)
if err != nil {
t.Fatal(err)
}
rows, err := db.Query("select * from super_dev order by ts desc limit 100")
if err != nil {
t.Fatal(err)
}
defer func() { _ = rows.Close() }()
tdMapper := NewMapper()
toMap, err := tdMapper.ScanRowsToMap(rows)
if err != nil {
t.Fatal(err)
}
indent, _ := json.MarshalIndent(toMap, "", " ")
fmt.Println(string(indent))
fmt.Println("len:", len(toMap))
}

View File

@@ -81,7 +81,11 @@ func scan(data interface{}) (*StructMate, error) {
continue continue
} }
sr.Field2IndexCache[field.Name] = append(sr.Field2IndexCache[field.Name], i) if len(sr.Field2IndexCache[field.Name]) > 0 {
return nil, fmt.Errorf("duplicate field [%s %s `db:%s`]", field.Name, field.Type.Name(), columnName)
}
sr.Field2IndexCache[field.Name] = append([]int{}, i)
sr.DBName2IndexCache[columnName] = append(sr.DBName2IndexCache[columnName], i) sr.DBName2IndexCache[columnName] = append(sr.DBName2IndexCache[columnName], i)
sr.Filed2DBNameCache[field.Name] = columnName sr.Filed2DBNameCache[field.Name] = columnName