From 181b1107a3a0f996096c0b1901b826686e78ad55 Mon Sep 17 00:00:00 2001 From: altugbakan Date: Sat, 11 Mar 2023 19:23:23 +0300 Subject: [PATCH] Fix options on write request --- Cargo.lock | 2 +- Cargo.toml | 4 +-- src/config.rs | 8 +++--- src/convert.rs | 4 +-- src/message.rs | 57 ++++++++++++++---------------------- src/packet.rs | 58 +++++++++++++++++++++++++++++++++---- src/server.rs | 30 +++++++++---------- src/worker.rs | 78 ++++++++++++++++++++++++++------------------------ 8 files changed, 138 insertions(+), 103 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 26d86b5..88bd5d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,4 +4,4 @@ version = 3 [[package]] name = "tftpd" -version = "0.1.0" +version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index d16e5a6..c712fae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "tftpd" -version = "0.1.0" +version = "0.1.1" authors = ["Altuğ Bakan "] edition = "2021" -description = "TFTP Server Daemon implemented in Rust" +description = "Multithreaded TFTP server daemon" repository = "https://github.com/altugbakan/rs-tftpd" license = "MIT" keywords = ["tftp", "server"] diff --git a/src/config.rs b/src/config.rs index 89332dc..892f624 100644 --- a/src/config.rs +++ b/src/config.rs @@ -28,14 +28,14 @@ impl Config { if let Some(ip_str) = args.next() { config.ip_address = ip_str.parse::()?; } else { - return Err("missing ip address after flag".into()); + return Err("Missing ip address after flag".into()); } } "-p" | "--port" => { if let Some(port_str) = args.next() { config.port = port_str.parse::()?; } else { - return Err("missing port number after flag".into()); + return Err("Missing port number after flag".into()); } } "-d" | "--directory" => { @@ -45,7 +45,7 @@ impl Config { } config.directory = PathBuf::from(dir_str); } else { - return Err("missing directory after flag".into()); + return Err("Missing directory after flag".into()); } } "-h" | "--help" => { @@ -60,7 +60,7 @@ impl Config { println!(" -h, --help\t\t\tPrint help information"); process::exit(0); } - invalid => return Err(format!("invalid flag: {invalid}").into()), + invalid => return Err(format!("Invalid flag: {invalid}").into()), } } diff --git a/src/convert.rs b/src/convert.rs index c535b17..ef2f98e 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -5,7 +5,7 @@ pub struct Convert; impl Convert { pub fn to_u16(buf: &[u8]) -> Result { if buf.len() < 2 { - Err("error when converting to u16") + Err("Error when converting to u16") } else { Ok(((buf[0] as u16) << 8) + buf[1] as u16) } @@ -17,7 +17,7 @@ impl Convert { String::from_utf8(buf[start..start + index].to_vec())?, index + start, )), - None => return Err("invalid string".into()), + None => return Err("Invalid string".into()), } } } diff --git a/src/message.rs b/src/message.rs index 5ae4c4b..488a1fc 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,7 +3,7 @@ use std::{ net::{SocketAddr, UdpSocket}, }; -use crate::packet::{ErrorCode, Opcode, Packet, TransferOption}; +use crate::packet::{ErrorCode, Packet, TransferOption}; pub struct Message; @@ -12,25 +12,16 @@ const MAX_REQUEST_PACKET_SIZE: usize = 512; impl Message { pub fn send_data( socket: &UdpSocket, - block_number: u16, - data: &[u8], + block_num: u16, + data: Vec, ) -> Result<(), Box> { - let buf = [ - &Opcode::Data.as_bytes()[..], - &block_number.to_be_bytes(), - data, - ] - .concat(); - - socket.send(&buf)?; + socket.send(&Packet::Data { block_num, data }.serialize()?)?; Ok(()) } pub fn send_ack(socket: &UdpSocket, block_number: u16) -> Result<(), Box> { - let buf = [Opcode::Ack.as_bytes(), block_number.to_be_bytes()].concat(); - - socket.send(&buf)?; + socket.send(&Packet::Ack(block_number).serialize()?)?; Ok(()) } @@ -38,9 +29,9 @@ impl Message { pub fn send_error( socket: &UdpSocket, code: ErrorCode, - msg: &str, + msg: String, ) -> Result<(), Box> { - socket.send(&build_error_buf(code, msg))?; + socket.send(&Packet::Error { code, msg }.serialize()?)?; Ok(()) } @@ -49,9 +40,19 @@ impl Message { socket: &UdpSocket, to: &SocketAddr, code: ErrorCode, - msg: &'a str, + msg: String, ) -> Result<(), Box> { - if socket.send_to(&build_error_buf(code, msg), to).is_err() { + if socket + .send_to( + &Packet::Error { + code, + msg: msg.clone(), + } + .serialize()?, + to, + ) + .is_err() + { eprintln!("could not send an error message"); } Err(msg.into()) @@ -59,15 +60,9 @@ impl Message { pub fn send_oack( socket: &UdpSocket, - options: &Vec, + options: Vec, ) -> Result<(), Box> { - let mut buf = Opcode::Oack.as_bytes().to_vec(); - - for option in options { - buf = [buf, option.as_bytes()].concat(); - } - - socket.send(&buf)?; + socket.send(&Packet::Oack { options }.serialize()?)?; Ok(()) } @@ -96,13 +91,3 @@ impl Message { Ok((packet, from)) } } - -fn build_error_buf(code: ErrorCode, msg: &str) -> Vec { - [ - &Opcode::Error.as_bytes()[..], - &code.as_bytes()[..], - &msg.as_bytes()[..], - &[0x00], - ] - .concat() -} diff --git a/src/packet.rs b/src/packet.rs index 95bd61b..8f7c010 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -21,6 +21,9 @@ pub enum Packet { code: ErrorCode, msg: String, }, + Oack { + options: Vec, + }, } impl Packet { @@ -32,7 +35,17 @@ impl Packet { Opcode::Data => parse_data(buf), Opcode::Ack => parse_ack(buf), Opcode::Error => parse_error(buf), - _ => Err("invalid packet".into()), + _ => Err("Invalid packet".into()), + } + } + + pub fn serialize(&self) -> Result, Box> { + match self { + Packet::Data { block_num, data } => Ok(serialize_data(block_num, data)), + Packet::Ack(block_num) => Ok(serialize_ack(block_num)), + Packet::Error { code, msg } => Ok(serialize_error(code, msg)), + Packet::Oack { options } => Ok(serialize_oack(options)), + _ => Err("Invalid packet".into()), } } } @@ -57,7 +70,7 @@ impl Opcode { 0x0004 => Ok(Opcode::Ack), 0x0005 => Ok(Opcode::Error), 0x0006 => Ok(Opcode::Oack), - _ => Err("invalid opcode"), + _ => Err("Invalid opcode"), } } @@ -105,13 +118,13 @@ impl OptionType { "blksize" => Ok(OptionType::BlockSize), "tsize" => Ok(OptionType::TransferSize), "timeout" => Ok(OptionType::Timeout), - _ => Err("invalid option type".into()), + _ => Err("Invalid option type".into()), } } } #[repr(u16)] -#[derive(PartialEq, Debug)] +#[derive(Clone, Copy, PartialEq, Debug)] pub enum ErrorCode { NotDefined = 0, FileNotFound = 1, @@ -134,7 +147,7 @@ impl ErrorCode { 5 => Ok(ErrorCode::UnknownId), 6 => Ok(ErrorCode::FileExists), 7 => Ok(ErrorCode::NoSuchUser), - _ => Err("invalid error code"), + _ => Err("Invalid error code"), } } @@ -192,7 +205,7 @@ fn parse_rq(buf: &[u8], opcode: Opcode) -> Result> { mode, options, }), - _ => Err("non request opcode".into()), + _ => Err("Non request opcode".into()), } } @@ -213,6 +226,39 @@ fn parse_error(buf: &[u8]) -> Result> { Ok(Packet::Error { code, msg }) } +fn serialize_data(block_num: &u16, data: &Vec) -> Vec { + [ + &Opcode::Data.as_bytes(), + &block_num.to_be_bytes(), + data.as_slice(), + ] + .concat() +} + +fn serialize_ack(block_num: &u16) -> Vec { + [Opcode::Ack.as_bytes(), block_num.to_be_bytes()].concat() +} + +fn serialize_error(code: &ErrorCode, msg: &String) -> Vec { + [ + &Opcode::Error.as_bytes()[..], + &code.as_bytes()[..], + &msg.as_bytes()[..], + &[0x00], + ] + .concat() +} + +fn serialize_oack(options: &Vec) -> Vec { + let mut buf = Opcode::Oack.as_bytes().to_vec(); + + for option in options { + buf = [buf, option.as_bytes()].concat(); + } + + buf +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/server.rs b/src/server.rs index d8fc6e5..1c1e72f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -29,28 +29,28 @@ impl Server { filename, mut options, .. - } => match self.handle_rrq(filename.clone(), &mut options, &from) { - Ok(_) => { - println!("Sending {filename} to {from}"); + } => { + println!("Sending {filename} to {from}"); + if let Err(err) = self.handle_rrq(filename.clone(), &mut options, &from) { + eprintln!("{err}") } - Err(err) => eprintln!("{err}"), - }, + } Packet::Wrq { filename, mut options, .. - } => match self.handle_wrq(filename.clone(), &mut options, &from) { - Ok(_) => { - println!("Receiving {filename} from {from}"); + } => { + println!("Receiving {filename} from {from}"); + if let Err(err) = self.handle_wrq(filename.clone(), &mut options, &from) { + eprintln!("{err}") } - Err(err) => eprintln!("{err}"), - }, + } _ => { Message::send_error_to( &self.socket, &from, ErrorCode::IllegalOperation, - "invalid request", + "invalid request".to_string(), ) .unwrap_or_else(|err| eprintln!("{err}")); } @@ -71,7 +71,7 @@ impl Server { &self.socket, to, ErrorCode::FileNotFound, - "file does not exist", + "file does not exist".to_string(), ); } ErrorCode::AccessViolation => { @@ -79,7 +79,7 @@ impl Server { &self.socket, to, ErrorCode::AccessViolation, - "file access violation", + "file access violation".to_string(), ); } ErrorCode::FileExists => Worker::send( @@ -106,7 +106,7 @@ impl Server { &self.socket, to, ErrorCode::FileExists, - "requested file already exists", + "requested file already exists".to_string(), ); } ErrorCode::AccessViolation => { @@ -114,7 +114,7 @@ impl Server { &self.socket, to, ErrorCode::AccessViolation, - "file access violation", + "file access violation".to_string(), ); } ErrorCode::FileNotFound => Worker::receive( diff --git a/src/worker.rs b/src/worker.rs index 6a2aaf5..60b6260 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -42,9 +42,9 @@ impl Worker { let mut handle_send = || -> Result<(), Box> { let socket = setup_socket(&addr, &remote)?; let work_type = WorkType::Send(Path::new(&filename).metadata().unwrap().len()); - let worker_options = parse_options(&mut options, &work_type)?; - send_oack(&socket, &options, &work_type)?; - send_file(&socket, &worker_options, &filename, &mut options)?; + accept_request(&socket, &options, &work_type)?; + check_response(&socket)?; + send_file(&socket, &filename, &mut options)?; Ok(()) }; @@ -65,9 +65,8 @@ impl Worker { let mut handle_receive = || -> Result<(), Box> { let socket = setup_socket(&addr, &remote)?; let work_type = WorkType::Receive; - let worker_options = parse_options(&mut options, &work_type)?; - send_oack(&socket, &options, &work_type)?; - receive_file(&socket, &worker_options, &filename, &mut options)?; + accept_request(&socket, &options, &work_type)?; + receive_file(&socket, &filename, &mut options)?; Ok(()) }; @@ -81,13 +80,11 @@ impl Worker { fn send_file( socket: &UdpSocket, - worker_options: &WorkerOptions, filename: &String, options: &mut Vec, ) -> Result<(), Box> { let mut file = File::open(filename).unwrap(); - - parse_options(options, &WorkType::Send(file.metadata().unwrap().len()))?; + let worker_options = parse_options(options, &WorkType::Send(file.metadata().unwrap().len()))?; let mut block_number = 1; loop { @@ -96,7 +93,7 @@ fn send_file( let mut retry_cnt = 0; loop { - Message::send_data(socket, block_number, &chunk[..size])?; + Message::send_data(socket, block_number, chunk[..size].to_vec())?; match Message::recv(socket) { Ok(Packet::Ack(received_block_number)) => { @@ -106,12 +103,12 @@ fn send_file( } } Ok(Packet::Error { code, msg }) => { - return Err(format!("received error code {code}, with message {msg}").into()); + return Err(format!("Received error code {code}, with message {msg}").into()); } _ => { retry_cnt += 1; if retry_cnt == MAX_RETRIES { - return Err(format!("transfer timed out after {MAX_RETRIES} tries").into()); + return Err(format!("Transfer timed out after {MAX_RETRIES} tries").into()); } } } @@ -128,13 +125,11 @@ fn send_file( fn receive_file( socket: &UdpSocket, - worker_options: &WorkerOptions, filename: &String, options: &mut Vec, ) -> Result<(), Box> { let mut file = File::create(filename).unwrap(); - - parse_options(options, &WorkType::Receive)?; + let worker_options = parse_options(options, &WorkType::Receive)?; let mut block_number: u16 = 0; loop { @@ -155,13 +150,13 @@ fn receive_file( } } Ok(Packet::Error { code, msg }) => { - return Err(format!("received error code {code}: {msg}").into()); + return Err(format!("Received error code {code}: {msg}").into()); } Err(err) => { retry_cnt += 1; if retry_cnt == MAX_RETRIES { return Err( - format!("transfer timed out after {MAX_RETRIES} tries: {err}").into(), + format!("Transfer timed out after {MAX_RETRIES} tries: {err}").into(), ); } } @@ -179,6 +174,34 @@ fn receive_file( Ok(()) } +fn accept_request( + socket: &UdpSocket, + options: &Vec, + work_type: &WorkType, +) -> Result<(), Box> { + if options.len() > 0 { + Message::send_oack(socket, options.to_vec())?; + } else if *work_type == WorkType::Receive { + Message::send_ack(socket, 0)? + } + + Ok(()) +} + +fn check_response(socket: &UdpSocket) -> Result<(), Box> { + if let Packet::Ack(received_block_number) = Message::recv(&socket)? { + if received_block_number != 0 { + Message::send_error( + &socket, + ErrorCode::IllegalOperation, + "invalid oack response".to_string(), + )?; + } + } + + Ok(()) +} + fn setup_socket(addr: &SocketAddr, remote: &SocketAddr) -> Result> { let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?; socket.connect(remote)?; @@ -212,7 +235,7 @@ fn parse_options( }, OptionType::Timeout => { if *value == 0 { - return Err("invalid timeout value".into()); + return Err("Invalid timeout value".into()); } worker_options.timeout = *value as u64; } @@ -221,22 +244,3 @@ fn parse_options( Ok(worker_options) } - -fn send_oack( - socket: &UdpSocket, - options: &Vec, - work_type: &WorkType, -) -> Result<(), Box> { - if options.len() > 0 { - Message::send_oack(socket, options)?; - if let Packet::Ack(received_block_number) = Message::recv(socket)? { - if received_block_number != 0 { - Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?; - } - } - } else if *work_type == WorkType::Receive { - Message::send_ack(socket, 0)? - } - - Ok(()) -}