implement error handling for buffer reads
This commit is contained in:
parent
5342cfec63
commit
ee24948b73
3 changed files with 77 additions and 23 deletions
|
|
@ -2,6 +2,7 @@ package proto
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var NO = binary.BigEndian
|
||||
|
|
@ -85,11 +86,14 @@ func (buf *LuxBuffer) WriteBytes(bytes []byte) {
|
|||
copy(buf.WriteNext(len(bytes)), bytes)
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) ReadNext(size int) []byte {
|
||||
func (buf *LuxBuffer) ReadNext(size int) ([]byte, error) {
|
||||
if buf.offset+size > buf.len {
|
||||
return nil, fmt.Errorf("ReadNext %d+%d > %d", buf.offset, size, buf.len)
|
||||
}
|
||||
next := buf.data[buf.offset : buf.offset+size]
|
||||
buf.offset += size
|
||||
|
||||
return next
|
||||
return next, nil
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) Skip(skip int) {
|
||||
|
|
@ -97,23 +101,40 @@ func (buf *LuxBuffer) Skip(skip int) {
|
|||
}
|
||||
|
||||
// explicit copy of read bytes
|
||||
func (buf *LuxBuffer) CopyBytes(size int) []byte {
|
||||
func (buf *LuxBuffer) CopyBytes(size int) ([]byte, error) {
|
||||
bytes, err := buf.ReadNext(size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
read := make([]byte, size)
|
||||
copy(read, buf.ReadNext(size))
|
||||
copy(read, bytes)
|
||||
|
||||
return read
|
||||
return read, nil
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) ReadUint16() uint16 {
|
||||
return NO.Uint16(buf.ReadNext(2))
|
||||
func (buf *LuxBuffer) ReadUint16() (uint16, error) {
|
||||
rd, err := buf.ReadNext(2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return NO.Uint16(rd), nil
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) ReadUint32() uint32 {
|
||||
return NO.Uint32(buf.ReadNext(4))
|
||||
func (buf *LuxBuffer) ReadUint32() (uint32, error) {
|
||||
rd, err := buf.ReadNext(4)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return NO.Uint32(rd), nil
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) ReadUint64() uint64 {
|
||||
return NO.Uint64(buf.ReadNext(8))
|
||||
func (buf *LuxBuffer) ReadUint64() (uint64, error) {
|
||||
rd, err := buf.ReadNext(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return NO.Uint64(rd), nil
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) WriteUint16(val uint16) {
|
||||
|
|
@ -131,14 +152,21 @@ func (buf *LuxBuffer) WriteUint64(val uint64) {
|
|||
// variable-length blocks can cause misaligned size,
|
||||
// so we make method for them, forcing padding
|
||||
|
||||
func (buf *LuxBuffer) ReadVarBlock() []byte {
|
||||
len := buf.ReadUint16()
|
||||
bytes := buf.ReadNext(int(len))
|
||||
func (buf *LuxBuffer) ReadVarBlock() ([]byte, error) {
|
||||
len, err := buf.ReadUint16()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bytes, err := buf.ReadNext(int(len))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len%2 != 0 {
|
||||
buf.Skip(1) // skip padding
|
||||
}
|
||||
|
||||
return bytes
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) WriteVarBlock(bytes []byte) {
|
||||
|
|
@ -152,8 +180,12 @@ func (buf *LuxBuffer) WriteVarBlock(bytes []byte) {
|
|||
}
|
||||
|
||||
// strings in LUX protocol format will be padded to align of 2
|
||||
func (buf *LuxBuffer) ReadString() string {
|
||||
return string(buf.ReadVarBlock())
|
||||
func (buf *LuxBuffer) ReadString() (string, error) {
|
||||
block, err := buf.ReadVarBlock()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(block), nil
|
||||
}
|
||||
|
||||
func (buf *LuxBuffer) WriteString(val string) {
|
||||
|
|
|
|||
|
|
@ -6,3 +6,12 @@ const (
|
|||
LuxTypeHost = 0
|
||||
LuxTypeNode
|
||||
)
|
||||
|
||||
func (luxType *LuxType) Read(rd *LuxBuffer) error {
|
||||
if val, err := rd.ReadUint16(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
*luxType = LuxType(val)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,16 +29,16 @@ func TestInts(t *testing.T) {
|
|||
wd.WriteUint64(69)
|
||||
wd.WriteUint32(2967279234)
|
||||
|
||||
if rd := wd.ReadUint16(); rd != 1234 {
|
||||
if rd, _ := wd.ReadUint16(); rd != 1234 {
|
||||
t.Fatal(rd)
|
||||
}
|
||||
if rd := wd.ReadUint16(); rd != 4321 {
|
||||
if rd, _ := wd.ReadUint16(); rd != 4321 {
|
||||
t.Fatal(rd)
|
||||
}
|
||||
if rd := wd.ReadUint64(); rd != 69 {
|
||||
if rd, _ := wd.ReadUint64(); rd != 69 {
|
||||
t.Fatal(rd)
|
||||
}
|
||||
if rd := wd.ReadUint32(); rd != 2967279234 {
|
||||
if rd, _ := wd.ReadUint32(); rd != 2967279234 {
|
||||
t.Fatal(rd)
|
||||
}
|
||||
}
|
||||
|
|
@ -48,13 +48,26 @@ func TestStrings(t *testing.T) {
|
|||
wd.WriteString("a") // 4
|
||||
wd.WriteString("hello") // 8
|
||||
|
||||
if rd := wd.ReadString(); rd != "a" {
|
||||
if rd, _ := wd.ReadString(); rd != "a" {
|
||||
t.Fatal(rd)
|
||||
}
|
||||
if rd := wd.ReadString(); rd != "hello" {
|
||||
if rd, _ := wd.ReadString(); rd != "hello" {
|
||||
t.Fatal(rd)
|
||||
}
|
||||
if wd.Length() != 12 {
|
||||
t.Fatalf("string misaligned size: %d\n", wd.Length())
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadOverrun(t *testing.T) {
|
||||
rd := proto.FromSlice(make([]byte, 2))
|
||||
if _, err := rd.ReadNext(3); err == nil {
|
||||
t.Fatalf("no error when rd.ReadNext(3) for %d\n", rd.Length())
|
||||
}
|
||||
|
||||
rd = proto.FromSlice(make([]byte, 1))
|
||||
rd.ReadNext(1)
|
||||
if _, err := rd.ReadNext(1); err == nil {
|
||||
t.Fatalf("no error when reading at offset 1 of 1 byte buffer\n")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue