diff --git a/src/interfaces/ip.rs b/src/interfaces/ip.rs index 3063cd3..6c61734 100644 --- a/src/interfaces/ip.rs +++ b/src/interfaces/ip.rs @@ -30,6 +30,7 @@ pub struct IPInterface { listener: net::TcpListener, peers: Vec, package_queue: Vec<(IPPackage, String /* from_peer */)>, + main_loop_iterations: u64, } #[derive(Debug, Clone)] @@ -66,6 +67,11 @@ impl MessageType { } } + +fn compare_addrs(peer: &Peer, addr: net::SocketAddr) -> bool { + addr.ip() == peer.0 && addr.port() == peer.1 +} + impl InterfaceRequirements for IPInterface {} impl Interface for IPInterface { @@ -162,6 +168,18 @@ impl Interface for IPInterface { } self.connections.extend(new_connections); + self.main_loop_iterations += 1; + // Every 50 iterations we connect to everyone we know + if self.main_loop_iterations % 50 == 0 { + let connected_addresses = self.connections.iter().filter_map(|conn| conn.peer_addr().ok()).collect::>(); + let peers_we_do_not_have_connections_with = self.peers.iter().filter(|p| !connected_addresses.iter().any(|addr| compare_addrs(p, *addr))).copied().collect::>(); + self.connections.extend(IPInterface::get_connections_to_peers(&peers_we_do_not_have_connections_with, self.peers.len() < PEER_THRESHOLD * 2)); + } + // We do a peer exchange every 30 iterations + if self.main_loop_iterations % 30 == 0 { + let connection_index = (self.main_loop_iterations / 30) as usize % self.connections.len(); + IPInterface::request_peers(&mut self.connections[connection_index])?; + } Ok(()) } @@ -206,13 +224,15 @@ impl Interface for IPInterface { Ok(()) } fn receive(&mut self) -> IFResult> { - if !self.package_queue.is_empty() { - println!( - "({:?}): New message from {}", - self.listener.local_addr().unwrap(), - self.package_queue.last().unwrap().1 - ); - } + // if !self.package_queue.is_empty() { + // println!( + // "({:?}): New message from {}. By the way, I know {} peers and have {} connections", + // self.listener.local_addr().unwrap(), + // self.package_queue.last().unwrap().1, + // self.peers.len(), + // self.connections.len() + // ); + // } match self.package_queue.pop() { Some((ip_package, data)) => Ok(Some((ip_package.message, data))), None => Ok(None), @@ -221,6 +241,22 @@ impl Interface for IPInterface { } impl IPInterface { + fn get_connections_to_peers(peers: &[Peer], do_peer_request: bool) -> Vec { + peers + .par_iter() + .map(Self::new_connection) + .filter_map(|r| r.ok()) + .filter_map(|r| r) + .map(|mut c| -> IFResult { + if do_peer_request { + Self::request_peers(&mut c)?; + } + Ok(c) + }) + .filter_map(|r| r.ok()) + .collect::>() + } + pub fn new(port: u16, peers: Vec) -> IFResult { let listener = match create_tcp_listener(port) { Some(listener) => listener, @@ -233,17 +269,7 @@ impl IPInterface { listener.set_nonblocking(true)?; - let connections = peers - .par_iter() - .map(Self::new_connection) - .filter_map(|r| r.ok()) - .filter_map(|r| r) - .map(|mut c| -> IFResult { - Self::request_peers(&mut c)?; - Ok(c) - }) - .filter_map(|r| r.ok()) - .collect::>(); + let connections = Self::get_connections_to_peers(&peers, true); Ok(IPInterface { id: String::from("IP Interface"), @@ -251,6 +277,7 @@ impl IPInterface { listener, peers, package_queue: vec![], + main_loop_iterations: 0, }) } pub fn dump(&self) -> IFResult> { @@ -298,10 +325,6 @@ impl IPInterface { } fn obtain_connection(&mut self, addr: &Peer) -> IFResult { - fn compare_addrs(peer: &Peer, addr: net::SocketAddr) -> bool { - addr.ip() == peer.0 && addr.port() == peer.1 - } - if let Some(pos) = self.connections.iter().position(|con| { con.peer_addr().is_ok() && compare_addrs(addr, con.peer_addr().unwrap()) }) { @@ -319,7 +342,7 @@ impl IPInterface { for port in addr.1..addr.1 + 3 { match net::TcpStream::connect_timeout( &net::SocketAddr::new(addr.0, port as u16), - Duration::from_millis(500), + Duration::from_millis(300), ) { Ok(connection) => { return Ok(Some(connection)); @@ -419,7 +442,7 @@ fn test_creating_connection() -> IFResult<()> { } #[cfg(test)] -fn create_test_interfaces(n: usize) -> impl Iterator { +pub fn create_test_interfaces(n: usize) -> impl Iterator { let ip_addr = std::net::IpAddr::from_str("127.0.0.1").unwrap(); (0..n).map(move |i| { IPInterface::new( @@ -428,7 +451,7 @@ fn create_test_interfaces(n: usize) -> impl Iterator { // .filter(|j| *j != i) // .map(|j| (ip_addr, (5000 + 5 * j) as u16)) // .collect(), - vec![(ip_addr, (5000 + 5 * ((i + 5) % n)) as u16)], + vec![(ip_addr, (5000 + 5 * ((i + 1) % n)) as u16)], ) .unwrap() }) diff --git a/src/ironforce.rs b/src/ironforce.rs index 4e2a0df..4e61b0d 100644 --- a/src/ironforce.rs +++ b/src/ironforce.rs @@ -60,7 +60,7 @@ impl IronForce { } /// Create a new tunnel to another node - fn initialize_tunnel_creation(&mut self, destination: PublicKey) -> IFResult<()> { + fn initialize_tunnel_creation(&mut self, destination: &PublicKey) -> IFResult<()> { let tunnel = TunnelPublic::new_singlecast(); #[cfg(std)] println!( @@ -107,7 +107,7 @@ impl IronForce { } else { return Err(IFError::TunnelNotFound); }; - message.tunnel_id = tunnel_id; + message.tunnel_id = (tunnel_id, tunnel.peer_ids.0 != 0); let peer_ids = match (direction, tunnel.peer_ids) { (_, (x, 0)) => vec![x], (_, (0, x)) => vec![x], @@ -124,15 +124,15 @@ impl IronForce { /// Send a message to another node, /// creating a new tunnel if needed - pub fn send_message(&mut self, message: Message, destination: PublicKey) -> IFResult<()> { + pub fn send_message(&mut self, message: Message, destination: &PublicKey) -> IFResult<()> { if let Some(Some(tunnel_id)) = self .tunnels .iter() .find(|t| { - t.target_node.as_ref() == Some(&destination) + t.target_node.as_ref() == Some(destination) || t.nodes_in_tunnel .as_ref() - .map(|nodes| nodes.contains(&destination)) + .map(|nodes| nodes.contains(destination)) == Some(true) }) .map(|tunnel| tunnel.id) @@ -171,18 +171,12 @@ impl IronForce { let tunnel = Tunnel { id: tunnel_pub.id, local_ids: tunnel_pub.local_ids.clone(), - peer_ids: (inc_peer, 0), + peer_ids: (0, inc_peer), ttd: 0, nodes_in_tunnel: None, is_multicast: false, target_node: Some(sender), }; - #[cfg(feature = "std")] - println!( - "[{}] Got an incoming tunnel for me! {:?}", - self.short_id(), - tunnel_pub - ); self.tunnels.push(tunnel); self.transport.send_message( serde_cbor::to_vec( @@ -192,7 +186,7 @@ impl IronForce { tunnel_pub.clone(), ), )) - .tunnel(tunnel_pub.id.unwrap()) + .tunnel((tunnel_pub.id.unwrap(), false)) .sign(&self.keys) .build()?, )?, @@ -255,7 +249,17 @@ impl IronForce { } } } - MessageType::SingleCast | MessageType::Broadcast => self.messages.push(message), + MessageType::SingleCast if message.check_recipient(&self.keys) => self.messages.push(message.clone()), + MessageType::SingleCast => { + if let Some(tunnel) = self.tunnels.iter().find(|tun| tun.id == Some(message.tunnel_id.0)) { + let peer_id = if message.tunnel_id.1 { tunnel.peer_ids.0 } else { tunnel.peer_ids.1 }; + self.transport.send_message(serde_cbor::to_vec(&message)?, Some(peer_id))?; + } + } + MessageType::Broadcast => { + self.messages.push(message.clone()); + self.send_to_all(message)?; + }, } Ok(()) } @@ -323,7 +327,7 @@ mod if_testing { fn test_creating_a_tunnel() -> IFResult<()> { let mut network = create_test_network(); let key_1 = network[1].keys.get_public(); - network[0].initialize_tunnel_creation(key_1)?; + network[0].initialize_tunnel_creation(&key_1)?; network[0].main_loop_iteration()?; network[1].main_loop_iteration()?; network[0].main_loop_iteration()?; @@ -335,7 +339,7 @@ mod if_testing { fn test_sending_message() -> IFResult<()> { let mut network = create_test_network(); let key_1 = network[1].keys.get_public(); - network[0].initialize_tunnel_creation(key_1.clone())?; + network[0].initialize_tunnel_creation(&key_1)?; network[0].main_loop_iteration()?; network[1].main_loop_iteration()?; network[0].main_loop_iteration()?; @@ -344,10 +348,10 @@ mod if_testing { Message::build() .message_type(MessageType::SingleCast) .sign(&zero_keys) - .recipient(key_1.clone()) + .recipient(&key_1) .content(b"hello".to_vec()) .build()?, - key_1, + &key_1, )?; network[1].main_loop_iteration()?; let msg = network[1].read_message(); @@ -367,29 +371,29 @@ mod if_testing { #[cfg(feature = "std")] mod test_with_ip { use crate::crypto::Keys; - use crate::interfaces::ip::IPInterface; use crate::ironforce::IronForce; use crate::res::IFResult; use crate::transport::Transport; use alloc::boxed::Box; use alloc::vec; use alloc::vec::Vec; - use core::str::FromStr; use std::println; + use crate::interfaces::ip::create_test_interfaces; + use crate::message::{Message, MessageType}; - fn create_test_interfaces(n: usize) -> impl Iterator { - let ip_addr = std::net::IpAddr::from_str("127.0.0.1").unwrap(); - (0..n).map(move |i| { - IPInterface::new( - (5000 + 5 * i) as u16, - (0..n) - .filter(|j| *j != i) - .map(|j| (ip_addr, (5000 + 5 * j) as u16)) - .collect(), - ) - .unwrap() - }) - } + // fn create_test_interfaces(n: usize) -> impl Iterator { + // let ip_addr = std::net::IpAddr::from_str("127.0.0.1").unwrap(); + // (0..n).map(move |i| { + // IPInterface::new( + // (5000 + 5 * i) as u16, + // (0..n) + // .filter(|j| *j != i) + // .map(|j| (ip_addr, (5000 + 5 * j) as u16)) + // .collect(), + // ) + // .unwrap() + // }) + // } fn create_test_network() -> Vec { let interfaces = create_test_interfaces(4); @@ -411,43 +415,47 @@ mod test_with_ip { .collect() } + // MAIN TEST RIGHT HERE #[test] - fn test_creating_a_tunnel() -> IFResult<()> { + fn test_creating_a_tunnel_and_sending_message() -> IFResult<()> { let mut network = create_test_network(); let key_1 = network[1].keys.get_public(); let (mut node0, mut node1) = (network.remove(0), network.remove(0)); + let node0_keys = node0.keys.clone(); println!("node0 id: {}", node0.short_id()); println!("node1 id: {}", node1.short_id()); let (mut node2, mut node3) = (network.remove(0), network.remove(0)); let t1 = std::thread::spawn(move || { - for _i in 0..5 { + for _i in 0..170 { // println!("Iteration {} (1)", i); node0.main_loop_iteration().unwrap(); + std::thread::sleep(std::time::Duration::from_millis(10)); } node0 }); let t2 = std::thread::spawn(move || { - for _i in 0..15 { + for _i in 0..250 { // println!("Iteration {} (2)", i); node1.main_loop_iteration().unwrap(); + std::thread::sleep(std::time::Duration::from_millis(10)); } node1 }); std::thread::spawn(move || loop { node2.main_loop_iteration().unwrap(); - std::thread::sleep(std::time::Duration::from_secs(5)); + std::thread::sleep(std::time::Duration::from_millis(10)); }); std::thread::spawn(move || loop { - std::thread::sleep(std::time::Duration::from_secs(5)); + std::thread::sleep(std::time::Duration::from_millis(10)); node3.main_loop_iteration().unwrap(); }); let mut node0 = t1.join().unwrap(); - node0.initialize_tunnel_creation(key_1)?; + node0.initialize_tunnel_creation(&key_1)?; let mut node1 = t2.join().unwrap(); let t1 = std::thread::spawn(move || { for _ in 0..18 { node0.main_loop_iteration().unwrap(); - std::thread::sleep(std::time::Duration::from_millis(150)); + std::thread::sleep(std::time::Duration::from_millis(50)); } node0 }); @@ -457,9 +465,27 @@ mod test_with_ip { } node1 }); - let node0 = t1.join().unwrap(); - t2.join().unwrap(); + let mut node0 = t1.join().unwrap(); + let mut node1 = t2.join().unwrap(); assert!(!node0.tunnels.is_empty()); + node0.send_message( + Message::build() + .message_type(MessageType::SingleCast) + .content(b"Hello!".to_vec()) + .recipient(&key_1) + .sign(&node0_keys) + .build()?, + &key_1)?; + let t2 = std::thread::spawn(move || { + for _ in 0..18 { + node1.main_loop_iteration().unwrap(); + } + node1 + }); + let mut node1 = t2.join().unwrap(); + let msg = node1.read_message(); + assert!(msg.is_some()); + assert_eq!(msg.unwrap().get_decrypted(&node1.keys)?, b"Hello!".to_vec()); Ok(()) } } diff --git a/src/message.rs b/src/message.rs index eced2e9..7d5631b 100644 --- a/src/message.rs +++ b/src/message.rs @@ -137,8 +137,8 @@ pub struct Message { hash: Vec, /// Optional: hash of the message encrypted for the recipient, so that the recipient can know that this message is for them, but nobody else recipient_verification: Option>, - /// ID of the tunnel that is used - pub tunnel_id: u64, + /// ID of the tunnel that is used and the direction + pub tunnel_id: (u64, bool), /// Network info network_info: NetworkInfo, } @@ -256,7 +256,7 @@ pub struct MessageBuilder { /// Recipient's public key (if present, the content will be encrypted and recipient verification field will be set) recipient: Option, /// ID of the tunnel that is used - tunnel_id: u64, + tunnel_id: (u64, bool), } impl MessageBuilder { @@ -267,7 +267,7 @@ impl MessageBuilder { message_type: None, sender: None, recipient: None, - tunnel_id: 0, + tunnel_id: (0, false), } } @@ -283,13 +283,13 @@ impl MessageBuilder { } /// Set message's recipient (and therefore set recipient verification and encrypt the content) - pub fn recipient(mut self, recipient: PublicKey) -> Self { - self.recipient = Some(recipient); + pub fn recipient(mut self, recipient: &PublicKey) -> Self { + self.recipient = Some(recipient.clone()); self } /// Set tunnel id - pub fn tunnel(mut self, tunnel_id: u64) -> Self { + pub fn tunnel(mut self, tunnel_id: (u64, bool)) -> Self { self.tunnel_id = tunnel_id; self } @@ -383,8 +383,8 @@ fn test_building_message() -> IFResult<()> { let msg = Message::build() .content(b"hello".to_vec()) .sign(&keys_1) - .recipient(keys_2.get_public()) - .tunnel(1) + .recipient(&keys_2.get_public()) + .tunnel((1, false)) .message_type(MessageType::SingleCast) .build()?; assert!(msg.verify_hash());