From 2240ae80ae43ce75bf5db34ba41eee24742e8238 Mon Sep 17 00:00:00 2001 From: ennucore Date: Sun, 12 Dec 2021 21:47:23 +0300 Subject: [PATCH] Fix some IP things for devices behind NAT and bad connections --- Cargo.lock | 3 + degeon/Cargo.toml | 7 +- degeon/src/gui_events.rs | 6 + degeon/src/main.rs | 56 +------- degeon/src/message.rs | 4 +- degeon/src/state.rs | 271 +++++++++++++++++++++++++++++++----- src/crypto.rs | 22 ++- src/interfaces/ip.rs | 291 ++++++++++++++++++++++++++------------- src/ironforce.rs | 53 +++++-- src/message.rs | 37 ++--- src/transport.rs | 9 +- src/tunnel.rs | 4 +- 12 files changed, 549 insertions(+), 214 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2243030..dfebbfc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -553,9 +553,12 @@ name = "degeon" version = "0.1.0" dependencies = [ "base64", + "futures", "iced", + "iced_native", "ironforce", "serde", + "serde_json", ] [[package]] diff --git a/degeon/Cargo.toml b/degeon/Cargo.toml index 7b1985d..b7354eb 100644 --- a/degeon/Cargo.toml +++ b/degeon/Cargo.toml @@ -6,7 +6,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -iced = { version = "0.3", features = ["glow"] } +iced = { version = "0.3.0", features = ["glow"] } ironforce = { path = "../", features = ["std"] } base64 = "0.13.0" -serde = { version = "1.0" } \ No newline at end of file +serde = { version = "1.0" } +serde_json = "1.0.72" +futures = "0.3.18" +iced_native = "0.4.0" \ No newline at end of file diff --git a/degeon/src/gui_events.rs b/degeon/src/gui_events.rs index a7a7125..d1c85ac 100644 --- a/degeon/src/gui_events.rs +++ b/degeon/src/gui_events.rs @@ -1,6 +1,12 @@ +use crate::message::DegMessage; +use ironforce::PublicKey; + #[derive(Clone, Debug)] pub enum GuiEvent { ChatSelect(usize), Typed(String), SendClick, + NewChat(PublicKey), + NewMessageInChat(PublicKey, DegMessage), + SetName(PublicKey, String), } diff --git a/degeon/src/main.rs b/degeon/src/main.rs index 21562a9..0c1a974 100644 --- a/degeon/src/main.rs +++ b/degeon/src/main.rs @@ -3,65 +3,11 @@ mod message; mod state; mod gui_events; -use iced::Sandbox; +use iced::Application; use ironforce::res::IFResult; use ironforce::{IronForce, Message, MessageType, PublicKey}; use crate::state::State; -fn main_if() -> IFResult<()> { - let ironforce = IronForce::from_file("".to_string())?; - let if_keys = ironforce.keys.clone(); - println!( - "Our public key: {}", - base64::encode(if_keys.get_public().to_vec().as_slice()) - ); - let (_thread, if_mutex) = ironforce.launch_main_loop(100); - let stdin = std::io::stdin(); - let if_mutex_clone = if_mutex.clone(); - let if_keys_clone = if_keys.clone(); - std::thread::spawn(move || loop { - if let Some(msg) = if_mutex_clone.lock().unwrap().read_message() { - println!( - "New message: {}", - String::from_utf8(msg.get_decrypted(&if_keys_clone).unwrap()).unwrap() - ); - } - std::thread::sleep(std::time::Duration::from_millis(300)) - }); - loop { - let mut buf = String::new(); - stdin.read_line(&mut buf)?; - let msg_base = if buf.starts_with('>') { - let target_base64 = buf - .split(')') - .next() - .unwrap() - .trim_start_matches(">(") - .to_string(); - let target = if let Ok(res) = base64::decode(target_base64) { - res - } else { - println!("Wrong b64."); - continue; - }; - buf = buf - .split(')') - .skip(1) - .map(|s| s.to_string()) - .collect::>() - .join(")"); - Message::build() - .message_type(MessageType::SingleCast) - .recipient(&PublicKey::from_vec(target).unwrap()) - } else { - Message::build().message_type(MessageType::Broadcast) - }; - if_mutex - .lock() - .unwrap() - .send_to_all(msg_base.content(buf.into_bytes()).sign(&if_keys).build()?)?; - } -} fn main() -> Result<(), Box> { // let ironforce = IronForce::from_file("".to_string()).unwrap(); diff --git a/degeon/src/message.rs b/degeon/src/message.rs index 3edda2c..a840a17 100644 --- a/degeon/src/message.rs +++ b/degeon/src/message.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize)] -pub enum Message { +pub enum DegMessage { Text(String), File(Vec), Service(ServiceMsg), @@ -11,4 +11,6 @@ pub enum Message { pub enum ServiceMsg { NameRequest, NameStatement(String), + Ping, + HiThere, } diff --git a/degeon/src/state.rs b/degeon/src/state.rs index 5b9be0f..b245292 100644 --- a/degeon/src/state.rs +++ b/degeon/src/state.rs @@ -1,33 +1,49 @@ use crate::gui_events::GuiEvent; -use crate::message::Message; +use crate::message::{DegMessage, ServiceMsg}; use core::default::Default; +use futures::Stream; use iced::{ - button, Align, Button, Column, Element, HorizontalAlignment, Length, Row, Sandbox, Settings, + button, Align, Application, Button, Column, Element, HorizontalAlignment, Length, Row, Text, TextInput, VerticalAlignment, }; -use ironforce::{Keys, PublicKey}; -use serde::{Deserialize, Serialize}; +use ironforce::res::{IFError, IFResult}; +use ironforce::{IronForce, Keys, Message, MessageType, PublicKey}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; #[derive(Clone, Debug)] pub struct Chat { pkey: PublicKey, - messages: Vec<(bool, Message)>, + messages: Vec<(bool, DegMessage)>, name: String, scrolled: f32, pub input: String, } -pub fn view_message(msg: &(bool, Message)) -> Option> { +impl Chat { + pub fn new(pkey: PublicKey) -> Self { + Self { + pkey, + messages: vec![], + name: "".to_string(), + scrolled: 0.0, + input: "".to_string(), + } + } +} + +pub fn view_message(msg: &(bool, DegMessage)) -> Option> { let msg = &msg.1; match msg { - Message::Text(t) => Some( + DegMessage::Text(t) => Some( iced::Container::new(Text::new(t.as_str())) .padding(10) .style(style::Container::Message) .into(), ), - Message::File(_) => None, - Message::Service(_) => None, + DegMessage::File(_) => None, + DegMessage::Service(_) => None, } } @@ -54,7 +70,11 @@ mod style { })), border_radius: 5.0, shadow_offset: Vector::new(1.0, 1.0), - text_color: if self != &Button::InactiveChat { Color::WHITE } else { Color::BLACK }, + text_color: if self != &Button::InactiveChat { + Color::WHITE + } else { + Color::BLACK + }, ..button::Style::default() } } @@ -118,11 +138,20 @@ impl Chat { .into() } - pub fn preview<'a>(&'a self, state: &'a mut button::State, i: usize, is_selected: bool) -> Element<'a, GuiEvent> { + pub fn preview<'a>( + &'a self, + state: &'a mut button::State, + i: usize, + is_selected: bool, + ) -> Element<'a, GuiEvent> { Button::new(state, Text::new(self.name.as_str())) .width(Length::Fill) .padding(10) - .style(if is_selected { style::Button::Primary } else { style::Button::InactiveChat }) + .style(if is_selected { + style::Button::Primary + } else { + style::Button::InactiveChat + }) .on_press(GuiEvent::ChatSelect(i)) .into() } @@ -157,7 +186,7 @@ impl Chat { pub fn example(i: usize) -> Chat { Self { pkey: Keys::generate().get_public(), - messages: vec![(false, Message::Text(format!("Example message {}", i)))], + messages: vec![(false, DegMessage::Text(format!("Example message {}", i)))], name: format!("Example user ({})", i), scrolled: 0.0, input: "".to_string(), @@ -165,12 +194,135 @@ impl Chat { } } -#[derive(Default, Clone, Debug)] +#[derive(Clone)] +pub struct Degeon { + pub chats: Vec, + pub my_name: String, + pub keys: Keys, + pub ironforce: Arc>, +} + +impl Default for Degeon { + fn default() -> Self { + let ironforce = IronForce::from_file("".to_string()).unwrap(); + let keys = ironforce.keys.clone(); + let (_thread, ironforce) = ironforce.launch_main_loop(500); + Self { + chats: vec![], + my_name: "".to_string(), + keys, + ironforce, + } + } +} + +impl Degeon { + pub fn chat_with(&self, pkey: &PublicKey) -> Option { + self.chats.iter().position(|chat| &chat.pkey == pkey) + } + + pub fn process_message(&self, msg: ironforce::Message) -> IFResult> { + let deg_msg: DegMessage = + serde_json::from_slice(msg.get_decrypted(&self.keys)?.as_slice())?; + let sender = msg.get_sender(&self.keys).unwrap(); + Ok(match °_msg { + DegMessage::Text(_) | DegMessage::File(_) => { + Some(GuiEvent::NewMessageInChat(sender, deg_msg)) + } + DegMessage::Service(msg) => match msg { + ServiceMsg::NameRequest => self + .send_message( + DegMessage::Service(ServiceMsg::NameStatement(self.my_name.clone())), + &sender, + ) + .map(|_| None)?, + ServiceMsg::NameStatement(name) => { + Some(GuiEvent::SetName(sender, name.to_string())) + } + ServiceMsg::Ping => self + .send_message(DegMessage::Service(ServiceMsg::HiThere), &sender) + .map(|_| None)?, + ServiceMsg::HiThere => Some(GuiEvent::NewChat(sender)), + }, + }) + } + + pub fn send_multicast(&self, msg: DegMessage) -> IFResult<()> { + self.ironforce.lock().unwrap().send_to_all( + Message::build() + .message_type(MessageType::Broadcast) + .content(serde_json::to_vec(&msg)?) + .sign(&self.keys) + .build()?, + ) + } + + pub fn send_message(&self, msg: DegMessage, target: &PublicKey) -> IFResult<()> { + if self.ironforce.lock().unwrap().get_tunnel(target).is_none() { + println!("Creating a tunnel"); + self.ironforce + .lock() + .unwrap() + .initialize_tunnel_creation(target)?; + let mut counter = 0; + while self.ironforce.lock().unwrap().get_tunnel(target).is_none() { + std::thread::sleep(std::time::Duration::from_millis(350)); + counter += 1; + if counter > 100 { + return Err(IFError::TunnelNotFound); + } + } + } + self.ironforce.lock().unwrap().send_to_all( + Message::build() + .message_type(MessageType::Broadcast) + .content(serde_json::to_vec(&msg)?) + .recipient(target) + .sign(&self.keys) + .build()?, + ) + } +} + +impl Stream for Degeon { + type Item = GuiEvent; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + println!("Degeon worker is being polled"); + let msg = self.ironforce.lock().unwrap().read_message(); + match msg.map(|msg| self.process_message(msg).unwrap()) { + None | Some(None) => Poll::Pending, + Some(Some(msg)) => Poll::Ready(Some(msg)), + } + } +} + +impl iced_native::subscription::Recipe for Degeon +where + H: std::hash::Hasher, +{ + type Output = GuiEvent; + + fn hash(&self, state: &mut H) { + use std::hash::Hash; + + std::any::TypeId::of::().hash(state); + self.ironforce.lock().unwrap().hash(state); + } + + fn stream( + self: Box, + _input: futures::stream::BoxStream<'static, I>, + ) -> futures::stream::BoxStream<'static, Self::Output> { + Box::pin(self) + } +} + +#[derive(Default)] pub struct State { - chats: Vec, - my_name: String, + pub data: Degeon, selected_chat: usize, - pub send_button_state: iced::button::State, + send_button_state: iced::button::State, text_input_state: iced::text_input::State, preview_button_states: Vec, } @@ -179,8 +331,11 @@ impl State { fn chat_list<'a>( chats: &'a Vec, preview_button_states: &'a mut Vec, - selected: usize + selected: usize, ) -> Element<'a, GuiEvent> { + while preview_button_states.len() < chats.len() { + preview_button_states.push(Default::default()) + } Column::with_children( chats .iter() @@ -214,41 +369,87 @@ impl State { } } -impl Sandbox for State { +impl Application for State { + type Executor = iced::executor::Default; type Message = GuiEvent; + type Flags = (); - fn new() -> Self { + fn new(_: ()) -> (Self, iced::Command) { let mut st = Self::default(); - st.chats = vec![Chat::example(1), Chat::example(2)]; + st.data.chats = vec![Chat::example(1), Chat::example(2)]; st.preview_button_states = vec![Default::default(), Default::default()]; - st + st.data.my_name = "John".to_string(); + st.data + .send_multicast(DegMessage::Service(ServiceMsg::Ping)) + .unwrap(); + let data_clone = st.data.clone(); + std::thread::spawn(move || { + std::thread::sleep(std::time::Duration::from_secs(10)); + loop { + data_clone + .send_multicast(DegMessage::Service(ServiceMsg::Ping)) + .unwrap(); + std::thread::sleep(std::time::Duration::from_secs(120)); + } + }); + (st, iced::Command::none()) } fn title(&self) -> String { String::from("Degeon") } - fn update(&mut self, message: GuiEvent) { + fn update(&mut self, message: GuiEvent, _: &mut iced::Clipboard) -> iced::Command { match message { GuiEvent::ChatSelect(i) => self.selected_chat = i, - GuiEvent::Typed(st) => self.chats[self.selected_chat].input = st, + GuiEvent::Typed(st) => self.data.chats[self.selected_chat].input = st, GuiEvent::SendClick => { - if self.chats[self.selected_chat].input.is_empty() { - return; + if self.data.chats[self.selected_chat].input.is_empty() { + return iced::Command::none(); } - let new_msg = Message::Text(self.chats[self.selected_chat].input.clone()); - self.chats[self.selected_chat].input = String::new(); - self.chats[self.selected_chat] + let new_msg = DegMessage::Text(self.data.chats[self.selected_chat].input.clone()); + self.data.chats[self.selected_chat].input = String::new(); + self.data.chats[self.selected_chat] .messages - .push((true, new_msg)); - // todo + .push((true, new_msg.clone())); + let data_cloned = self.data.clone(); + let target = self.data.chats[self.selected_chat].pkey.clone(); + std::thread::spawn(move || { + data_cloned + .send_message(new_msg, &target) + .unwrap() + }); + } + GuiEvent::NewChat(pkey) => { + if self.data.chat_with(&pkey).is_none() { + self.data.chats.push(Chat::new(pkey)) + } + } + GuiEvent::NewMessageInChat(pkey, msg) => { + if self.data.chat_with(&pkey).is_none() { + self.data.chats.push(Chat::new(pkey.clone())) + } + let ind = self.data.chat_with(&pkey).unwrap(); + self.data.chats[ind].messages.push((false, msg)) + } + GuiEvent::SetName(pkey, name) => { + if self.data.chat_with(&pkey).is_none() { + self.data.chats.push(Chat::new(pkey.clone())) + } + let ind = self.data.chat_with(&pkey).unwrap(); + self.data.chats[ind].name = name; } } + iced::Command::none() + } + + fn subscription(&self) -> iced::Subscription { + iced::Subscription::from_recipe(self.data.clone()) } fn view(&mut self) -> Element { let Self { - chats, + data: Degeon { chats, .. }, selected_chat, send_button_state, text_input_state, @@ -257,7 +458,11 @@ impl Sandbox for State { } = self; Row::new() .padding(20) - .push(Self::chat_list(chats, preview_button_states, *selected_chat)) + .push(Self::chat_list( + chats, + preview_button_states, + *selected_chat, + )) .push(Self::active_chat( chats, *selected_chat, diff --git a/src/crypto.rs b/src/crypto.rs index e56fb89..5788893 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -9,13 +9,15 @@ use rsa::errors::Result as RsaRes; use rsa::{BigUint, PaddingScheme, PublicKey as RPK, PublicKeyParts, RsaPrivateKey, RsaPublicKey}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha224}; +use core::hash::Hash; +use core::hash::Hasher; static KEY_LENGTH: usize = 2048; static ENCRYPTION_CHUNK_SIZE: usize = 240; static ENCRYPTION_OUTPUT_CHUNK_SIZE: usize = 256; /// Public key of a node -#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct PublicKey { pub key: RsaPublicKey, } @@ -99,6 +101,18 @@ impl PublicKey { } } +impl Hash for PublicKey { + fn hash(&self, state: &mut H) { + Hash::hash(&self.to_vec(), state) + } +} + +impl PartialEq for PublicKey { + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } +} + impl PublicKey { fn hash(&self) -> Vec { self.to_vec() @@ -112,6 +126,12 @@ pub struct Keys { private_key: RsaPrivateKey, } +impl Hash for Keys { + fn hash(&self, state: &mut H) { + Hash::hash(&self.get_public(), state) + } +} + impl Keys { /// Generate new random key pub fn generate() -> Self { diff --git a/src/interfaces/ip.rs b/src/interfaces/ip.rs index 95066ce..be6aacc 100644 --- a/src/interfaces/ip.rs +++ b/src/interfaces/ip.rs @@ -99,116 +99,149 @@ impl Interface for IPInterface { self.connections.push(stream) } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {} - Err(e) => return Err(IFError::from(e)), + Err(e) => println!("An error happened with an incoming connection: {:?}", e), } } let mut new_connections: Vec = vec![]; - for connection in &mut self.connections { - connection.set_nonblocking(true)?; - let mut buf = [0u8; 6]; - let peek_res = connection.peek(&mut buf); - if peek_res.is_err() || peek_res.unwrap() < 6 { - continue - } - let mut header: [u8; 6] = [0, 0, 0, 0, 0, 0]; - match connection.read_exact(&mut header) { - Ok(_) => {} - Err(ref e) - if e.kind() == std::io::ErrorKind::WouldBlock - || e.kind() == std::io::ErrorKind::UnexpectedEof => - { - continue - } - Err(e) => { - println!("Error: {:?}", e); + let mut connections_to_delete = vec![]; + for (i, connection) in self.connections.iter_mut().enumerate() { + let res: std::io::Result<()> = { + connection.set_nonblocking(true)?; + let mut buf = [0u8; 6]; + let peek_res = connection.peek(&mut buf); + if peek_res.is_err() || peek_res.unwrap() < 6 { continue; } - }; - let version = header[0]; - let package_type = MessageType::from_u8(header[1])?; - let size = bytes_to_size([header[2], header[3], header[4], header[5]]); - connection.set_nonblocking(false)?; - connection.set_read_timeout(Some(std::time::Duration::from_millis(500)))?; - - let mut message_take = connection.take(size as u64); - let mut message: Vec = vec![]; - message_take.read_to_end(&mut message)?; - - match package_type { - MessageType::PeerRequest => { - let peers_to_share = if self.peers.len() < PEER_THRESHOLD { - self.peers.clone() - } else { - self.peers.iter().skip(7).step_by(2).cloned().collect() - }; - let message = serde_cbor::to_vec(&peers_to_share)?; - IPInterface::send_package( - connection, - IPPackage { + let mut header: [u8; 6] = [0, 0, 0, 0, 0, 0]; + match connection.read_exact(&mut header) { + Ok(_) => {} + Err(ref e) + if e.kind() == std::io::ErrorKind::WouldBlock + || e.kind() == std::io::ErrorKind::UnexpectedEof => + { + continue + } + Err(e) => { + println!("Error: {:?}", e); + connections_to_delete.push(i); + let connection_addr = if let Ok(r) = connection.peer_addr() { + r + } else { + continue; + }; + if let Some(peer) = self + .peers + .iter() + .find(|p| compare_addrs(p, connection_addr)) + { + if let Some(Some(conn)) = IPInterface::new_connection(peer).ok() { + new_connections.push(conn) + } + } + continue; + } + }; + let version = header[0]; + let package_type = MessageType::from_u8(header[1])?; + let size = bytes_to_size([header[2], header[3], header[4], header[5]]); + connection.set_nonblocking(false)?; + connection.set_read_timeout(Some(std::time::Duration::from_millis(500)))?; + + let mut message_take = connection.take(size as u64); + let mut message: Vec = vec![]; + message_take.read_to_end(&mut message)?; + + match package_type { + MessageType::PeerRequest => { + let peers_to_share = if self.peers.len() < PEER_THRESHOLD { + self.peers.clone() + } else { + self.peers.iter().skip(7).step_by(2).cloned().collect() + }; + let message = serde_cbor::to_vec(&peers_to_share)?; + IPInterface::send_package( + connection, + IPPackage { + version, + package_type: MessageType::PeersShared, + size: message.len() as u32, + message, + }, + )?; + } + MessageType::Common => { + let package = IPPackage { version, - package_type: MessageType::PeersShared, - size: message.len() as u32, + package_type, + size, message, - }, - )?; - } - MessageType::Common => { - let package = IPPackage { - version, - package_type, - size, - message, - }; - self.package_queue - .push((package, format!("{:?}", connection.peer_addr()?))); - } - MessageType::PeersShared => { - let peers: Vec = serde_cbor::from_slice(message.as_slice())?; - for peer in peers { - if !self.peers.contains(&peer) { - if let Some(conn) = IPInterface::new_connection(&peer)? { - new_connections.push(conn) + }; + self.package_queue + .push((package, format!("{:?}", connection.peer_addr()?))); + } + MessageType::PeersShared => { + let peers: Vec = serde_cbor::from_slice(message.as_slice())?; + for peer in peers { + if !self.peers.contains(&peer) { + if let Some(conn) = IPInterface::new_connection(&peer)? { + new_connections.push(conn) + } + self.peers.push(peer); } - self.peers.push(peer); } } } - } + Ok(()) + }; + if res.is_err() && res.unwrap_err().kind() == std::io::ErrorKind::BrokenPipe { + connections_to_delete.push(i) + }; + } + for (j, index) in connections_to_delete.iter().enumerate() { + self.connections.remove(index - j); } + for conn in new_connections.iter_mut() { - self.initialize_connection(conn)?; + self.initialize_connection(conn) + .unwrap_or_else(|e| println!("Couldn't initialize connection: {:?}", e)); } self.connections.extend(new_connections); self.main_loop_iterations += 1; // Every 50 iterations we connect to everyone we know if self.main_loop_iterations % 50 == 0 { - let connected_addresses = self - .connections - .iter() - .filter_map(|conn| conn.peer_addr().ok()) - .collect::>(); - let peers_we_do_not_have_connections_with = self - .peers - .iter() - .filter(|p| { - !connected_addresses - .iter() - .any(|addr| compare_addrs(p, *addr)) - }) - .copied() - .collect::>(); + let peers_we_do_not_have_connections_with = self.disconnected_peers(); self.connections .extend(IPInterface::get_connections_to_peers( &peers_we_do_not_have_connections_with, self.peers.len() < PEER_THRESHOLD * 2, )); } + if self.connections.is_empty() { + for peer in self.peers.clone() { + self.obtain_connection(&peer) + .map(|_| ()) + .unwrap_or_else(|e| println!("Error in obtaining connection: {:?}", e)); + } + } + // We do a peer exchange every 30 iterations if self.main_loop_iterations % 30 == 0 && !self.connections.is_empty() { let connection_index = (self.main_loop_iterations / 30) as usize % self.connections.len(); - IPInterface::request_peers(&mut self.connections[connection_index])?; + match IPInterface::request_peers(&mut self.connections[connection_index]) { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + let peer = ( + self.connections[connection_index].peer_addr()?.ip(), + self.connections[connection_index].peer_addr()?.port(), + ); + self.connections.remove(connection_index); + let connection_index = self.obtain_connection(&peer)?; + IPInterface::request_peers(&mut self.connections[connection_index])?; + } + Err(e) => println!("An error in peer sharing: {:?}", e), + _ => {} + }; } Ok(()) } @@ -219,8 +252,8 @@ impl Interface for IPInterface { fn id(&self) -> &str { &*self.id } - fn send(&mut self, message: &[u8], interface_data: Option) -> IFResult<()> { + fn send(&mut self, message: &[u8], interface_data: Option) -> IFResult<()> { let package = IPPackage { version: 0, package_type: MessageType::Common, @@ -231,13 +264,50 @@ impl Interface for IPInterface { match interface_data { Some(ip_string) => { let addr: net::SocketAddr = ip_string.parse().expect("Unable to parse address"); - let index = self.obtain_connection(&(addr.ip(), addr.port()))?; - IPInterface::send_package(&mut self.connections[index], package)?; + let peer = (addr.ip(), addr.port()); + let index = self.obtain_connection(&peer)?; + match IPInterface::send_package(&mut self.connections[index], package.clone()) { + Ok(_) => {} + Err(_) => { + self.remove_all_connections_to_peer(&peer); + let index = self.obtain_connection(&(addr.ip(), addr.port()))?; + IPInterface::send_package(&mut self.connections[index], package).map_err( + |e| { + println!("Error while sending: {:?}", e); + e + }, + ); + } + } } None => { - for conn in &mut self.connections { - IPInterface::send_package(conn, package.clone())?; + if self.connections.len() < PEER_THRESHOLD + && self.connections.len() < self.peers.len() + { + let new_connections = IPInterface::get_connections_to_peers( + &self.disconnected_peers(), + self.peers.len() < PEER_THRESHOLD, + ); + self.connections.extend(new_connections); + } + let connections_to_delete = self + .connections + .iter_mut() + .enumerate() + .filter_map(|(i, conn)| { + IPInterface::send_package(conn, package.clone()) + .err() + .map(|_| i) + }) + .collect::>(); + for (j, index) in connections_to_delete.iter().enumerate() { + self.connections.remove(index - j); } + self.connections + .extend(IPInterface::get_connections_to_peers( + &self.disconnected_peers(), + self.peers.len() < PEER_THRESHOLD, + )) } }; Ok(()) @@ -296,7 +366,7 @@ impl IPInterface { .filter_map(|r| r.ok()) .filter_map(|r| r) .map(|mut c| -> IFResult { - println!("Requesting peers from {:?}", c.peer_addr().unwrap()); + println!("Requesting peers from {:?}", c.peer_addr().ok()); if do_peer_request { Self::request_peers(&mut c)?; } @@ -306,6 +376,37 @@ impl IPInterface { .collect::>() } + fn connected_addresses(&self) -> Vec { + self.connections + .iter() + .filter_map(|conn| conn.peer_addr().ok()) + .collect::>() + } + + fn disconnected_peers(&self) -> Vec { + let connected_addresses = self.connected_addresses(); + self.peers + .iter() + .filter(|p| { + !connected_addresses + .iter() + .any(|addr| compare_addrs(p, *addr)) + }) + .copied() + .collect::>() + } + + fn remove_all_connections_to_peer(&mut self, peer: &Peer) { + while let Some(ind) = self + .connections + .iter() + .filter_map(|conn| conn.peer_addr().ok()) + .position(|addr| compare_addrs(&peer, addr)) + { + self.connections.remove(ind); + } + } + pub fn new(port: u16, peers: Vec) -> IFResult { let listener = match create_tcp_listener(port) { Some(listener) => listener, @@ -339,9 +440,8 @@ impl IPInterface { Ok(()) } - fn send_package(stream: &mut net::TcpStream, package: IPPackage) -> IFResult<()> { + fn send_package(stream: &mut net::TcpStream, package: IPPackage) -> std::io::Result<()> { stream.set_write_timeout(Some(std::time::Duration::from_millis(700)))?; - #[cfg(test)] stream.set_nonblocking(false)?; let mut header: Vec = vec![package.version, package.package_type.as_u8()]; for byte in size_to_bytes(package.size) { @@ -360,7 +460,7 @@ impl IPInterface { Ok(()) } - fn request_peers(conn: &mut TcpStream) -> IFResult<()> { + fn request_peers(conn: &mut TcpStream) -> std::io::Result<()> { IPInterface::send_package( conn, IPPackage { @@ -374,9 +474,12 @@ impl IPInterface { } fn obtain_connection(&mut self, addr: &Peer) -> IFResult { - if let Some(pos) = self.connections.iter().position(|con| { - con.peer_addr().is_ok() && compare_addrs(addr, con.peer_addr().unwrap()) - }) { + if let Some(pos) = self + .connections + .iter() + .filter_map(|conn| conn.peer_addr().ok()) + .position(|pa| compare_addrs(addr, pa)) + { return Ok(pos); } if let Some(conn) = Self::new_connection(addr)? { @@ -387,7 +490,7 @@ impl IPInterface { } } - fn new_connection(addr: &Peer) -> IFResult> { + fn new_connection(addr: &Peer) -> std::io::Result> { for port in addr.1..addr.1 + 3 { match net::TcpStream::connect_timeout( &net::SocketAddr::new(addr.0, port as u16), diff --git a/src/ironforce.rs b/src/ironforce.rs index 28e7415..0bed09a 100644 --- a/src/ironforce.rs +++ b/src/ironforce.rs @@ -17,6 +17,7 @@ const TUNNEL_MAX_REPEAT_COUNT: u32 = 3; pub const DEFAULT_FILE: &str = ".if_data.json"; /// Main worker +#[derive(Hash)] pub struct IronForce { /// Keys for this instance pub keys: Keys, @@ -70,6 +71,12 @@ impl IFSerializationData { } } +impl Default for IronForce { + fn default() -> Self { + Self::new() + } +} + impl IronForce { /// Create new worker pub fn new() -> Self { @@ -88,7 +95,7 @@ impl IronForce { } /// Create a new tunnel to another node - fn initialize_tunnel_creation(&mut self, destination: &PublicKey) -> IFResult<()> { + pub fn initialize_tunnel_creation(&mut self, destination: &PublicKey) -> IFResult<()> { let tunnel = TunnelPublic::new_singlecast(); self.tunnels_pending .push((tunnel.clone(), Some(destination.clone()), (0, 0))); @@ -144,10 +151,9 @@ impl IronForce { Ok(()) } - /// Send a message to another node, - /// creating a new tunnel if needed - pub fn send_message(&mut self, message: Message, destination: &PublicKey) -> IFResult<()> { - if let Some(Some(tunnel_id)) = self + /// Find a tunnel to another node (and return its id) + pub fn get_tunnel(&self, destination: &PublicKey) -> Option { + if let Some(Some(tun)) = self .tunnels .iter() .find(|t| { @@ -159,9 +165,28 @@ impl IronForce { }) .map(|tunnel| tunnel.id) { + Some(tun) + } else { + None + } + } + + /// Send a message to another node, + /// creating a new tunnel if needed + pub fn send_message(&mut self, message: Message, destination: &PublicKey) -> IFResult<()> { + if let Some(tunnel_id) = self.get_tunnel(destination) { self.send_through_tunnel(tunnel_id, message, None) } else { - Err(IFError::TunnelNotFound) + self.initialize_tunnel_creation(destination)?; + while self.get_tunnel(destination).is_none() { + if !self.has_background_worker { + self.main_loop_iteration()? + } + #[cfg(feature = "std")] + std::thread::sleep(std::time::Duration::from_millis(10)); + } + let tunnel_id = self.get_tunnel(destination).unwrap(); + self.send_through_tunnel(tunnel_id, message, None) } } @@ -284,7 +309,11 @@ impl IronForce { } } MessageType::Broadcast => { - self.messages.push(message.clone()); + #[cfg(feature = "std")] + println!("New message: {:?}", message.get_decrypted(&self.keys)); + if message.check_recipient(&self.keys) { + self.messages.push(message.clone()); + } self.send_to_all(message)?; } } @@ -334,9 +363,15 @@ impl IronForce { #[cfg(feature = "std")] pub fn from_file(filename: alloc::string::String) -> IFResult { - let filename = if filename.is_empty() { DEFAULT_FILE.to_string() } else { filename }; + let filename = if filename.is_empty() { + DEFAULT_FILE.to_string() + } else { + filename + }; if std::path::Path::new(&filename).exists() { - Self::from_serialization_data(serde_json::from_str(std::fs::read_to_string(filename)?.as_str())?) + Self::from_serialization_data(serde_json::from_str( + std::fs::read_to_string(filename)?.as_str(), + )?) } else { Ok(Self::new()) } diff --git a/src/message.rs b/src/message.rs index 7d5631b..11bf440 100644 --- a/src/message.rs +++ b/src/message.rs @@ -10,7 +10,7 @@ use sha2::Digest; pub(crate) type MessageBytes = Vec; /// Signature of the message: optional and optionally encrypted sender's key and signed hash -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Hash)] pub enum Signature { /// The message is signed. Author is unknown NotSigned, @@ -40,7 +40,7 @@ impl Signature { } /// Network name and version -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Hash)] pub struct NetworkInfo { network_name: String, version: String, @@ -55,7 +55,7 @@ impl Default for NetworkInfo { } } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Hash)] pub enum MessageType { SingleCast, Broadcast, @@ -83,7 +83,7 @@ impl MessageType { } } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Hash)] pub enum ServiceMessageType { /// Creating a tunnel - stage 1 /// @@ -92,7 +92,7 @@ pub enum ServiceMessageType { TunnelBuildingBackwardMovement(TunnelPublic), } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Hash)] pub enum MessageContent { /// Just plaintext message content Plain(Vec), @@ -102,6 +102,12 @@ pub enum MessageContent { None, } +impl Default for MessageContent { + fn default() -> Self { + MessageContent::None + } +} + impl MessageContent { pub fn hash(&self) -> Vec { match self { @@ -121,7 +127,7 @@ impl MessageContent { } /// The struct for messages that are sent in the network -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Hash)] pub struct Message { /// Content of the message (not to be confused with the bytes that we are sending through interfaces) /// @@ -180,8 +186,12 @@ impl Message { /// Check if this message is for this set of keys pub fn check_recipient(&self, keys: &Keys) -> bool { - keys.decrypt_data(&self.recipient_verification.clone().unwrap()) - .is_ok() + if self.recipient_verification.is_none() { + true + } else { + keys.decrypt_data(&self.recipient_verification.clone().unwrap()) + .is_ok() + } } /// Get decrypted content of the message @@ -220,7 +230,7 @@ impl Message { } /// Try to get sender from the signature - fn get_sender(&self, keys: &Keys) -> Option { + pub fn get_sender(&self, keys: &Keys) -> Option { match &self.signature { Signature::NotSigned => None, Signature::Signed { sender, .. } => Some(sender.clone()), @@ -247,6 +257,7 @@ impl Message { } /// Message builder to create a new message step-by-step, like `Message::build().message_type(...).sign(...)` +#[derive(Default)] pub struct MessageBuilder { content: MessageContent, /// The type of the message to be built @@ -262,13 +273,7 @@ pub struct MessageBuilder { impl MessageBuilder { /// Create a new `MessageBuilder` with default parameters pub fn new() -> Self { - Self { - content: MessageContent::None, - message_type: None, - sender: None, - recipient: None, - tunnel_id: (0, false), - } + Default::default() } pub fn content(mut self, cont: Vec) -> Self { diff --git a/src/transport.rs b/src/transport.rs index 702bf6f..e37462c 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -13,7 +13,7 @@ use rayon::prelude::*; use std::println; /// An identification of a peer - something that we can use to send a message to id -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)] pub struct PeerInfo { /// Something to locally identify this peer pub peer_id: u64, @@ -196,6 +196,13 @@ impl Transport { } } +impl core::hash::Hash for Transport { + fn hash(&self, state: &mut H) { + core::hash::Hash::hash(&self.get_interfaces_data(), state); + core::hash::Hash::hash(&self.peers, state); + } +} + #[cfg(test)] use crate::interface::test_interface::SimpleTestInterface; diff --git a/src/tunnel.rs b/src/tunnel.rs index 8d02aa6..968e41e 100644 --- a/src/tunnel.rs +++ b/src/tunnel.rs @@ -5,7 +5,7 @@ use sha2::Digest; use alloc::vec; /// A tunnel that is used for communication -#[derive(Serialize, Clone, Deserialize, Debug)] +#[derive(Serialize, Clone, Deserialize, Debug, Hash)] pub struct Tunnel { /// Tunnel's id. /// By the way, this id is `None` until the tunnel is validated in the backward movement @@ -25,7 +25,7 @@ pub struct Tunnel { } /// Tunnel, but only the fields that are ok to share -#[derive(Serialize, Clone, Deserialize, Debug)] +#[derive(Serialize, Clone, Deserialize, Debug, Hash)] pub struct TunnelPublic { /// Tunnel's id pub id: Option,