diff --git a/src/config.rs b/src/config.rs index 0a67240..c870d71 100644 --- a/src/config.rs +++ b/src/config.rs @@ -31,6 +31,8 @@ pub struct Config { pub read_only: bool, /// Duplicate all packets sent from the server. (default: 1) pub duplicate_packets: u8, + /// Overwrite existing files. (default: false) + pub overwrite: bool, } impl Config { @@ -44,6 +46,7 @@ impl Config { single_port: false, read_only: false, duplicate_packets: 1, + overwrite: false, }; args.next(); @@ -88,24 +91,33 @@ impl Config { println!( " -p, --port \t\tSet the listening port of the server (default: 69)" ); - println!(" -d, --directory \tSet the listening port of the server (default: Current Working Directory)"); + println!(" -d, --directory \tSet the serving directory (default: Current Working Directory)"); println!(" -s, --single-port\t\tUse a single port for both sending and receiving (default: false)"); println!(" -r, --read-only\t\tRefuse all write requests, making the server read-only (default: false)"); - println!(" --duplicate-packets \tDuplicate all packets sent from the server (default: 1)"); + println!(" --duplicate-packets \tDuplicate all packets sent from the server (default: 0)"); + println!(" --overwrite\t\t\tOverwrite existing files (default: false)"); println!(" -h, --help\t\t\tPrint help information"); process::exit(0); } "--duplicate-packets" => { if let Some(duplicate_packets_str) = args.next() { let duplicate_packets = duplicate_packets_str.parse::()?; - if duplicate_packets < 1 { - return Err("Duplicate packets must be greater than 0".into()); + + if duplicate_packets == u8::MAX { + return Err(format!( + "Duplicate packets should be less than {}", + u8::MAX + ) + .into()); } config.duplicate_packets = duplicate_packets; } else { return Err("Missing duplicate packets after flag".into()); } } + "--overwrite" => { + config.overwrite = true; + } invalid => return Err(format!("Invalid flag: {invalid}").into()), } @@ -178,13 +190,6 @@ mod tests { #[test] fn returns_error_on_invalid_duplicate_packets() { - assert!(Config::new( - ["/", "--duplicate-packets", "0"] - .iter() - .map(|s| s.to_string()), - ) - .is_err()); - assert!(Config::new( ["/", "--duplicate-packets", "-1"] .iter() @@ -192,4 +197,14 @@ mod tests { ) .is_err()); } + + #[test] + fn returns_error_on_max_duplicate_packets() { + assert!(Config::new( + ["/", "--duplicate-packets", format!("{}", u8::MAX).as_str()] + .iter() + .map(|s| s.to_string()), + ) + .is_err()); + } } diff --git a/src/server.rs b/src/server.rs index 436de8e..49eb300 100644 --- a/src/server.rs +++ b/src/server.rs @@ -32,6 +32,7 @@ pub struct Server { directory: PathBuf, single_port: bool, read_only: bool, + overwrite: bool, duplicate_packets: u8, largest_block_size: usize, clients: HashMap>, @@ -47,6 +48,7 @@ impl Server { directory: config.directory.clone(), single_port: config.single_port, read_only: config.read_only, + overwrite: config.overwrite, duplicate_packets: config.duplicate_packets, largest_block_size: DEFAULT_BLOCK_SIZE, clients: HashMap::new(), @@ -94,7 +96,7 @@ impl Server { { eprintln!("Could not send error packet"); }; - eprintln!("Received invalid request"); + eprintln!("Received write request while in read-only mode"); continue; } println!("Receiving {filename} from {from}"); @@ -179,7 +181,7 @@ impl Server { worker_options.block_size, worker_options.timeout, worker_options.window_size, - self.duplicate_packets, + self.duplicate_packets + 1, ); worker.send() } @@ -194,15 +196,51 @@ impl Server { to: &SocketAddr, ) -> Result<(), Box> { let file_path = &self.directory.join(file_name); + let initialize_write = &mut || -> Result<(), Box> { + let worker_options = parse_options(options, RequestType::Write)?; + let mut socket: Box; + + if self.single_port { + let single_socket = create_single_socket(&self.socket, to)?; + self.clients.insert(*to, single_socket.sender()); + self.largest_block_size = max(self.largest_block_size, worker_options.block_size); + + socket = Box::new(single_socket); + } else { + socket = Box::new(create_multi_socket(&self.socket.local_addr()?, to)?); + } + + socket.set_read_timeout(worker_options.timeout)?; + socket.set_write_timeout(worker_options.timeout)?; + + accept_request(&socket, options, RequestType::Write)?; + + let worker = Worker::new( + socket, + file_path.clone(), + worker_options.block_size, + worker_options.timeout, + worker_options.window_size, + self.duplicate_packets, + ); + worker.receive() + }; + match check_file_exists(file_path, &self.directory) { - ErrorCode::FileExists => Socket::send_to( - &self.socket, - &Packet::Error { - code: ErrorCode::FileExists, - msg: "requested file already exists".to_string(), - }, - to, - ), + ErrorCode::FileExists => { + if self.overwrite { + initialize_write() + } else { + Socket::send_to( + &self.socket, + &Packet::Error { + code: ErrorCode::FileExists, + msg: "requested file already exists".to_string(), + }, + to, + ) + } + } ErrorCode::AccessViolation => Socket::send_to( &self.socket, &Packet::Error { @@ -211,36 +249,7 @@ impl Server { }, to, ), - ErrorCode::FileNotFound => { - let worker_options = parse_options(options, RequestType::Write)?; - let mut socket: Box; - - if self.single_port { - let single_socket = create_single_socket(&self.socket, to)?; - self.clients.insert(*to, single_socket.sender()); - self.largest_block_size = - max(self.largest_block_size, worker_options.block_size); - - socket = Box::new(single_socket); - } else { - socket = Box::new(create_multi_socket(&self.socket.local_addr()?, to)?); - } - - socket.set_read_timeout(worker_options.timeout)?; - socket.set_write_timeout(worker_options.timeout)?; - - accept_request(&socket, options, RequestType::Write)?; - - let worker = Worker::new( - socket, - file_path.clone(), - worker_options.block_size, - worker_options.timeout, - worker_options.window_size, - self.duplicate_packets, - ); - worker.receive() - } + ErrorCode::FileNotFound => initialize_write(), _ => Err("Unexpected error code when checking file".into()), } } diff --git a/src/worker.rs b/src/worker.rs index bd5c61d..6fd4801 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -9,7 +9,6 @@ use std::{ const MAX_RETRIES: u32 = 6; const TIMEOUT_BUFFER: Duration = Duration::from_secs(1); -const DEFAULT_DUPLICATE_DELAY: Duration = Duration::from_millis(1); /// Worker `struct` is used for multithreaded file sending and receiving. /// It creates a new socket using the Server's IP and a random port @@ -44,7 +43,7 @@ pub struct Worker { blk_size: usize, timeout: Duration, windowsize: u16, - duplicate_packets: u8, + repeat_amount: u8, } impl Worker { @@ -55,7 +54,7 @@ impl Worker { blk_size: usize, timeout: Duration, windowsize: u16, - duplicate_packets: u8, + repeat_amount: u8, ) -> Worker { Worker { socket, @@ -63,7 +62,7 @@ impl Worker { blk_size, timeout, windowsize, - duplicate_packets, + repeat_amount, } } @@ -242,10 +241,7 @@ impl Worker { } fn send_packet(&self, packet: &Packet) -> Result<(), Box> { - for i in 0..self.duplicate_packets { - if i > 0 { - std::thread::sleep(DEFAULT_DUPLICATE_DELAY); - } + for _ in 0..self.repeat_amount { self.socket.send(packet)?; }