diff --git a/Cargo.lock b/Cargo.lock index 1bf47cc..5a4bbe5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,4 +4,4 @@ version = 3 [[package]] name = "tftpd" -version = "0.2.2" +version = "0.2.4" diff --git a/Cargo.toml b/Cargo.toml index 71fe963..9fee4ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tftpd" -version = "0.2.2" +version = "0.2.4" authors = ["Altuğ Bakan "] edition = "2021" description = "Multithreaded TFTP server daemon" diff --git a/src/config.rs b/src/config.rs index f18d911..17a22e2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -25,6 +25,8 @@ pub struct Config { pub port: u16, /// Default directory of the TFTP Server. (default: current working directory) pub directory: PathBuf, + /// Use a single port for both sending and receiving. (default: false) + pub single_port: bool, } impl Config { @@ -38,6 +40,7 @@ impl Config { ip_address: Ipv4Addr::new(127, 0, 0, 1), port: 69, directory: env::current_dir().unwrap_or_else(|_| env::temp_dir()), + single_port: false, }; args.next(); @@ -68,6 +71,9 @@ impl Config { return Err("Missing directory after flag".into()); } } + "-s" | "--single-port" => { + config.single_port = true; + } "-h" | "--help" => { println!("TFTP Server Daemon\n"); println!("Usage: tftpd [OPTIONS]\n"); @@ -77,6 +83,7 @@ impl Config { " -p, --port \t\tSet the listening port of the server (default: 69)" ); println!(" -d, --directory \tSet the listening port of the server (default: Current Working Directory)"); + println!(" -s, --single-port\t\tUse a single port for both sending and receiving (default: false)"); println!(" -h, --help\t\t\tPrint help information"); process::exit(0); } @@ -97,7 +104,7 @@ mod tests { #[test] fn parses_full_config() { let config = Config::new( - vec!["/", "-i", "0.0.0.0", "-p", "1234", "-d", "/"] + vec!["/", "-i", "0.0.0.0", "-p", "1234", "-d", "/", "-s"] .iter() .map(|s| s.to_string()), ) @@ -106,6 +113,7 @@ mod tests { assert_eq!(config.ip_address, Ipv4Addr::new(0, 0, 0, 0)); assert_eq!(config.port, 1234); assert_eq!(config.directory, PathBuf::from_str("/").unwrap()); + assert!(config.single_port); } #[test] diff --git a/src/lib.rs b/src/lib.rs index 8099d4f..9ec0d77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ mod convert; mod message; mod packet; mod server; +mod socket; mod window; mod worker; @@ -32,5 +33,7 @@ pub use packet::OptionType; pub use packet::Packet; pub use packet::TransferOption; pub use server::Server; +pub use socket::ServerSocket; +pub use socket::Socket; pub use window::Window; pub use worker::Worker; diff --git a/src/main.rs b/src/main.rs index 076cea6..8bc460a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ fn main() { process::exit(1) }); - let server = Server::new(&config).unwrap_or_else(|err| { + let mut server = Server::new(&config).unwrap_or_else(|err| { eprintln!( "Problem creating server on {}:{}: {err}", config.ip_address, config.port diff --git a/src/message.rs b/src/message.rs index 005e62d..7c35ca4 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,9 +1,6 @@ -use std::{ - error::Error, - net::{SocketAddr, UdpSocket}, -}; +use std::{error::Error, net::SocketAddr}; -use crate::{ErrorCode, Packet, TransferOption}; +use crate::{ErrorCode, Packet, Socket, TransferOption}; /// Message `struct` is used for easy message transmission of common TFTP /// message types. @@ -30,41 +27,34 @@ impl Message { /// Sends a data packet to the socket's connected remote. See /// [`UdpSocket`] for more information about connected /// sockets. - pub fn send_data( - socket: &UdpSocket, + pub fn send_data( + socket: &T, block_num: u16, data: Vec, ) -> Result<(), Box> { - socket.send(&Packet::Data { block_num, data }.serialize()?)?; - - Ok(()) + socket.send(&Packet::Data { block_num, data }) } /// Sends an acknowledgement packet to the socket's connected remote. See /// [`UdpSocket`] for more information about connected /// sockets. - pub fn send_ack(socket: &UdpSocket, block_number: u16) -> Result<(), Box> { - socket.send(&Packet::Ack(block_number).serialize()?)?; - - Ok(()) + pub fn send_ack(socket: &T, block_number: u16) -> Result<(), Box> { + socket.send(&Packet::Ack(block_number)) } /// Sends an error packet to the socket's connected remote. See /// [`UdpSocket`] for more information about connected /// sockets. - pub fn send_error( - socket: &UdpSocket, + pub fn send_error( + socket: &T, code: ErrorCode, msg: &str, ) -> Result<(), Box> { if socket - .send( - &Packet::Error { - code, - msg: msg.to_string(), - } - .serialize()?, - ) + .send(&Packet::Error { + code, + msg: msg.to_string(), + }) .is_err() { eprintln!("could not send an error message"); @@ -74,8 +64,8 @@ impl Message { } /// Sends an error packet to the supplied [`SocketAddr`]. - pub fn send_error_to( - socket: &UdpSocket, + pub fn send_error_to( + socket: &T, to: &SocketAddr, code: ErrorCode, msg: &str, @@ -85,8 +75,7 @@ impl Message { &Packet::Error { code, msg: msg.to_string(), - } - .serialize()?, + }, to, ) .is_err() @@ -98,46 +87,47 @@ impl Message { /// Sends an option acknowledgement packet to the socket's connected remote. /// See [`UdpSocket`] for more information about connected sockets. - pub fn send_oack( - socket: &UdpSocket, + pub fn send_oack( + socket: &T, options: Vec, ) -> Result<(), Box> { - socket.send(&Packet::Oack(options).serialize()?)?; - - Ok(()) + socket.send(&Packet::Oack(options)) } /// Receives a packet from the socket's connected remote, and returns the /// parsed [`Packet`]. This function cannot handle large data packets due to /// the limited buffer size. For handling data packets, see [`Message::recv_with_size()`]. - pub fn recv(socket: &UdpSocket) -> Result> { + pub fn recv(socket: &T) -> Result> { let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; - let number_of_bytes = socket.recv(&mut buf)?; - let packet = Packet::deserialize(&buf[..number_of_bytes])?; - - Ok(packet) + socket.recv(&mut buf) } /// Receives a packet from any incoming remote request, and returns the /// parsed [`Packet`] and the requesting [`SocketAddr`]. This function cannot handle /// large data packets due to the limited buffer size, so it is intended for /// only accepting incoming requests. For handling data packets, see [`Message::recv_with_size()`]. - pub fn recv_from(socket: &UdpSocket) -> Result<(Packet, SocketAddr), Box> { + pub fn recv_from(socket: &T) -> Result<(Packet, SocketAddr), Box> { let mut buf = [0; MAX_REQUEST_PACKET_SIZE]; - let (number_of_bytes, from) = socket.recv_from(&mut buf)?; - let packet = Packet::deserialize(&buf[..number_of_bytes])?; - - Ok((packet, from)) + socket.recv_from(&mut buf) } /// Receives a data packet from the socket's connected remote, and returns the /// parsed [`Packet`]. The received packet can actually be of any type, however, /// this function also allows supplying the buffer size for an incoming request. - pub fn recv_with_size(socket: &UdpSocket, size: usize) -> Result> { + pub fn recv_with_size(socket: &T, size: usize) -> Result> { let mut buf = vec![0; size + 4]; - let number_of_bytes = socket.recv(&mut buf)?; - let packet = Packet::deserialize(&buf[..number_of_bytes])?; + socket.recv(&mut buf) + } - Ok(packet) + /// Receives a data packet from any incoming remote request, and returns the + /// parsed [`Packet`] and the requesting [`SocketAddr`]. The received packet can + /// actually be of any type, however, this function also allows supplying the + /// buffer size for an incoming request. + pub fn recv_from_with_size( + socket: &T, + size: usize, + ) -> Result<(Packet, SocketAddr), Box> { + let mut buf = vec![0; size + 4]; + socket.recv_from(&mut buf) } } diff --git a/src/server.rs b/src/server.rs index e3cb6d4..e793eb8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,16 @@ -use crate::{Config, Message, Worker}; +use crate::{Config, Message, OptionType, ServerSocket, Socket, Worker}; use crate::{ErrorCode, Packet, TransferOption}; +use std::cmp::max; +use std::collections::HashMap; use std::error::Error; use std::net::{SocketAddr, UdpSocket}; use std::path::{Path, PathBuf}; +use std::sync::mpsc::Sender; +use std::time::Duration; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); +const DEFAULT_BLOCK_SIZE: usize = 512; +const DEFAULT_WINDOW_SIZE: u16 = 1; /// Server `struct` is used for handling incoming TFTP requests. /// @@ -22,6 +30,9 @@ use std::path::{Path, PathBuf}; pub struct Server { socket: UdpSocket, directory: PathBuf, + single_port: bool, + largest_block_size: usize, + clients: HashMap>, } impl Server { @@ -32,15 +43,24 @@ impl Server { let server = Server { socket, directory: config.directory.clone(), + single_port: config.single_port, + largest_block_size: DEFAULT_BLOCK_SIZE, + clients: HashMap::new(), }; Ok(server) } /// Starts listening for connections. Note that this function does not finish running until termination. - pub fn listen(&self) { + pub fn listen(&mut self) { loop { - if let Ok((packet, from)) = Message::recv_from(&self.socket) { + let received = if self.single_port { + Message::recv_from_with_size(&self.socket, self.largest_block_size) + } else { + Message::recv_from(&self.socket) + }; + + if let Ok((packet, from)) = received { match packet { Packet::Rrq { filename, @@ -63,13 +83,15 @@ impl Server { } } _ => { - Message::send_error_to( - &self.socket, - &from, - ErrorCode::IllegalOperation, - "invalid request", - ) - .unwrap_or_else(|_| eprintln!("Received invalid request")); + if self.route_packet(packet, &from).is_err() { + Message::send_error_to( + &self.socket, + &from, + ErrorCode::IllegalOperation, + "invalid request", + ) + .unwrap_or_else(|_| eprintln!("Received invalid request")); + } } }; } @@ -77,7 +99,7 @@ impl Server { } fn handle_rrq( - &self, + &mut self, filename: String, options: &mut [TransferOption], to: &SocketAddr, @@ -97,25 +119,63 @@ impl Server { "file access violation", ), ErrorCode::FileExists => { - Worker::send( - self.socket.local_addr()?, - *to, - file_path.to_path_buf(), - options.to_vec(), - ); - Ok(()) + let worker_options = + parse_options(options, RequestType::Read(file_path.metadata()?.len()))?; + + if self.single_port { + let mut socket = create_single_socket(&self.socket, to)?; + socket.set_read_timeout(worker_options.timeout)?; + socket.set_write_timeout(worker_options.timeout)?; + + self.clients.insert(*to, socket.sender()); + self.largest_block_size = + max(self.largest_block_size, worker_options.block_size); + accept_request( + &socket, + options, + RequestType::Read(file_path.metadata()?.len()), + )?; + + let worker = Worker::new( + socket, + file_path.clone(), + worker_options.block_size, + worker_options.timeout, + worker_options.window_size, + ); + worker.send() + } else { + let socket = create_multi_socket(&self.socket.local_addr()?, to)?; + socket.set_read_timeout(Some(worker_options.timeout))?; + socket.set_write_timeout(Some(worker_options.timeout))?; + + accept_request( + &socket, + options, + RequestType::Read(file_path.metadata()?.len()), + )?; + + let worker = Worker::new( + socket, + file_path.clone(), + worker_options.block_size, + worker_options.timeout, + worker_options.window_size, + ); + worker.send() + } } - _ => Err("unexpected error code when checking file".into()), + _ => Err("Unexpected error code when checking file".into()), } } fn handle_wrq( - &self, - filename: String, + &mut self, + file_name: String, options: &mut [TransferOption], to: &SocketAddr, ) -> Result<(), Box> { - let file_path = &self.directory.join(filename); + let file_path = &self.directory.join(file_name); match check_file_exists(file_path, &self.directory) { ErrorCode::FileExists => Message::send_error_to( &self.socket, @@ -130,17 +190,159 @@ impl Server { "file access violation", ), ErrorCode::FileNotFound => { - Worker::receive( - self.socket.local_addr()?, - *to, - file_path.to_path_buf(), - options.to_vec(), - ); - Ok(()) + let worker_options = parse_options(options, RequestType::Write)?; + + if self.single_port { + let mut socket = create_single_socket(&self.socket, to)?; + socket.set_read_timeout(worker_options.timeout)?; + socket.set_write_timeout(worker_options.timeout)?; + + self.clients.insert(*to, socket.sender()); + self.largest_block_size = + max(self.largest_block_size, worker_options.block_size); + accept_request(&socket, options, RequestType::Write)?; + + let worker = Worker::new( + socket, + file_path.clone(), + worker_options.block_size, + worker_options.timeout, + worker_options.window_size, + ); + worker.receive() + } else { + let socket = create_multi_socket(&self.socket.local_addr()?, to)?; + socket.set_read_timeout(Some(worker_options.timeout))?; + socket.set_write_timeout(Some(worker_options.timeout))?; + + accept_request(&socket, options, RequestType::Write)?; + + let worker = Worker::new( + socket, + file_path.clone(), + worker_options.block_size, + worker_options.timeout, + worker_options.window_size, + ); + worker.receive() + } } - _ => Err("unexpected error code when checking file".into()), + _ => Err("Unexpected error code when checking file".into()), } } + + fn route_packet(&self, packet: Packet, to: &SocketAddr) -> Result<(), Box> { + if self.clients.contains_key(to) { + self.clients[to].send(packet)?; + Ok(()) + } else { + Err("No client found for packet".into()) + } + } +} + +#[derive(Debug, PartialEq)] +struct WorkerOptions { + block_size: usize, + transfer_size: u64, + timeout: Duration, + window_size: u16, +} + +#[derive(Debug, PartialEq)] +enum RequestType { + Read(u64), + Write, +} + +fn parse_options( + options: &mut [TransferOption], + request_type: RequestType, +) -> Result { + let mut worker_options = WorkerOptions { + block_size: DEFAULT_BLOCK_SIZE, + transfer_size: 0, + timeout: DEFAULT_TIMEOUT, + window_size: DEFAULT_WINDOW_SIZE, + }; + + for option in options { + let TransferOption { + option: option_type, + value, + } = option; + + match option_type { + OptionType::BlockSize => worker_options.block_size = *value, + OptionType::TransferSize => match request_type { + RequestType::Read(size) => { + *value = size as usize; + worker_options.transfer_size = size; + } + RequestType::Write => worker_options.transfer_size = *value as u64, + }, + OptionType::Timeout => { + if *value == 0 { + return Err("Invalid timeout value"); + } + worker_options.timeout = Duration::from_secs(*value as u64); + } + OptionType::Windowsize => { + if *value == 0 || *value > u16::MAX as usize { + return Err("Invalid windowsize value"); + } + worker_options.window_size = *value as u16; + } + } + } + + Ok(worker_options) +} + +fn create_single_socket( + socket: &UdpSocket, + remote: &SocketAddr, +) -> Result> { + let socket = ServerSocket::new(socket.try_clone()?, *remote); + + Ok(socket) +} + +fn create_multi_socket( + addr: &SocketAddr, + remote: &SocketAddr, +) -> Result> { + let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?; + socket.connect(remote)?; + + Ok(socket) +} + +fn accept_request( + socket: &T, + options: &[TransferOption], + request_type: RequestType, +) -> Result<(), Box> { + if !options.is_empty() { + Message::send_oack(socket, options.to_vec())?; + if let RequestType::Read(_) = request_type { + check_response(socket)?; + } + } else if request_type == RequestType::Write { + Message::send_ack(socket, 0)? + } + + Ok(()) +} + +fn check_response(socket: &T) -> Result<(), Box> { + if let Packet::Ack(received_block_number) = Message::recv(socket)? { + if received_block_number != 0 { + Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?; + } + } + + Ok(()) } fn check_file_exists(file: &Path, directory: &PathBuf) -> ErrorCode { @@ -185,4 +387,69 @@ mod tests { &PathBuf::from("/dir/test") )); } + + #[test] + fn parses_write_options() { + let mut options = vec![ + TransferOption { + option: OptionType::BlockSize, + value: 1024, + }, + TransferOption { + option: OptionType::TransferSize, + value: 0, + }, + TransferOption { + option: OptionType::Timeout, + value: 5, + }, + ]; + + let work_type = RequestType::Read(12341234); + + let worker_options = parse_options(&mut options, work_type).unwrap(); + + assert_eq!(options[0].value, worker_options.block_size); + assert_eq!(options[1].value, worker_options.transfer_size as usize); + assert_eq!(options[2].value as u64, worker_options.timeout.as_secs()); + } + + #[test] + fn parses_read_options() { + let mut options = vec![ + TransferOption { + option: OptionType::BlockSize, + value: 1024, + }, + TransferOption { + option: OptionType::TransferSize, + value: 44554455, + }, + TransferOption { + option: OptionType::Timeout, + value: 5, + }, + ]; + + let work_type = RequestType::Write; + + let worker_options = parse_options(&mut options, work_type).unwrap(); + + assert_eq!(options[0].value, worker_options.block_size); + assert_eq!(options[1].value, worker_options.transfer_size as usize); + assert_eq!(options[2].value as u64, worker_options.timeout.as_secs()); + } + + #[test] + fn parses_default_options() { + assert_eq!( + parse_options(&mut [], RequestType::Write).unwrap(), + WorkerOptions { + block_size: DEFAULT_BLOCK_SIZE, + transfer_size: 0, + timeout: DEFAULT_TIMEOUT, + window_size: DEFAULT_WINDOW_SIZE, + } + ); + } } diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..9638d73 --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,207 @@ +use std::{ + error::Error, + net::{SocketAddr, UdpSocket}, + sync::{ + mpsc::{self, Receiver, Sender}, + Mutex, + }, + time::Duration, +}; + +use crate::Packet; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); + +/// Socket `trait` is used for easy message transmission of common TFTP +/// message types. This `trait` is implemented for [`UdpSocket`] and used +/// for abstraction of single socket communication. +pub trait Socket: Send + Sync + 'static { + /// Sends a [`Packet`] to the socket's connected remote [`Socket`]. + fn send(&self, packet: &Packet) -> Result<(), Box>; + /// Sends a [`Packet`] to the specified remote [`Socket`]. + fn send_to(&self, packet: &Packet, to: &SocketAddr) -> Result<(), Box>; + /// Receives a [`Packet`] from the socket's connected remote [`Socket`]. + fn recv(&self, buf: &mut [u8]) -> Result>; + /// Receives a [`Packet`] from any remote [`Socket`] and returns the [`SocketAddr`] + /// of the remote [`Socket`]. + fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box>; + /// Returns the remote [`SocketAddr`] if it exists. + fn remote_addr(&self) -> Result>; + /// Sets the read timeout for the [`Socket`]. + fn set_read_timeout(&mut self, dur: Duration) -> Result<(), Box>; + /// Sets the write timeout for the [`Socket`]. + fn set_write_timeout(&mut self, dur: Duration) -> Result<(), Box>; +} + +impl Socket for UdpSocket { + fn send(&self, packet: &Packet) -> Result<(), Box> { + self.send(&packet.serialize()?)?; + + Ok(()) + } + + fn send_to(&self, packet: &Packet, to: &SocketAddr) -> Result<(), Box> { + self.send_to(&packet.serialize()?, to)?; + + Ok(()) + } + + fn recv(&self, buf: &mut [u8]) -> Result> { + let amt = self.recv(buf)?; + let packet = Packet::deserialize(&buf[..amt])?; + + Ok(packet) + } + + fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box> { + let (amt, addr) = self.recv_from(buf)?; + let packet = Packet::deserialize(&buf[..amt])?; + + Ok((packet, addr)) + } + + fn remote_addr(&self) -> Result> { + Ok(self.peer_addr()?) + } + + fn set_read_timeout(&mut self, dur: Duration) -> Result<(), Box> { + UdpSocket::set_read_timeout(self, Some(dur))?; + + Ok(()) + } + + fn set_write_timeout(&mut self, dur: Duration) -> Result<(), Box> { + UdpSocket::set_write_timeout(self, Some(dur))?; + + Ok(()) + } +} + +/// ServerSocket `struct` is used as an abstraction layer for a server +/// [`Socket`]. This `struct` is used for abstraction of single socket +/// communication. +/// +/// # Example +/// +/// ```rust +/// use std::net::{SocketAddr, UdpSocket}; +/// use std::str::FromStr; +/// use tftpd::{Socket, ServerSocket, Packet}; +/// +/// let socket = ServerSocket::new( +/// UdpSocket::bind("127.0.0.1:0").unwrap(), +/// SocketAddr::from_str("127.0.0.1:50000").unwrap(), +/// ); +/// socket.send(&Packet::Ack(1)).unwrap(); +/// ``` +pub struct ServerSocket { + socket: UdpSocket, + remote: SocketAddr, + sender: Mutex>, + receiver: Mutex>, + timeout: Duration, +} + +impl Socket for ServerSocket { + fn send(&self, packet: &Packet) -> Result<(), Box> { + self.send_to(packet, &self.remote) + } + + fn send_to(&self, packet: &Packet, to: &SocketAddr) -> Result<(), Box> { + self.socket.send_to(&packet.serialize()?, to)?; + + Ok(()) + } + + fn recv(&self, _buf: &mut [u8]) -> Result> { + if let Ok(receiver) = self.receiver.lock() { + if let Ok(packet) = receiver.recv_timeout(self.timeout) { + Ok(packet) + } else { + Err("Failed to receive".into()) + } + } else { + Err("Failed to lock mutex".into()) + } + } + + fn recv_from(&self, buf: &mut [u8]) -> Result<(Packet, SocketAddr), Box> { + Ok((self.recv(buf)?, self.remote)) + } + + fn remote_addr(&self) -> Result> { + Ok(self.remote) + } + + fn set_read_timeout(&mut self, dur: Duration) -> Result<(), Box> { + self.timeout = dur; + + Ok(()) + } + + fn set_write_timeout(&mut self, dur: Duration) -> Result<(), Box> { + self.socket.set_write_timeout(Some(dur))?; + + Ok(()) + } +} + +impl ServerSocket { + /// Creates a new [`ServerSocket`] from a [`UdpSocket`] and a remote [`SocketAddr`]. + pub fn new(socket: UdpSocket, remote: SocketAddr) -> Self { + let (sender, receiver) = mpsc::channel(); + Self { + socket, + remote, + sender: Mutex::new(sender), + receiver: Mutex::new(receiver), + timeout: DEFAULT_TIMEOUT, + } + } + + /// Returns a [`Sender`] for sending [`Packet`]s to the remote [`Socket`]. + pub fn sender(&self) -> Sender { + self.sender.lock().unwrap().clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::str::FromStr; + + #[test] + fn test_recv() { + let socket = ServerSocket::new( + UdpSocket::bind("127.0.0.1:0").unwrap(), + SocketAddr::from_str("127.0.0.1:50000").unwrap(), + ); + + socket.sender.lock().unwrap().send(Packet::Ack(1)).unwrap(); + + let packet = socket.recv(&mut []).unwrap(); + + assert_eq!(packet, Packet::Ack(1)); + + socket + .sender + .lock() + .unwrap() + .send(Packet::Data { + block_num: 15, + data: vec![0x01, 0x02, 0x03], + }) + .unwrap(); + + let packet = socket.recv(&mut []).unwrap(); + + assert_eq!( + packet, + Packet::Data { + block_num: 15, + data: vec![0x01, 0x02, 0x03] + } + ); + } +} diff --git a/src/worker.rs b/src/worker.rs index 064c7c8..ff795d3 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,13 +1,15 @@ use std::{ error::Error, fs::{self, File}, - net::{SocketAddr, UdpSocket}, path::PathBuf, thread, - time::{Duration, SystemTime}, + time::{Duration, Instant}, }; -use crate::{ErrorCode, Message, OptionType, Packet, TransferOption, Window}; +use crate::{Message, Packet, Socket, Window}; + +const MAX_RETRIES: u32 = 6; +const TIMEOUT_BUFFER: Duration = Duration::from_secs(1); /// Worker `struct` is used for multithreaded file sending and receiving. /// It creates a new socket using the Server's IP and a random port @@ -18,55 +20,64 @@ use crate::{ErrorCode, Message, OptionType, Packet, TransferOption, Window}; /// # Example /// /// ```rust -/// use std::{net::SocketAddr, path::PathBuf, str::FromStr}; +/// use std::{net::{UdpSocket, SocketAddr}, path::PathBuf, str::FromStr, time::Duration}; /// use tftpd::Worker; /// /// // Send a file, responding to a read request. -/// Worker::send( -/// SocketAddr::from_str("127.0.0.1:1234").unwrap(), -/// SocketAddr::from_str("127.0.0.1:4321").unwrap(), -/// PathBuf::from_str("/home/rust/test.txt").unwrap(), -/// vec![] +/// let socket = UdpSocket::bind("127.0.0.1:0").unwrap(); +/// socket.connect(SocketAddr::from_str("127.0.0.1:12345").unwrap()).unwrap(); +/// +/// let worker = Worker::new( +/// socket, +/// PathBuf::from_str("Cargo.toml").unwrap(), +/// 512, +/// Duration::from_secs(1), +/// 1, /// ); +/// +/// worker.send().unwrap(); /// ``` -pub struct Worker; - -#[derive(Debug, PartialEq, Eq)] -struct WorkerOptions { +pub struct Worker +where + T: Socket, +{ + socket: T, + file_name: PathBuf, blk_size: usize, - t_size: usize, - timeout: u64, + timeout: Duration, windowsize: u16, } -#[derive(PartialEq, Eq)] -enum WorkType { - Receive, - Send(u64), -} +impl Worker +where + T: Socket, +{ + /// Creates a new [`Worker`] with the supplied options. + pub fn new( + socket: T, + file_name: PathBuf, + blk_size: usize, + timeout: Duration, + windowsize: u16, + ) -> Worker { + Worker { + socket, + file_name, + blk_size, + timeout, + windowsize, + } + } -const MAX_RETRIES: u32 = 6; -const DEFAULT_TIMEOUT_SECS: u64 = 5; -const TIMEOUT_BUFFER_SECS: u64 = 1; -const DEFAULT_BLOCK_SIZE: usize = 512; - -impl Worker { /// Sends a file to the remote [`SocketAddr`] that has sent a read request using /// a random port, asynchronously. - pub fn send( - addr: SocketAddr, - remote: SocketAddr, - file_path: PathBuf, - mut options: Vec, - ) { - thread::spawn(move || { - let mut handle_send = || -> Result<(), Box> { - let socket = setup_socket(&addr, &remote)?; - let work_type = WorkType::Send(file_path.metadata()?.len()); - let worker_options = parse_options(&mut options, &work_type)?; + pub fn send(self) -> Result<(), Box> { + let file_name = self.file_name.clone(); + let remote_addr = self.socket.remote_addr().unwrap(); - accept_request(&socket, &options, &work_type)?; - send_file(&socket, File::open(&file_path)?, &worker_options)?; + thread::spawn(move || { + let handle_send = || -> Result<(), Box> { + self.send_file(File::open(&file_name)?)?; Ok(()) }; @@ -75,8 +86,8 @@ impl Worker { Ok(_) => { println!( "Sent {} to {}", - file_path.file_name().unwrap().to_str().unwrap(), - remote + &file_name.file_name().unwrap().to_string_lossy(), + &remote_addr ); } Err(err) => { @@ -84,24 +95,19 @@ impl Worker { } } }); + + Ok(()) } /// Receives a file from the remote [`SocketAddr`] that has sent a write request using - /// a random port, asynchronously. - pub fn receive( - addr: SocketAddr, - remote: SocketAddr, - file_path: PathBuf, - mut options: Vec, - ) { - thread::spawn(move || { - let mut handle_receive = || -> Result<(), Box> { - let socket = setup_socket(&addr, &remote)?; - let work_type = WorkType::Receive; - let worker_options = parse_options(&mut options, &work_type)?; + /// the supplied socket, asynchronously. + pub fn receive(self) -> Result<(), Box> { + let file_name = self.file_name.clone(); + let remote_addr = self.socket.remote_addr().unwrap(); - accept_request(&socket, &options, &work_type)?; - receive_file(&socket, File::create(&file_path)?, &worker_options)?; + thread::spawn(move || { + let handle_receive = || -> Result<(), Box> { + self.receive_file(File::create(&file_name)?)?; Ok(()) }; @@ -110,129 +116,123 @@ impl Worker { Ok(_) => { println!( "Received {} from {}", - file_path.file_name().unwrap().to_str().unwrap(), - remote + &file_name.file_name().unwrap().to_string_lossy(), + remote_addr ); } Err(err) => { eprintln!("{err}"); - if fs::remove_file(&file_path).is_err() { - eprintln!( - "Error while cleaning {}", - file_path.file_name().unwrap().to_str().unwrap() - ); + if fs::remove_file(&file_name).is_err() { + eprintln!("Error while cleaning {}", &file_name.to_str().unwrap()); } } } }); - } -} -fn send_file( - socket: &UdpSocket, - file: File, - worker_options: &WorkerOptions, -) -> Result<(), Box> { - let mut block_number = 1; - let mut window = Window::new(worker_options.windowsize, worker_options.blk_size, file); - - loop { - let filled = window.fill()?; - - let mut retry_cnt = 0; - let mut time = - SystemTime::now() - Duration::from_secs(DEFAULT_TIMEOUT_SECS + TIMEOUT_BUFFER_SECS); - loop { - if time.elapsed()? >= Duration::from_secs(DEFAULT_TIMEOUT_SECS) { - send_window(socket, &window, block_number)?; - time = SystemTime::now(); - } - - match Message::recv(socket) { - Ok(Packet::Ack(received_block_number)) => { - let diff = received_block_number.wrapping_sub(block_number); - if diff <= worker_options.windowsize { - block_number = received_block_number.wrapping_add(1); - window.remove(diff + 1)?; - break; - } - } - Ok(Packet::Error { code, msg }) => { - return Err(format!("Received error code {code}: {msg}").into()); - } - _ => { - retry_cnt += 1; - if retry_cnt == MAX_RETRIES { - return Err(format!("Transfer timed out after {MAX_RETRIES} tries").into()); - } - } - } - } - - if !filled && window.is_empty() { - break; - } + Ok(()) } - Ok(()) -} - -fn receive_file( - socket: &UdpSocket, - file: File, - worker_options: &WorkerOptions, -) -> Result<(), Box> { - let mut block_number: u16 = 0; - let mut window = Window::new(worker_options.windowsize, worker_options.blk_size, file); - - loop { - let mut size; - let mut retry_cnt = 0; + fn send_file(self, file: File) -> Result<(), Box> { + let mut block_number = 1; + let mut window = Window::new(self.windowsize, self.blk_size, file); loop { - match Message::recv_with_size(socket, worker_options.blk_size) { - Ok(Packet::Data { - block_num: received_block_number, - data, - }) => { - if received_block_number == block_number.wrapping_add(1) { - block_number = received_block_number; - size = data.len(); - window.add(data)?; + let filled = window.fill()?; - if size < worker_options.blk_size { - break; - } + let mut retry_cnt = 0; + let mut time = Instant::now() - (self.timeout + TIMEOUT_BUFFER); + loop { + if time.elapsed() >= self.timeout { + send_window(&self.socket, &window, block_number)?; + time = Instant::now(); + } - if window.is_full() { + match Message::recv(&self.socket) { + Ok(Packet::Ack(received_block_number)) => { + let diff = received_block_number.wrapping_sub(block_number); + if diff <= self.windowsize { + block_number = received_block_number.wrapping_add(1); + window.remove(diff + 1)?; break; } } - } - Ok(Packet::Error { code, msg }) => { - return Err(format!("Received error code {code}: {msg}").into()); - } - _ => { - retry_cnt += 1; - if retry_cnt == MAX_RETRIES { - return Err(format!("Transfer timed out after {MAX_RETRIES} tries").into()); + Ok(Packet::Error { code, msg }) => { + return Err(format!("Received error code {code}: {msg}").into()); + } + _ => { + retry_cnt += 1; + if retry_cnt == MAX_RETRIES { + return Err( + format!("Transfer timed out after {MAX_RETRIES} tries").into() + ); + } } } } + + if !filled && window.is_empty() { + break; + } } - window.empty()?; - Message::send_ack(socket, block_number)?; - if size < worker_options.blk_size { - break; - }; + Ok(()) } - Ok(()) + fn receive_file(self, file: File) -> Result<(), Box> { + let mut block_number: u16 = 0; + let mut window = Window::new(self.windowsize, self.blk_size, file); + + loop { + let mut size; + let mut retry_cnt = 0; + + loop { + match Message::recv_with_size(&self.socket, self.blk_size) { + Ok(Packet::Data { + block_num: received_block_number, + data, + }) => { + if received_block_number == block_number.wrapping_add(1) { + block_number = received_block_number; + size = data.len(); + window.add(data)?; + + if size < self.blk_size { + break; + } + + if window.is_full() { + break; + } + } + } + Ok(Packet::Error { code, msg }) => { + return Err(format!("Received error code {code}: {msg}").into()); + } + _ => { + retry_cnt += 1; + if retry_cnt == MAX_RETRIES { + return Err( + format!("Transfer timed out after {MAX_RETRIES} tries").into() + ); + } + } + } + } + + window.empty()?; + Message::send_ack(&self.socket, block_number)?; + if size < self.blk_size { + break; + }; + } + + Ok(()) + } } -fn send_window( - socket: &UdpSocket, +fn send_window( + socket: &T, window: &Window, mut block_num: u16, ) -> Result<(), Box> { @@ -243,151 +243,3 @@ fn send_window( Ok(()) } - -fn accept_request( - socket: &UdpSocket, - options: &Vec, - work_type: &WorkType, -) -> Result<(), Box> { - if !options.is_empty() { - Message::send_oack(socket, options.to_vec())?; - if let WorkType::Send(_) = work_type { - check_response(socket)?; - } - } else if *work_type == WorkType::Receive { - Message::send_ack(socket, 0)? - } - - Ok(()) -} - -fn check_response(socket: &UdpSocket) -> Result<(), Box> { - if let Packet::Ack(received_block_number) = Message::recv(socket)? { - if received_block_number != 0 { - Message::send_error(socket, ErrorCode::IllegalOperation, "invalid oack response")?; - } - } - - Ok(()) -} - -fn setup_socket(addr: &SocketAddr, remote: &SocketAddr) -> Result> { - let socket = UdpSocket::bind(SocketAddr::from((addr.ip(), 0)))?; - socket.connect(remote)?; - socket.set_read_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?; - socket.set_write_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?; - Ok(socket) -} - -fn parse_options( - options: &mut Vec, - work_type: &WorkType, -) -> Result> { - let mut worker_options = WorkerOptions { - blk_size: DEFAULT_BLOCK_SIZE, - t_size: 0, - timeout: DEFAULT_TIMEOUT_SECS, - windowsize: 1, - }; - - for option in &mut *options { - let TransferOption { option, value } = option; - - match option { - OptionType::BlockSize => worker_options.blk_size = *value, - OptionType::TransferSize => match work_type { - WorkType::Send(size) => { - *value = *size as usize; - worker_options.t_size = *size as usize; - } - WorkType::Receive => { - worker_options.t_size = *value; - } - }, - OptionType::Timeout => { - if *value == 0 { - return Err("Invalid timeout value".into()); - } - worker_options.timeout = *value as u64; - } - OptionType::Windowsize => { - if *value == 0 || *value > u16::MAX as usize { - return Err("Invalid windowsize value".into()); - } - worker_options.windowsize = *value as u16; - } - } - } - - Ok(worker_options) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parses_send_options() { - let mut options = vec![ - TransferOption { - option: OptionType::BlockSize, - value: 1024, - }, - TransferOption { - option: OptionType::TransferSize, - value: 0, - }, - TransferOption { - option: OptionType::Timeout, - value: 5, - }, - ]; - - let work_type = WorkType::Send(12341234); - - let worker_options = parse_options(&mut options, &work_type).unwrap(); - - assert_eq!(options[0].value, worker_options.blk_size); - assert_eq!(options[1].value, worker_options.t_size); - assert_eq!(options[2].value as u64, worker_options.timeout); - } - - #[test] - fn parses_receive_options() { - let mut options = vec![ - TransferOption { - option: OptionType::BlockSize, - value: 1024, - }, - TransferOption { - option: OptionType::TransferSize, - value: 44554455, - }, - TransferOption { - option: OptionType::Timeout, - value: 5, - }, - ]; - - let work_type = WorkType::Receive; - - let worker_options = parse_options(&mut options, &work_type).unwrap(); - - assert_eq!(options[0].value, worker_options.blk_size); - assert_eq!(options[1].value, worker_options.t_size); - assert_eq!(options[2].value as u64, worker_options.timeout); - } - - #[test] - fn parses_default_options() { - assert_eq!( - parse_options(&mut vec![], &WorkType::Receive).unwrap(), - WorkerOptions { - blk_size: DEFAULT_BLOCK_SIZE, - t_size: 0, - timeout: DEFAULT_TIMEOUT_SECS, - windowsize: 1, - } - ); - } -}