From 340deecefb5dee542613c4e7ae3bd8fa9e2020eb Mon Sep 17 00:00:00 2001 From: zhoujie Date: Fri, 20 Jun 2025 14:13:24 +0800 Subject: [PATCH] =?UTF-8?q?feat(ToInsert):=20=E6=94=AF=E6=8C=81=E6=95=B0?= =?UTF-8?q?=E7=BB=84=E7=9B=B4=E6=8E=A5=E4=BC=A0=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mapping.go | 61 +++++++++++++++++++++++++++++++++++++++++++++++-- mapping_test.go | 16 ++++++------- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/mapping.go b/mapping.go index 0ec0796..f3af3c3 100644 --- a/mapping.go +++ b/mapping.go @@ -125,14 +125,71 @@ func (b *Mapper) extractStructData(data ...any) (map[string][]*TableRowMateria, return result, nil } +// normalizeData 对输入数据进行统一处理和规范化 +// 支持输入为任意数量的参数,参数可以是单个结构体指针或结构体,也可以是结构体切片或指针切片 +// 处理逻辑包括: +// 1. 将切片类型的参数拆开放入一个统一的一维切片中(展平) +// 2. 跳过其中的 nil 元素,确保数据有效性 +// 3. 将所有非指针类型的元素转换为对应的指针类型,方便后续统一处理 +func (b *Mapper) normalizeData(data ...any) []any { + var normalizedData []any + + for _, item := range data { + if item == nil { + // 跳过nil值,避免后续反射调用panic + continue + } + itemType := reflect.TypeOf(item) + if itemType.Kind() == reflect.Slice { + // 处理切片类型参数,展平成单个元素 + sliceValue := reflect.ValueOf(item) + for i := 0; i < sliceValue.Len(); i++ { + elem := sliceValue.Index(i).Interface() + if elem == nil { + // 跳过nil元素 + continue + } + normalizedData = append(normalizedData, elem) + } + } else { + // 非切片类型直接追加 + normalizedData = append(normalizedData, item) + } + } + + // 转换所有元素为指针类型,便于后续一致性处理 + for i, item := range normalizedData { + if item == nil { + continue + } + itemType := reflect.TypeOf(item) + if itemType.Kind() != reflect.Ptr { + // 创建对应类型的指针 + itemPtr := reflect.New(itemType) + // 设置指针指向的值为元素本身 + itemPtr.Elem().Set(reflect.ValueOf(item)) + // 替换切片中的元素为指针类型 + normalizedData[i] = itemPtr.Interface() + } + } + return normalizedData +} + func (b *Mapper) ToInsertSQL(data ...any) (string, error) { + // 统一规范化处理输入数据,将切片拆平并转换元素为指针 + normalizedData := b.normalizeData(data...) + + if len(normalizedData) == 0 { + return "", fmt.Errorf("data is empty") + } + // 扫描 struct 类型数据 - if err := b.scanStruct(data...); err != nil { + if err := b.scanStruct(normalizedData...); err != nil { return "", fmt.Errorf("failed to scan struct: %w", err) } // 提取 struct 内的信息 - tableMap, err := b.extractStructData(data...) + tableMap, err := b.extractStructData(normalizedData...) if err != nil { return "", fmt.Errorf("failed to extract struct data: %w", err) } diff --git a/mapping_test.go b/mapping_test.go index 1ecefde..dfe1519 100644 --- a/mapping_test.go +++ b/mapping_test.go @@ -69,7 +69,7 @@ func TestBuilderInsert(t *testing.T) { }, } - insertSql, err := tdMapper.ToInsertSQL(data...) + insertSql, err := tdMapper.ToInsertSQL(data) if err != nil { t.Fatal(err) } @@ -91,12 +91,12 @@ func (u *User) SuperTableName() string { func TestSimpleInsert(t *testing.T) { tdMapper := NewMapper() - data := []any{ - &User{Name: "张三", Age: 18}, - &User{Name: "李四", Age: 20}, + data := []User{ + {Name: "张三", Age: 18}, + {Name: "李四", Age: 20}, } - insertSql, err := tdMapper.ToInsertSQL(data...) + insertSql, err := tdMapper.ToInsertSQL(data) if err != nil { t.Fatal(err) } @@ -124,15 +124,15 @@ func (s *SuperDev) TableName() string { func TestSuperDevInsert(t *testing.T) { var data = []any{ - &SuperDev{ + SuperDev{ SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"}, Ts: time.Now(), AppSn: "a0001", Ct: 1.0, }, - &SuperDev{ + SuperDev{ SuperDevTAG: SuperDevTAG{DevId: "SN001", DevType: "模拟设备"}, Ts: time.Now().Add(time.Second), AppSn: "a0002", Ct: 2.0, }, - &SuperDev{ + SuperDev{ SuperDevTAG: SuperDevTAG{DevId: "SN002", DevType: "模拟设备"}, Ts: time.Now(), AppSn: "a0003", Ct: 3.0, },