implement error handling for buffer reads

This commit is contained in:
mykola2312 2024-12-29 12:56:07 +02:00
parent 5342cfec63
commit ee24948b73
3 changed files with 77 additions and 23 deletions

View file

@ -2,6 +2,7 @@ package proto
import (
"encoding/binary"
"fmt"
)
var NO = binary.BigEndian
@ -85,11 +86,14 @@ func (buf *LuxBuffer) WriteBytes(bytes []byte) {
copy(buf.WriteNext(len(bytes)), bytes)
}
func (buf *LuxBuffer) ReadNext(size int) []byte {
func (buf *LuxBuffer) ReadNext(size int) ([]byte, error) {
if buf.offset+size > buf.len {
return nil, fmt.Errorf("ReadNext %d+%d > %d", buf.offset, size, buf.len)
}
next := buf.data[buf.offset : buf.offset+size]
buf.offset += size
return next
return next, nil
}
func (buf *LuxBuffer) Skip(skip int) {
@ -97,23 +101,40 @@ func (buf *LuxBuffer) Skip(skip int) {
}
// explicit copy of read bytes
func (buf *LuxBuffer) CopyBytes(size int) []byte {
func (buf *LuxBuffer) CopyBytes(size int) ([]byte, error) {
bytes, err := buf.ReadNext(size)
if err != nil {
return nil, err
}
read := make([]byte, size)
copy(read, buf.ReadNext(size))
copy(read, bytes)
return read
return read, nil
}
func (buf *LuxBuffer) ReadUint16() uint16 {
return NO.Uint16(buf.ReadNext(2))
func (buf *LuxBuffer) ReadUint16() (uint16, error) {
rd, err := buf.ReadNext(2)
if err != nil {
return 0, err
}
return NO.Uint16(rd), nil
}
func (buf *LuxBuffer) ReadUint32() uint32 {
return NO.Uint32(buf.ReadNext(4))
func (buf *LuxBuffer) ReadUint32() (uint32, error) {
rd, err := buf.ReadNext(4)
if err != nil {
return 0, err
}
return NO.Uint32(rd), nil
}
func (buf *LuxBuffer) ReadUint64() uint64 {
return NO.Uint64(buf.ReadNext(8))
func (buf *LuxBuffer) ReadUint64() (uint64, error) {
rd, err := buf.ReadNext(8)
if err != nil {
return 0, err
}
return NO.Uint64(rd), nil
}
func (buf *LuxBuffer) WriteUint16(val uint16) {
@ -131,14 +152,21 @@ func (buf *LuxBuffer) WriteUint64(val uint64) {
// variable-length blocks can cause misaligned size,
// so we make method for them, forcing padding
func (buf *LuxBuffer) ReadVarBlock() []byte {
len := buf.ReadUint16()
bytes := buf.ReadNext(int(len))
func (buf *LuxBuffer) ReadVarBlock() ([]byte, error) {
len, err := buf.ReadUint16()
if err != nil {
return nil, err
}
bytes, err := buf.ReadNext(int(len))
if err != nil {
return nil, err
}
if len%2 != 0 {
buf.Skip(1) // skip padding
}
return bytes
return bytes, nil
}
func (buf *LuxBuffer) WriteVarBlock(bytes []byte) {
@ -152,8 +180,12 @@ func (buf *LuxBuffer) WriteVarBlock(bytes []byte) {
}
// strings in LUX protocol format will be padded to align of 2
func (buf *LuxBuffer) ReadString() string {
return string(buf.ReadVarBlock())
func (buf *LuxBuffer) ReadString() (string, error) {
block, err := buf.ReadVarBlock()
if err != nil {
return "", err
}
return string(block), nil
}
func (buf *LuxBuffer) WriteString(val string) {

View file

@ -6,3 +6,12 @@ const (
LuxTypeHost = 0
LuxTypeNode
)
func (luxType *LuxType) Read(rd *LuxBuffer) error {
if val, err := rd.ReadUint16(); err != nil {
return err
} else {
*luxType = LuxType(val)
return nil
}
}

View file

@ -29,16 +29,16 @@ func TestInts(t *testing.T) {
wd.WriteUint64(69)
wd.WriteUint32(2967279234)
if rd := wd.ReadUint16(); rd != 1234 {
if rd, _ := wd.ReadUint16(); rd != 1234 {
t.Fatal(rd)
}
if rd := wd.ReadUint16(); rd != 4321 {
if rd, _ := wd.ReadUint16(); rd != 4321 {
t.Fatal(rd)
}
if rd := wd.ReadUint64(); rd != 69 {
if rd, _ := wd.ReadUint64(); rd != 69 {
t.Fatal(rd)
}
if rd := wd.ReadUint32(); rd != 2967279234 {
if rd, _ := wd.ReadUint32(); rd != 2967279234 {
t.Fatal(rd)
}
}
@ -48,13 +48,26 @@ func TestStrings(t *testing.T) {
wd.WriteString("a") // 4
wd.WriteString("hello") // 8
if rd := wd.ReadString(); rd != "a" {
if rd, _ := wd.ReadString(); rd != "a" {
t.Fatal(rd)
}
if rd := wd.ReadString(); rd != "hello" {
if rd, _ := wd.ReadString(); rd != "hello" {
t.Fatal(rd)
}
if wd.Length() != 12 {
t.Fatalf("string misaligned size: %d\n", wd.Length())
}
}
func TestReadOverrun(t *testing.T) {
rd := proto.FromSlice(make([]byte, 2))
if _, err := rd.ReadNext(3); err == nil {
t.Fatalf("no error when rd.ReadNext(3) for %d\n", rd.Length())
}
rd = proto.FromSlice(make([]byte, 1))
rd.ReadNext(1)
if _, err := rd.ReadNext(1); err == nil {
t.Fatalf("no error when reading at offset 1 of 1 byte buffer\n")
}
}