mirror of
https://github.com/ollama/ollama.git
synced 2026-01-02 12:38:15 -05:00
Compare commits
3 Commits
implement-
...
mxyng/16-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69f3dfdedf | ||
|
|
7bd3f0269c | ||
|
|
276c4df770 |
21
convert/bfloat16/bfloat16.go
Normal file
21
convert/bfloat16/bfloat16.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package bfloat16
|
||||
|
||||
import "math"
|
||||
|
||||
// FromFloat32s converts a slice of float32 values to a slice of bfloat16 values, represented as uint16s.
|
||||
func FromFloat32s(f32s []float32) (u16s []uint16) {
|
||||
u16s = make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
u16s[i] = uint16(math.Float32bits(f32s[i]) >> 16)
|
||||
}
|
||||
return u16s
|
||||
}
|
||||
|
||||
// Float32s converts a slice of bfloat16 values, represented as uint16s, back to a slice of float32 values.
|
||||
func Float32s(u16s []uint16) (f32s []float32) {
|
||||
f32s = make([]float32, len(u16s))
|
||||
for i := range u16s {
|
||||
f32s[i] = math.Float32frombits(uint32(u16s[i]) << 16)
|
||||
}
|
||||
return f32s
|
||||
}
|
||||
82
convert/bfloat16/bfloat16_test.go
Normal file
82
convert/bfloat16/bfloat16_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package bfloat16
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestBfloat16(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input uint16
|
||||
want uint32
|
||||
}{
|
||||
// Zero cases
|
||||
{"positive zero", 0x0000, 0x0},
|
||||
{"negative zero", 0x8000, 0x80000000},
|
||||
|
||||
// Normal numbers
|
||||
{"one", 0x3F80, 0x3F800000},
|
||||
{"negative one", 0xBF80, 0xBF800000},
|
||||
{"two", 0x4000, 0x40000000},
|
||||
{"half", 0x3F00, 0x3F000000},
|
||||
{"quarter", 0x3E80, 0x3E800000},
|
||||
{"max finite", 0x7F7F, 0x7F7F0000},
|
||||
{"min positive normal", 0x0080, 0x00800000},
|
||||
|
||||
// Infinity cases
|
||||
{"positive infinity", 0x7F80, 0x7F800000},
|
||||
{"negative infinity", 0xFF80, 0xFF800000},
|
||||
|
||||
// NaN cases
|
||||
{"NaN", 0x7FC0, 0x7FC00000},
|
||||
{"NaN with payload", 0x7FC1, 0x7FC10000},
|
||||
|
||||
// Subnormal cases
|
||||
{"min positive subnormal", 0x0001, 0x00010000},
|
||||
{"max subnormal", 0x007F, 0x007F0000},
|
||||
|
||||
// Powers of 2
|
||||
{"2^10", 0x4480, 0x44800000},
|
||||
{"2^-10", 0x3A80, 0x3A800000},
|
||||
{"2^20", 0x4B80, 0x4B800000},
|
||||
|
||||
// Common approximations in BF16
|
||||
{"pi approximation", 0x4049, 0x40490000},
|
||||
{"e approximation", 0x402E, 0x402E0000},
|
||||
{"sqrt(2) approximation", 0x3FB5, 0x3FB50000},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Float32s", func(t *testing.T) {
|
||||
got := Float32s([]uint16{tt.input})[0]
|
||||
if diff := cmp.Diff(tt.want, math.Float32bits(got)); diff != "" {
|
||||
t.Errorf("Float32s mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FromFloat32s", func(t *testing.T) {
|
||||
got := FromFloat32s([]float32{math.Float32frombits(tt.want)})
|
||||
if diff := cmp.Diff([]uint16{tt.input}, got); diff != "" {
|
||||
t.Errorf("FromFloat32s mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBfloat16(b *testing.B) {
|
||||
f32s := make([]float32, 1_000_000)
|
||||
for i := range f32s {
|
||||
f32s[i] = rand.Float32()
|
||||
}
|
||||
for b.Loop() {
|
||||
Float32s(FromFloat32s(f32s))
|
||||
}
|
||||
}
|
||||
97
convert/float16/float16.go
Normal file
97
convert/float16/float16.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package float16
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
func FromFloat32s(f32s []float32) (u16s []uint16) {
|
||||
u16s = make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
bits := math.Float32bits(f32s[i])
|
||||
sign := (bits >> 31) & 0x1
|
||||
exponent := (bits >> 23) & 0xFF
|
||||
mantissa := bits & 0x7FFFFF
|
||||
if exponent == 0xFF {
|
||||
if mantissa == 0 {
|
||||
// Infinity
|
||||
u16s[i] = uint16((sign << 15) | 0x7C00)
|
||||
} else {
|
||||
// NaN
|
||||
u16s[i] = uint16((sign << 15) | 0x7C00 | (mantissa >> 13))
|
||||
}
|
||||
} else if exponent == 0 && mantissa == 0 {
|
||||
// Zero
|
||||
u16s[i] = uint16(sign << 15)
|
||||
} else {
|
||||
// Convert exponent from FP32 bias (127) to FP16 bias (15)
|
||||
exponent := int(exponent) - 127 + 15
|
||||
if exponent >= 31 {
|
||||
// Overflow to infinity
|
||||
u16s[i] = uint16((sign << 15) | 0x7C00)
|
||||
} else if exponent <= 0 {
|
||||
// Underflow - create subnormal or zero
|
||||
if exponent < -10 {
|
||||
u16s[i] = uint16(sign << 15) // Zero
|
||||
} else {
|
||||
// Subnormal number
|
||||
mantissa = (mantissa | 0x800000) >> uint(-exponent+1)
|
||||
u16s[i] = uint16((sign << 15) | (mantissa >> 13))
|
||||
}
|
||||
} else {
|
||||
// Normal number - truncate mantissa from 23 to 10 bits
|
||||
u16s[i] = uint16((sign << 15) | (uint32(exponent) << 10) | (mantissa >> 13))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return u16s
|
||||
}
|
||||
|
||||
func Float32s(u16s []uint16) (f32s []float32) {
|
||||
f32s = make([]float32, len(u16s))
|
||||
for i := range u16s {
|
||||
sign := (u16s[i] >> 15) & 0x1
|
||||
exponent := (u16s[i] >> 10) & 0x1F
|
||||
mantissa := u16s[i] & 0x3FF
|
||||
|
||||
var u32 uint32
|
||||
switch exponent {
|
||||
case 0:
|
||||
if mantissa == 0 {
|
||||
// Zero
|
||||
u32 = uint32(sign) << 31
|
||||
} else {
|
||||
// Subnormal - convert to normal
|
||||
// Find leading 1 bit
|
||||
shift := 0
|
||||
temp := mantissa
|
||||
for temp&0x400 == 0 {
|
||||
temp <<= 1
|
||||
shift++
|
||||
}
|
||||
|
||||
exponent := 127 - 15 + 1 - shift
|
||||
mantissa := (uint32(temp&0x3FF) << 13)
|
||||
|
||||
u32 = (uint32(sign) << 31) | (uint32(exponent) << 23) | mantissa
|
||||
}
|
||||
case 0x1F:
|
||||
if mantissa == 0 {
|
||||
// Infinity
|
||||
u32 = (uint32(sign) << 31) | 0x7F800000
|
||||
} else {
|
||||
// NaN
|
||||
u32 = (uint32(sign) << 31) | 0x7F800000 | (uint32(mantissa) << 13)
|
||||
}
|
||||
default:
|
||||
// Normal number
|
||||
exponent := uint32(exponent) - 15 + 127
|
||||
mantissa := uint32(mantissa) << 13
|
||||
|
||||
u32 = (uint32(sign) << 31) | (exponent << 23) | mantissa
|
||||
}
|
||||
|
||||
f32s[i] = math.Float32frombits(u32)
|
||||
}
|
||||
return f32s
|
||||
}
|
||||
75
convert/float16/float16_test.go
Normal file
75
convert/float16/float16_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package float16
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestFloat16(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input uint16
|
||||
want uint32
|
||||
}{
|
||||
// Zero cases
|
||||
{"positive zero", 0x0000, 0x0},
|
||||
{"negative zero", 0x8000, 0x80000000},
|
||||
|
||||
// Normal numbers
|
||||
{"one", 0x3C00, 0x3F800000},
|
||||
{"negative one", 0xBC00, 0xBF800000},
|
||||
{"two", 0x4000, 0x40000000},
|
||||
{"half", 0x3800, 0x3F000000},
|
||||
{"max normal", 0x7BFF, 0x477fe000},
|
||||
{"min positive normal", 0x0400, 0x38800000},
|
||||
|
||||
// Infinity cases
|
||||
{"positive infinity", 0x7C00, 0x7F800000},
|
||||
{"negative infinity", 0xFC00, 0xFF800000},
|
||||
|
||||
// NaN cases
|
||||
{"NaN", 0x7C01, 0x7f802000},
|
||||
{"NaN with payload", 0x7E00, 0x7FC00000},
|
||||
|
||||
// Subnormal cases
|
||||
{"min positive subnormal", 0x0001, 0x33800000},
|
||||
{"max subnormal", 0x03FF, 0x387fc000},
|
||||
|
||||
// Common values
|
||||
{"pi approximation", 0x4248, 0x40490000},
|
||||
{"e approximation", 0x416F, 0x402de000},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Float32s", func(t *testing.T) {
|
||||
got := Float32s([]uint16{tt.input})[0]
|
||||
if diff := cmp.Diff(tt.want, math.Float32bits(got)); diff != "" {
|
||||
t.Errorf("Float32s mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FromFloat32s", func(t *testing.T) {
|
||||
got := FromFloat32s([]float32{math.Float32frombits(tt.want)})
|
||||
if diff := cmp.Diff([]uint16{tt.input}, got); diff != "" {
|
||||
t.Errorf("FromFloat32s mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFloat16(b *testing.B) {
|
||||
f32s := make([]float32, 1_000_000)
|
||||
for i := range f32s {
|
||||
f32s[i] = rand.Float32()
|
||||
}
|
||||
for b.Loop() {
|
||||
Float32s(FromFloat32s(f32s))
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"github.com/ollama/ollama/convert/bfloat16"
|
||||
"github.com/ollama/ollama/convert/float16"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
@@ -163,18 +163,16 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f32s = make([]float32, len(u16s))
|
||||
for i := range u16s {
|
||||
f32s[i] = float16.Frombits(u16s[i]).Float32()
|
||||
}
|
||||
f32s = float16.Float32s(u16s)
|
||||
|
||||
case "BF16":
|
||||
u8s := make([]uint8, st.size)
|
||||
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||
u16s := make([]uint16, st.size/2)
|
||||
if err = binary.Read(br, binary.LittleEndian, u16s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f32s = bfloat16.DecodeFloat32(u8s)
|
||||
f32s = bfloat16.Float32s(u16s)
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
|
||||
}
|
||||
@@ -190,15 +188,9 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
case tensorKindFP32:
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, float16.FromFloat32s(f32s))
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, u8s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, bfloat16.FromFloat32s(f32s))
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/x448/float16"
|
||||
"github.com/ollama/ollama/convert/bfloat16"
|
||||
"github.com/ollama/ollama/convert/float16"
|
||||
)
|
||||
|
||||
func TestSafetensors(t *testing.T) {
|
||||
@@ -21,6 +21,11 @@ func TestSafetensors(t *testing.T) {
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name,
|
||||
dtype string
|
||||
@@ -36,11 +41,6 @@ func TestSafetensors(t *testing.T) {
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -62,11 +62,6 @@ func TestSafetensors(t *testing.T) {
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -84,12 +79,7 @@ func TestSafetensors(t *testing.T) {
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, float16.FromFloat32s(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
@@ -106,12 +96,7 @@ func TestSafetensors(t *testing.T) {
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, float16.FromFloat32s(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
@@ -132,12 +117,7 @@ func TestSafetensors(t *testing.T) {
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.FromFloat32s(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
@@ -154,12 +134,7 @@ func TestSafetensors(t *testing.T) {
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.FromFloat32s(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
|
||||
2
go.mod
2
go.mod
@@ -10,13 +10,11 @@ require (
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.12.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/google/go-cmp v0.7.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
|
||||
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -197,8 +195,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
|
||||
Reference in New Issue
Block a user