From a03e8acee071e8882599bcbf1672e93f0476b8da Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Fri, 28 Apr 2023 13:38:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8C=89=E6=9D=83=E9=87=8D=E9=9A=8F=E6=9C=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/random/number.go | 8 ++++++++ utils/random/weight.go | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 utils/random/weight.go diff --git a/utils/random/number.go b/utils/random/number.go index 11cf867..a6d5f90 100644 --- a/utils/random/number.go +++ b/utils/random/number.go @@ -21,3 +21,11 @@ func Float64() float64 { func Float32() float32 { return rand.Float32() } + +// IntN 返回一个0~n的整数 +func IntN(n int) int { + if n <= 0 { + return 0 + } + return rand.Intn(n) +} diff --git a/utils/random/weight.go b/utils/random/weight.go new file mode 100644 index 0000000..7bf4ce1 --- /dev/null +++ b/utils/random/weight.go @@ -0,0 +1,45 @@ +package random + +// WeightSlice 按权重随机从切片中产生一个数据并返回 +func WeightSlice[T any](getWeightHandle func(data T) int, data ...T) T { + var total int + var overlayWeight []int + for _, d := range data { + total += getWeightHandle(d) + overlayWeight = append(overlayWeight, total) + } + var r = IntN(total) + var i, count = 0, len(overlayWeight) + for i < count { + h := int(uint(i+count) >> 1) + if overlayWeight[h] < r { + i = h + 1 + } else { + count = h + } + } + return data[i] +} + +// WeightMap 按权重随机从map中产生一个数据并返回 +func WeightMap[K comparable, T any](getWeightHandle func(data T) int, data map[K]T) T { + var total int + var overlayWeight []int + var dataSlice = make([]T, 0, len(data)) + for _, d := range data { + total += getWeightHandle(d) + dataSlice = append(dataSlice, d) + overlayWeight = append(overlayWeight, total) + } + var r = IntN(total) + var i, count = 0, len(overlayWeight) + for i < count { + h := int(uint(i+count) >> 1) + if overlayWeight[h] < r { + i = h + 1 + } else { + count = h + } + } + return dataSlice[i] +}