implement error handling for buffer reads

This commit is contained in:
mykola2312 2024-12-29 12:56:07 +02:00
parent 5342cfec63
commit ee24948b73
3 changed files with 77 additions and 23 deletions

View file

@ -2,6 +2,7 @@ package proto
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
) )
var NO = binary.BigEndian var NO = binary.BigEndian
@ -85,11 +86,14 @@ func (buf *LuxBuffer) WriteBytes(bytes []byte) {
copy(buf.WriteNext(len(bytes)), bytes) copy(buf.WriteNext(len(bytes)), bytes)
} }
func (buf *LuxBuffer) ReadNext(size int) []byte { func (buf *LuxBuffer) ReadNext(size int) ([]byte, error) {
if buf.offset+size > buf.len {
return nil, fmt.Errorf("ReadNext %d+%d > %d", buf.offset, size, buf.len)
}
next := buf.data[buf.offset : buf.offset+size] next := buf.data[buf.offset : buf.offset+size]
buf.offset += size buf.offset += size
return next return next, nil
} }
func (buf *LuxBuffer) Skip(skip int) { func (buf *LuxBuffer) Skip(skip int) {
@ -97,23 +101,40 @@ func (buf *LuxBuffer) Skip(skip int) {
} }
// explicit copy of read bytes // explicit copy of read bytes
func (buf *LuxBuffer) CopyBytes(size int) []byte { func (buf *LuxBuffer) CopyBytes(size int) ([]byte, error) {
bytes, err := buf.ReadNext(size)
if err != nil {
return nil, err
}
read := make([]byte, size) read := make([]byte, size)
copy(read, buf.ReadNext(size)) copy(read, bytes)
return read return read, nil
} }
func (buf *LuxBuffer) ReadUint16() uint16 { func (buf *LuxBuffer) ReadUint16() (uint16, error) {
return NO.Uint16(buf.ReadNext(2)) rd, err := buf.ReadNext(2)
if err != nil {
return 0, err
}
return NO.Uint16(rd), nil
} }
func (buf *LuxBuffer) ReadUint32() uint32 { func (buf *LuxBuffer) ReadUint32() (uint32, error) {
return NO.Uint32(buf.ReadNext(4)) rd, err := buf.ReadNext(4)
if err != nil {
return 0, err
}
return NO.Uint32(rd), nil
} }
func (buf *LuxBuffer) ReadUint64() uint64 { func (buf *LuxBuffer) ReadUint64() (uint64, error) {
return NO.Uint64(buf.ReadNext(8)) rd, err := buf.ReadNext(8)
if err != nil {
return 0, err
}
return NO.Uint64(rd), nil
} }
func (buf *LuxBuffer) WriteUint16(val uint16) { func (buf *LuxBuffer) WriteUint16(val uint16) {
@ -131,14 +152,21 @@ func (buf *LuxBuffer) WriteUint64(val uint64) {
// variable-length blocks can cause misaligned size, // variable-length blocks can cause misaligned size,
// so we make method for them, forcing padding // so we make method for them, forcing padding
func (buf *LuxBuffer) ReadVarBlock() []byte { func (buf *LuxBuffer) ReadVarBlock() ([]byte, error) {
len := buf.ReadUint16() len, err := buf.ReadUint16()
bytes := buf.ReadNext(int(len)) if err != nil {
return nil, err
}
bytes, err := buf.ReadNext(int(len))
if err != nil {
return nil, err
}
if len%2 != 0 { if len%2 != 0 {
buf.Skip(1) // skip padding buf.Skip(1) // skip padding
} }
return bytes return bytes, nil
} }
func (buf *LuxBuffer) WriteVarBlock(bytes []byte) { func (buf *LuxBuffer) WriteVarBlock(bytes []byte) {
@ -152,8 +180,12 @@ func (buf *LuxBuffer) WriteVarBlock(bytes []byte) {
} }
// strings in LUX protocol format will be padded to align of 2 // strings in LUX protocol format will be padded to align of 2
func (buf *LuxBuffer) ReadString() string { func (buf *LuxBuffer) ReadString() (string, error) {
return string(buf.ReadVarBlock()) block, err := buf.ReadVarBlock()
if err != nil {
return "", err
}
return string(block), nil
} }
func (buf *LuxBuffer) WriteString(val string) { func (buf *LuxBuffer) WriteString(val string) {

View file

@ -6,3 +6,12 @@ const (
LuxTypeHost = 0 LuxTypeHost = 0
LuxTypeNode LuxTypeNode
) )
func (luxType *LuxType) Read(rd *LuxBuffer) error {
if val, err := rd.ReadUint16(); err != nil {
return err
} else {
*luxType = LuxType(val)
return nil
}
}

View file

@ -29,16 +29,16 @@ func TestInts(t *testing.T) {
wd.WriteUint64(69) wd.WriteUint64(69)
wd.WriteUint32(2967279234) wd.WriteUint32(2967279234)
if rd := wd.ReadUint16(); rd != 1234 { if rd, _ := wd.ReadUint16(); rd != 1234 {
t.Fatal(rd) t.Fatal(rd)
} }
if rd := wd.ReadUint16(); rd != 4321 { if rd, _ := wd.ReadUint16(); rd != 4321 {
t.Fatal(rd) t.Fatal(rd)
} }
if rd := wd.ReadUint64(); rd != 69 { if rd, _ := wd.ReadUint64(); rd != 69 {
t.Fatal(rd) t.Fatal(rd)
} }
if rd := wd.ReadUint32(); rd != 2967279234 { if rd, _ := wd.ReadUint32(); rd != 2967279234 {
t.Fatal(rd) t.Fatal(rd)
} }
} }
@ -48,13 +48,26 @@ func TestStrings(t *testing.T) {
wd.WriteString("a") // 4 wd.WriteString("a") // 4
wd.WriteString("hello") // 8 wd.WriteString("hello") // 8
if rd := wd.ReadString(); rd != "a" { if rd, _ := wd.ReadString(); rd != "a" {
t.Fatal(rd) t.Fatal(rd)
} }
if rd := wd.ReadString(); rd != "hello" { if rd, _ := wd.ReadString(); rd != "hello" {
t.Fatal(rd) t.Fatal(rd)
} }
if wd.Length() != 12 { if wd.Length() != 12 {
t.Fatalf("string misaligned size: %d\n", wd.Length()) t.Fatalf("string misaligned size: %d\n", wd.Length())
} }
} }
func TestReadOverrun(t *testing.T) {
rd := proto.FromSlice(make([]byte, 2))
if _, err := rd.ReadNext(3); err == nil {
t.Fatalf("no error when rd.ReadNext(3) for %d\n", rd.Length())
}
rd = proto.FromSlice(make([]byte, 1))
rd.ReadNext(1)
if _, err := rd.ReadNext(1); err == nil {
t.Fatalf("no error when reading at offset 1 of 1 byte buffer\n")
}
}