diff --git a/src/lib.rs b/src/lib.rs index 9ec0d77..51c6d51 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,6 @@ mod config; mod convert; -mod message; mod packet; mod server; mod socket; @@ -26,7 +25,6 @@ mod worker; pub use config::Config; pub use convert::Convert; -pub use message::Message; pub use packet::ErrorCode; pub use packet::Opcode; pub use packet::OptionType; diff --git a/src/message.rs b/src/message.rs deleted file mode 100644 index 7c35ca4..0000000 --- a/src/message.rs +++ /dev/null @@ -1,133 +0,0 @@ -use std::{error::Error, net::SocketAddr}; - -use crate::{ErrorCode, Packet, Socket, TransferOption}; - -/// Message `struct` is used for easy message transmission of common TFTP -/// message types. -/// -/// # Example -/// -/// ```rust -/// use std::{net::{SocketAddr, UdpSocket}, str::FromStr}; -/// use tftpd::{Message, ErrorCode}; -/// -/// // Send a FileNotFound error. -/// Message::send_error_to( -/// &UdpSocket::bind(SocketAddr::from_str("127.0.0.1:6969").unwrap()).unwrap(), -/// &SocketAddr::from_str("127.0.0.1:1234").unwrap(), -/// ErrorCode::FileNotFound, -/// "file does not exist", -/// ); -/// ``` -pub struct Message; - -const MAX_REQUEST_PACKET_SIZE: usize = 512; - -impl Message { - /// Sends a data packet to the socket's connected remote. See - /// [`UdpSocket`] for more information about connected - /// sockets. - pub fn send_data( - socket: &T, - block_num: u16, - data: Vec, - ) -> Result<(), Box> { - socket.send(&Packet::Data { block_num, data }) - } - - /// Sends an acknowledgement packet to the socket's connected remote. See - /// [`UdpSocket`] for more information about connected - /// sockets. - pub fn send_ack(socket: &T, block_number: u16) -> Result<(), Box> { - socket.send(&Packet::Ack(block_number)) - } - - /// Sends an error packet to the socket's connected remote. See - /// [`UdpSocket`] for more information about connected - /// sockets. - pub fn send_error( - socket: &T, - code: ErrorCode, - msg: &str, - ) -> Result<(), Box> { - if socket - .send(&Packet::Error { - code, - msg: msg.to_string(), - }) - .is_err() - { - eprintln!("could not send an error message"); - }; - - Err(msg.into()) - } - - /// Sends an error packet to the supplied [`SocketAddr`]. - pub fn send_error_to( - socket: &T, - to: &SocketAddr, - code: ErrorCode, - msg: &str, - ) -> Result<(), Box> { - if socket - .send_to( - &Packet::Error { - code, - msg: msg.to_string(), - }, - to, - ) - .is_err() - { - eprintln!("could not send an error message"); - } - Err(msg.into()) - } - - /// Sends an option acknowledgement packet to the socket's connected remote. - /// See [`UdpSocket`] for more information about connected sockets. - pub fn send_oack( - socket: &T, - options: Vec, - ) -> Result<(), Box> { - socket.send(&Packet::Oack(options)) - } - - /// Receives a packet from the socket's connected remote, and returns the - /// parsed [`Packet`]. This function cannot handle large data packets due to - /// the limited buffer size. For handling data packets, see [`Message::recv_with_size()`]. - pub fn recv(socket: &T) -> Result> { - let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; - socket.recv(&mut buf) - } - - /// Receives a packet from any incoming remote request, and returns the - /// parsed [`Packet`] and the requesting [`SocketAddr`]. This function cannot handle - /// large data packets due to the limited buffer size, so it is intended for - /// only accepting incoming requests. For handling data packets, see [`Message::recv_with_size()`]. - pub fn recv_from(socket: &T) -> Result<(Packet, SocketAddr), Box> { - let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; - socket.recv_from(&mut buf) - } - - /// Receives a data packet from the socket's connected remote, and returns the - /// parsed [`Packet`]. The received packet can actually be of any type, however, - /// this function also allows supplying the buffer size for an incoming request. - pub fn recv_with_size(socket: &T, size: usize) -> Result> { - let mut buf = vec![0; size + 4]; - socket.recv(&mut buf) - } - - /// Receives a data packet from any incoming remote request, and returns the - /// parsed [`Packet`] and the requesting [`SocketAddr`]. The received packet can - /// actually be of any type, however, this function also allows supplying the - /// buffer size for an incoming request. - pub fn recv_from_with_size( - socket: &T, - size: usize, - ) -> Result<(Packet, SocketAddr), Box> { - let mut buf = vec![0; size + 4]; - socket.recv_from(&mut buf) - } -} diff --git a/src/server.rs b/src/server.rs index df910cf..3e5ec7f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,4 @@ -use crate::{Config, Message, OptionType, ServerSocket, Socket, Worker}; +use crate::{Config, OptionType, ServerSocket, Socket, Worker}; use crate::{ErrorCode, Packet, TransferOption}; use std::cmp::max; use std::collections::HashMap; @@ -55,9 +55,9 @@ impl Server { pub fn listen(&mut self) { loop { let received = if self.single_port { - Message::recv_from_with_size(&self.socket, self.largest_block_size) + self.socket.recv_from_with_size(self.largest_block_size) } else { - Message::recv_from(&self.socket) + Socket::recv_from(&self.socket) }; if let Ok((packet, from)) = received { @@ -84,13 +84,19 @@ impl Server { } _ => { if self.route_packet(packet, &from).is_err() { - Message::send_error_to( + if Socket::send_to( &self.socket, + &Packet::Error { + code: ErrorCode::IllegalOperation, + msg: "invalid request".to_string(), + }, &from, - ErrorCode::IllegalOperation, - "invalid request", ) - .unwrap_or_else(|_| eprintln!("Received invalid request")); + .is_err() + { + eprintln!("Could not send error packet"); + }; + eprintln!("Received invalid request"); } } }; @@ -106,17 +112,21 @@ impl Server { ) -> Result<(), Box> { let file_path = &self.directory.join(filename); match check_file_exists(file_path, &self.directory) { - ErrorCode::FileNotFound => Message::send_error_to( + ErrorCode::FileNotFound => Socket::send_to( &self.socket, + &Packet::Error { + code: ErrorCode::FileNotFound, + msg: "file does not exist".to_string(), + }, to, - ErrorCode::FileNotFound, - "file does not exist", ), - ErrorCode::AccessViolation => Message::send_error_to( + ErrorCode::AccessViolation => Socket::send_to( &self.socket, + &Packet::Error { + code: ErrorCode::AccessViolation, + msg: "file access violation".to_string(), + }, to, - ErrorCode::AccessViolation, - "file access violation", ), ErrorCode::FileExists => { let worker_options = @@ -164,17 +174,21 @@ impl Server { ) -> Result<(), Box> { let file_path = &self.directory.join(file_name); match check_file_exists(file_path, &self.directory) { - ErrorCode::FileExists => Message::send_error_to( + ErrorCode::FileExists => Socket::send_to( &self.socket, + &Packet::Error { + code: ErrorCode::FileExists, + msg: "requested file already exists".to_string(), + }, to, - ErrorCode::FileExists, - "requested file already exists", ), - ErrorCode::AccessViolation => Message::send_error_to( + ErrorCode::AccessViolation => Socket::send_to( &self.socket, + &Packet::Error { + code: ErrorCode::AccessViolation, + msg: "file access violation".to_string(), + }, to, - ErrorCode::AccessViolation, - "file access violation", ), ErrorCode::FileNotFound => { let worker_options = parse_options(options, RequestType::Write)?; @@ -302,21 +316,24 @@ fn accept_request( request_type: RequestType, ) -> Result<(), Box> { if !options.is_empty() { - Message::send_oack(socket, options.to_vec())?; + socket.send(&Packet::Oack(options.to_vec()))?; if let RequestType::Read(_) = request_type { check_response(socket)?; } } else if request_type == RequestType::Write { - Message::send_ack(socket, 0)? + socket.send(&Packet::Ack(0))?; } Ok(()) } fn check_response(socket: &T) -> Result<(), Box> { - if let Packet::Ack(received_block_number) = Message::recv(socket)? { + if let Packet::Ack(received_block_number) = socket.recv()? { if received_block_number != 0 { - Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?; + socket.send(&Packet::Error { + code: ErrorCode::IllegalOperation, + msg: "invalid oack response".to_string(), + })?; } } diff --git a/src/socket.rs b/src/socket.rs index 00b853e..409e030 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -1,3 +1,4 @@ +use crate::Packet; use std::{ error::Error, net::{SocketAddr, UdpSocket}, @@ -8,9 +9,8 @@ use std::{ time::Duration, }; -use crate::Packet; - const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); +const MAX_REQUEST_PACKET_SIZE: usize = 512; /// Socket `trait` is used to allow building custom sockets to be used for /// TFTP communication. @@ -19,11 +19,29 @@ pub trait Socket: Send + Sync + 'static { fn send(&self, packet: &Packet) -> Result<(), Box>; /// Sends a [`Packet`] to the specified remote [`Socket`]. fn send_to(&self, packet: &Packet, to: &SocketAddr) -> Result<(), Box>; - /// Receives a [`Packet`] from the socket's connected remote [`Socket`]. - fn recv(&self, buf: &mut [u8]) -> Result>; + /// Receives a [`Packet`] from the socket's connected remote [`Socket`]. This + /// function cannot handle large data packets due to the limited buffer size, + /// so it is intended for only accepting incoming requests. For handling data + /// packets, see [`Socket::recv_with_size()`]. + fn recv(&self) -> Result> { + self.recv_with_size(MAX_REQUEST_PACKET_SIZE) + } + /// Receives a data packet from the socket's connected remote, and returns the + /// parsed [`Packet`]. The received packet can actually be of any type, however, + /// this function also allows supplying the buffer size for an incoming request. + fn recv_with_size(&self, size: usize) -> Result>; /// Receives a [`Packet`] from any remote [`Socket`] and returns the [`SocketAddr`] - /// of the remote [`Socket`]. - fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box>; + /// of the remote [`Socket`]. This function cannot handle large data packets + /// due to the limited buffer size, so it is intended for only accepting incoming + /// requests. For handling data packets, see [`Socket::recv_from_with_size()`]. + fn recv_from(&self) -> Result<(Packet, SocketAddr), Box> { + self.recv_from_with_size(MAX_REQUEST_PACKET_SIZE) + } + /// Receives a data packet from any incoming remote request, and returns the + /// parsed [`Packet`] and the requesting [`SocketAddr`]. The received packet can + /// actually be of any type, however, this function also allows supplying the + /// buffer size for an incoming request. + fn recv_from_with_size(&self, size: usize) -> Result<(Packet, SocketAddr), Box>; /// Returns the remote [`SocketAddr`] if it exists. fn remote_addr(&self) -> Result>; /// Sets the read timeout for the [`Socket`]. @@ -45,15 +63,17 @@ impl Socket for UdpSocket { Ok(()) } - fn recv(&self, buf: &mut [u8]) -> Result> { - let amt = self.recv(buf)?; + fn recv_with_size(&self, size: usize) -> Result> { + let mut buf = vec![0; size + 4]; + let amt = self.recv(&mut buf)?; let packet = Packet::deserialize(&buf[..amt])?; Ok(packet) } - fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box> { - let (amt, addr) = self.recv_from(buf)?; + fn recv_from_with_size(&self, size: usize) -> Result<(Packet, SocketAddr), Box> { + let mut buf = vec![0; size + 4]; + let (amt, addr) = self.recv_from(&mut buf)?; let packet = Packet::deserialize(&buf[..amt])?; Ok((packet, addr)) @@ -112,7 +132,7 @@ impl Socket for ServerSocket { Ok(()) } - fn recv(&self, _buf: &mut [u8]) -> Result> { + fn recv_with_size(&self, _size: usize) -> Result> { if let Ok(receiver) = self.receiver.lock() { if let Ok(packet) = receiver.recv_timeout(self.timeout) { Ok(packet) @@ -124,8 +144,8 @@ impl Socket for ServerSocket { } } - fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box> { - Ok((self.recv(buf)?, self.remote)) + fn recv_from_with_size(&self, _size: usize) -> Result<(Packet, SocketAddr), Box> { + Ok((self.recv()?, self.remote)) } fn remote_addr(&self) -> Result> { @@ -173,12 +193,12 @@ impl Socket for Box { (**self).send_to(packet, to) } - fn recv(&self, buf: &mut [u8]) -> Result> { - (**self).recv(buf) + fn recv_with_size(&self, size: usize) -> Result> { + (**self).recv_with_size(size) } - fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box> { - (**self).recv_from(buf) + fn recv_from_with_size(&self, size: usize) -> Result<(Packet, SocketAddr), Box> { + (**self).recv_from_with_size(size) } fn remote_addr(&self) -> Result> { @@ -209,7 +229,7 @@ mod tests { socket.sender.lock().unwrap().send(Packet::Ack(1)).unwrap(); - let packet = socket.recv(&mut []).unwrap(); + let packet = socket.recv().unwrap(); assert_eq!(packet, Packet::Ack(1)); @@ -223,7 +243,7 @@ mod tests { }) .unwrap(); - let packet = socket.recv(&mut []).unwrap(); + let packet = socket.recv().unwrap(); assert_eq!( packet, diff --git a/src/window.rs b/src/window.rs index eb07e28..7aad4f1 100644 --- a/src/window.rs +++ b/src/window.rs @@ -206,10 +206,12 @@ mod tests { .unwrap() } - #[allow(unused_must_use)] fn clean(file_name: &str) { let file_name = DIR_NAME.to_string() + "/" + file_name; fs::remove_file(file_name).unwrap(); - fs::remove_dir(DIR_NAME); + if fs::remove_dir(DIR_NAME).is_err() { + // ignore removing directory, as other tests are + // still running + } } } diff --git a/src/worker.rs b/src/worker.rs index 36e0d03..f302a74 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,3 +1,4 @@ +use crate::{Packet, Socket, Window}; use std::{ error::Error, fs::{self, File}, @@ -6,8 +7,6 @@ use std::{ time::{Duration, Instant}, }; -use crate::{Message, Packet, Socket, Window}; - const MAX_RETRIES: u32 = 6; const TIMEOUT_BUFFER: Duration = Duration::from_secs(1); @@ -141,7 +140,7 @@ impl Worker { time = Instant::now(); } - match Message::recv(&self.socket) { + match self.socket.recv() { Ok(Packet::Ack(received_block_number)) => { let diff = received_block_number.wrapping_sub(block_number); if diff <= self.windowsize { @@ -181,7 +180,7 @@ impl Worker { let mut retry_cnt = 0; loop { - match Message::recv_with_size(&self.socket, self.blk_size) { + match self.socket.recv_with_size(self.blk_size) { Ok(Packet::Data { block_num: received_block_number, data, @@ -215,7 +214,7 @@ impl Worker { } window.empty()?; - Message::send_ack(&self.socket, block_number)?; + self.socket.send(&Packet::Ack(block_number))?; if size < self.blk_size { break; }; @@ -231,7 +230,10 @@ fn send_window( mut block_num: u16, ) -> Result<(), Box> { for frame in window.get_elements() { - Message::send_data(socket, block_num, frame.to_vec())?; + socket.send(&Packet::Data { + block_num, + data: frame.to_vec(), + })?; block_num = block_num.wrapping_add(1); }