From ee24948b73c89be61eb23b2f54df853f7a309a2b Mon Sep 17 00:00:00 2001 From: mykola2312 <49044616+mykola2312@users.noreply.github.com> Date: Sun, 29 Dec 2024 12:56:07 +0200 Subject: [PATCH] implement error handling for buffer reads --- proto/lux_buffer.go | 66 +++++++++++++++++++++++++++++----------- proto/lux_type.go | 9 ++++++ tests/lux_buffer_test.go | 25 +++++++++++---- 3 files changed, 77 insertions(+), 23 deletions(-) diff --git a/proto/lux_buffer.go b/proto/lux_buffer.go index 86d5f30..97051e6 100644 --- a/proto/lux_buffer.go +++ b/proto/lux_buffer.go @@ -2,6 +2,7 @@ package proto import ( "encoding/binary" + "fmt" ) var NO = binary.BigEndian @@ -85,11 +86,14 @@ func (buf *LuxBuffer) WriteBytes(bytes []byte) { copy(buf.WriteNext(len(bytes)), bytes) } -func (buf *LuxBuffer) ReadNext(size int) []byte { +func (buf *LuxBuffer) ReadNext(size int) ([]byte, error) { + if buf.offset+size > buf.len { + return nil, fmt.Errorf("ReadNext %d+%d > %d", buf.offset, size, buf.len) + } next := buf.data[buf.offset : buf.offset+size] buf.offset += size - return next + return next, nil } func (buf *LuxBuffer) Skip(skip int) { @@ -97,23 +101,40 @@ func (buf *LuxBuffer) Skip(skip int) { } // explicit copy of read bytes -func (buf *LuxBuffer) CopyBytes(size int) []byte { +func (buf *LuxBuffer) CopyBytes(size int) ([]byte, error) { + bytes, err := buf.ReadNext(size) + if err != nil { + return nil, err + } + read := make([]byte, size) - copy(read, buf.ReadNext(size)) + copy(read, bytes) - return read + return read, nil } -func (buf *LuxBuffer) ReadUint16() uint16 { - return NO.Uint16(buf.ReadNext(2)) +func (buf *LuxBuffer) ReadUint16() (uint16, error) { + rd, err := buf.ReadNext(2) + if err != nil { + return 0, err + } + return NO.Uint16(rd), nil } -func (buf *LuxBuffer) ReadUint32() uint32 { - return NO.Uint32(buf.ReadNext(4)) +func (buf *LuxBuffer) ReadUint32() (uint32, error) { + rd, err := buf.ReadNext(4) + if err != nil { + return 0, err + } + return NO.Uint32(rd), nil } -func (buf *LuxBuffer) ReadUint64() uint64 { - return NO.Uint64(buf.ReadNext(8)) +func (buf *LuxBuffer) ReadUint64() (uint64, error) { + rd, err := buf.ReadNext(8) + if err != nil { + return 0, err + } + return NO.Uint64(rd), nil } func (buf *LuxBuffer) WriteUint16(val uint16) { @@ -131,14 +152,21 @@ func (buf *LuxBuffer) WriteUint64(val uint64) { // variable-length blocks can cause misaligned size, // so we make method for them, forcing padding -func (buf *LuxBuffer) ReadVarBlock() []byte { - len := buf.ReadUint16() - bytes := buf.ReadNext(int(len)) +func (buf *LuxBuffer) ReadVarBlock() ([]byte, error) { + len, err := buf.ReadUint16() + if err != nil { + return nil, err + } + bytes, err := buf.ReadNext(int(len)) + if err != nil { + return nil, err + } + if len%2 != 0 { buf.Skip(1) // skip padding } - return bytes + return bytes, nil } func (buf *LuxBuffer) WriteVarBlock(bytes []byte) { @@ -152,8 +180,12 @@ func (buf *LuxBuffer) WriteVarBlock(bytes []byte) { } // strings in LUX protocol format will be padded to align of 2 -func (buf *LuxBuffer) ReadString() string { - return string(buf.ReadVarBlock()) +func (buf *LuxBuffer) ReadString() (string, error) { + block, err := buf.ReadVarBlock() + if err != nil { + return "", err + } + return string(block), nil } func (buf *LuxBuffer) WriteString(val string) { diff --git a/proto/lux_type.go b/proto/lux_type.go index ccb7823..ca6d832 100644 --- a/proto/lux_type.go +++ b/proto/lux_type.go @@ -6,3 +6,12 @@ const ( LuxTypeHost = 0 LuxTypeNode ) + +func (luxType *LuxType) Read(rd *LuxBuffer) error { + if val, err := rd.ReadUint16(); err != nil { + return err + } else { + *luxType = LuxType(val) + return nil + } +} diff --git a/tests/lux_buffer_test.go b/tests/lux_buffer_test.go index 00c361f..6a95287 100644 --- a/tests/lux_buffer_test.go +++ b/tests/lux_buffer_test.go @@ -29,16 +29,16 @@ func TestInts(t *testing.T) { wd.WriteUint64(69) wd.WriteUint32(2967279234) - if rd := wd.ReadUint16(); rd != 1234 { + if rd, _ := wd.ReadUint16(); rd != 1234 { t.Fatal(rd) } - if rd := wd.ReadUint16(); rd != 4321 { + if rd, _ := wd.ReadUint16(); rd != 4321 { t.Fatal(rd) } - if rd := wd.ReadUint64(); rd != 69 { + if rd, _ := wd.ReadUint64(); rd != 69 { t.Fatal(rd) } - if rd := wd.ReadUint32(); rd != 2967279234 { + if rd, _ := wd.ReadUint32(); rd != 2967279234 { t.Fatal(rd) } } @@ -48,13 +48,26 @@ func TestStrings(t *testing.T) { wd.WriteString("a") // 4 wd.WriteString("hello") // 8 - if rd := wd.ReadString(); rd != "a" { + if rd, _ := wd.ReadString(); rd != "a" { t.Fatal(rd) } - if rd := wd.ReadString(); rd != "hello" { + if rd, _ := wd.ReadString(); rd != "hello" { t.Fatal(rd) } if wd.Length() != 12 { t.Fatalf("string misaligned size: %d\n", wd.Length()) } } + +func TestReadOverrun(t *testing.T) { + rd := proto.FromSlice(make([]byte, 2)) + if _, err := rd.ReadNext(3); err == nil { + t.Fatalf("no error when rd.ReadNext(3) for %d\n", rd.Length()) + } + + rd = proto.FromSlice(make([]byte, 1)) + rd.ReadNext(1) + if _, err := rd.ReadNext(1); err == nil { + t.Fatalf("no error when reading at offset 1 of 1 byte buffer\n") + } +}