diff --git a/internal/btree/btree.go b/internal/btree/btree.go index ee74b62..5a3c34a 100644 --- a/internal/btree/btree.go +++ b/internal/btree/btree.go @@ -2,7 +2,6 @@ package btree import ( - "encoding/binary" "io" "math/rand" "os" @@ -15,7 +14,7 @@ const ( sha1Size = 20 // md5 sha1 tableSize = (4096 - 1) / int(unsafe.Sizeof(item{})) cacheSlots = 23 // prime - superSize = unsafe.Sizeof(super{}) + superSize = int(unsafe.Sizeof(super{})) tableStructSize = int(unsafe.Sizeof(table{})) ) @@ -65,7 +64,7 @@ func (bt *Btree) get(offset int64) *table { table := new(table) bt.fd.Seek(offset, io.SeekStart) - err := binary.Read(bt.fd, binary.LittleEndian, table) // todo(wdvxdr): efficient reading + err := readTable(bt.fd, table) if err != nil { panic(errors.Wrap(err, "btree I/O error")) } @@ -85,7 +84,7 @@ func (bt *Btree) flush(t *table, offset int64) { assert(offset != 0) bt.fd.Seek(offset, io.SeekStart) - err := binary.Write(bt.fd, binary.LittleEndian, t) + err := writeTable(bt.fd, t) if err != nil { panic(errors.Wrap(err, "btree I/O error")) } @@ -99,7 +98,7 @@ func (bt *Btree) flushSuper() { freeTop: bt.freeTop, alloc: bt.alloc, } - err := binary.Write(bt.fd, binary.LittleEndian, super) + err := writeSuper(bt.fd, &super) if err != nil { panic(errors.Wrap(err, "btree I/O error")) } @@ -115,7 +114,7 @@ func Open(name string) (*Btree, error) { btree.fd = fd super := super{} - err = binary.Read(fd, binary.LittleEndian, &super) + err = readSuper(fd, &super) btree.top = super.top btree.freeTop = super.freeTop btree.alloc = super.alloc @@ -138,6 +137,7 @@ func Create(name string) (*Btree, error) { // Close closes the database func (bt *Btree) Close() error { + _ = bt.fd.Sync() err := bt.fd.Close() for i := 0; i < cacheSlots; i++ { bt.cache[i] = cache{} @@ -256,10 +256,7 @@ func (bt *Btree) remove(t *table, i int, sha1 *byte) int64 { } else { // memmove(&table->items[i], &table->items[i + 1], // (table->size - i) * sizeof(struct btree_item)); - // table->size--; - for j := i; j < t.size-i; j++ { // fuck you, go! - t.items[j] = t.items[j+1] - } + copy(t.items[i:], t.items[i+1:]) t.size-- if lc != 0 { @@ -319,9 +316,9 @@ func (bt *Btree) insert(toff int64, sha1 *byte, data []byte, size int) int64 { } table.size++ - // todo: // memmove(&table->items[i + 1], &table->items[i], - // (table->size - i) * sizeof(struct btree_item)); + // (table->size - i) * sizeof(struct btree_item)); + copy(table.items[i+1:], table.items[i:]) copysha1(&table.items[i].sha1[0], sha1) table.items[i].offset = off table.items[i].child = lc @@ -340,7 +337,7 @@ func (bt *Btree) insertData(data []byte, size int) int64 { offset := bt.allocChunk(4 + len(data)) bt.fd.Seek(offset, io.SeekStart) - err := binary.Write(bt.fd, binary.LittleEndian, int32(len(data))) + err := write32(bt.fd, int32(len(data))) if err != nil { panic(errors.Wrap(err, "btree I/O error")) } @@ -486,8 +483,7 @@ func (bt *Btree) Get(sha1 *byte) []byte { } bt.fd.Seek(off, io.SeekStart) - var length int32 - err := binary.Read(bt.fd, binary.LittleEndian, &length) + length, err := read32(bt.fd) if err != nil { return nil } diff --git a/internal/btree/btree_test.go b/internal/btree/btree_test.go new file mode 100644 index 0000000..9de67b0 --- /dev/null +++ b/internal/btree/btree_test.go @@ -0,0 +1,51 @@ +package btree + +import ( + "crypto/sha1" + "os" + "testing" + + assert2 "github.com/stretchr/testify/assert" +) + +func tempfile(t *testing.T) string { + temp, err := os.CreateTemp("", "temp.*.db") + assert2.NoError(t, temp.Close()) + assert2.NoError(t, err) + return temp.Name() +} + +func TestCreate(t *testing.T) { + f := tempfile(t) + _, err := Create(f) + assert2.NoError(t, err) + defer os.Remove(f) +} + +func TestBtree(t *testing.T) { + f := tempfile(t) + defer os.Remove(f) + bt, err := Create(f) + assert2.NoError(t, err) + + var tests = []string{ + "hello world", + "123", + "We are met on a great battle-field of that war.", + "Abraham Lincoln, November 19, 1863, Gettysburg, Pennsylvania", + } + var sha = make([]*byte, len(tests)) + for i, tt := range tests { + var hash = sha1.New() + hash.Write([]byte(tt)) + sha[i] = &hash.Sum(nil)[0] + bt.Insert(sha[i], []byte(tt)) + } + assert2.NoError(t, bt.Close()) + + bt, err = Open(f) + for i, tt := range tests { + assert2.Equal(t, []byte(tt), bt.Get(sha[i])) + } + assert2.NoError(t, bt.Close()) +} diff --git a/internal/btree/helper.go b/internal/btree/helper.go index 518ec66..cbff971 100644 --- a/internal/btree/helper.go +++ b/internal/btree/helper.go @@ -1,6 +1,10 @@ package btree -import "unsafe" +import ( + "encoding/binary" + "io" + "unsafe" +) func assert(cond bool) { if !cond { @@ -19,16 +23,16 @@ func power2(val int) int { // helpers for sha1 -func cmp(a, b *byte) uint64 { +func cmp(a, b *byte) int64 { pa, pb := unsafe.Pointer(a), unsafe.Pointer(b) if *(*uint64)(pa) != *(*uint64)(pb) { - return *(*uint64)(pa) - *(*uint64)(pb) + return int64(*(*uint64)(pa) - *(*uint64)(pb)) } pa, pb = unsafe.Add(pa, 8), unsafe.Add(pb, 8) if *(*uint64)(pa) != *(*uint64)(pb) { - return *(*uint64)(pa) - *(*uint64)(pb) + return int64(*(*uint64)(pa) - *(*uint64)(pb)) } - return uint64(*(*uint32)(unsafe.Add(pa, 8)) - *(*uint32)(unsafe.Add(pa, 8))) + return int64(*(*uint32)(unsafe.Add(pa, 8)) - *(*uint32)(unsafe.Add(pa, 8))) } func copysha1(dst *byte, src *byte) { @@ -40,3 +44,122 @@ func resetsha1(sha1 *byte) { p := unsafe.Pointer(sha1) *(*[sha1Size]byte)(p) = [sha1Size]byte{} } + +// reading table + +func read64(r io.Reader) (int64, error) { + var b = make([]byte, 8) + _, err := r.Read(b) + if err != nil { + return 0, err + } + return int64(binary.LittleEndian.Uint64(b)), nil +} + +func read32(r io.Reader) (int32, error) { + var b = make([]byte, 4) + _, err := r.Read(b) + if err != nil { + return 0, err + } + return int32(binary.LittleEndian.Uint32(b)), nil +} + +func readTable(r io.Reader, t *table) error { + for i := 0; i < tableSize; i++ { + err := readItem(r, &t.items[i]) + if err != nil { + return err + } + } + switch unsafe.Sizeof(0) { + case 8: + i, err := read64(r) + t.size = int(i) + return err + case 4: + i, err := read32(r) + t.size = int(i) + return err + default: + panic("unreachable") + } +} + +func readItem(r io.Reader, i *item) error { + _, err := r.Read(i.sha1[:]) + if err != nil { + return err + } + i.offset, err = read64(r) + if err != nil { + return err + } + i.child, err = read64(r) + return err +} + +func readSuper(r io.Reader, s *super) error { + var err error + if s.top, err = read64(r); err != nil { + return err + } + if s.freeTop, err = read64(r); err != nil { + return err + } + s.alloc, err = read64(r) + return err +} + +// write table + +func write64(w io.Writer, i int64) error { + var b = make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(i)) + _, err := w.Write(b) + return err +} + +func write32(w io.Writer, i int32) error { + var b = make([]byte, 4) + binary.LittleEndian.PutUint32(b, uint32(i)) + _, err := w.Write(b) + return err +} + +func writeTable(w io.Writer, t *table) error { + for i := 0; i < tableSize; i++ { + err := writeItem(w, &t.items[i]) + if err != nil { + return err + } + } + switch unsafe.Sizeof(0) { + case 8: + return write64(w, int64(t.size)) + case 4: + return write32(w, int32(t.size)) + default: + panic("unreachable") + } +} + +func writeItem(w io.Writer, i *item) error { + if _, err := w.Write(i.sha1[:]); err != nil { + return err + } + if err := write64(w, i.offset); err != nil { + return err + } + return write64(w, i.child) +} + +func writeSuper(w io.Writer, s *super) error { + if err := write64(w, s.top); err != nil { + return err + } + if err := write64(w, s.freeTop); err != nil { + return err + } + return write64(w, s.alloc) +}