summaryrefslogtreecommitdiff
path: root/src/gateway/socket.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/gateway/socket.rs')
-rw-r--r--src/gateway/socket.rs161
1 files changed, 161 insertions, 0 deletions
diff --git a/src/gateway/socket.rs b/src/gateway/socket.rs
new file mode 100644
index 0000000..f252e7c
--- /dev/null
+++ b/src/gateway/socket.rs
@@ -0,0 +1,161 @@
+use chan;
+use chan::Sender;
+use rustc_serialize::json;
+use std::io::{BufReader, Read, Write};
+use std::net::Shutdown;
+use std::sync::{Arc, Mutex};
+use std::{fs, thread};
+
+use datatype::{Command, Error, Event};
+use super::{Gateway, Interpret};
+use unix_socket::{UnixListener, UnixStream};
+
+
+/// The `Socket` gateway is used for communication via Unix Domain Sockets.
+pub struct Socket {
+ pub commands_path: String,
+ pub events_path: String,
+}
+
+impl Gateway for Socket {
+ fn initialize(&mut self, itx: Sender<Interpret>) -> Result<(), String> {
+ let _ = fs::remove_file(&self.commands_path);
+ let commands = match UnixListener::bind(&self.commands_path) {
+ Ok(sock) => sock,
+ Err(err) => return Err(format!("couldn't open commands socket: {}", err))
+ };
+
+ let itx = Arc::new(Mutex::new(itx));
+ thread::spawn(move || {
+ for conn in commands.incoming() {
+ if let Err(err) = conn {
+ error!("couldn't get commands socket connection: {}", err);
+ continue
+ }
+ let mut stream = conn.unwrap();
+ let itx = itx.clone();
+
+ thread::spawn(move || {
+ let resp = handle_client(&mut stream, itx)
+ .map(|ev| json::encode(&ev).expect("couldn't encode Event").into_bytes())
+ .unwrap_or_else(|err| format!("{}", err).into_bytes());
+
+ stream.write_all(&resp)
+ .unwrap_or_else(|err| error!("couldn't write to commands socket: {}", err));
+ stream.shutdown(Shutdown::Write)
+ .unwrap_or_else(|err| error!("couldn't close commands socket: {}", err));
+ });
+ }
+ });
+
+ Ok(info!("Socket listening for commands at {} and sending events to {}.",
+ self.commands_path, self.events_path))
+ }
+
+ fn pulse(&self, event: Event) {
+ match event {
+ Event::DownloadComplete(dl) => {
+ let _ = UnixStream::connect(&self.events_path).map(|mut stream| {
+ stream.write_all(&json::encode(&dl).expect("couldn't encode Event").into_bytes())
+ .unwrap_or_else(|err| error!("couldn't write to events socket: {}", err));
+ stream.shutdown(Shutdown::Write)
+ .unwrap_or_else(|err| error!("couldn't close events socket: {}", err));
+ }).map_err(|err| error!("couldn't open events socket: {}", err));
+ }
+
+ _ => ()
+ }
+ }
+}
+
+fn handle_client(stream: &mut UnixStream, itx: Arc<Mutex<Sender<Interpret>>>) -> Result<Event, Error> {
+ info!("New domain socket connection");
+ let mut reader = BufReader::new(stream);
+ let mut input = String::new();
+ try!(reader.read_to_string(&mut input));
+ debug!("socket input: {}", input);
+
+ let cmd = try!(input.parse::<Command>());
+ let (etx, erx) = chan::async::<Event>();
+ itx.lock().unwrap().send(Interpret {
+ command: cmd,
+ response_tx: Some(Arc::new(Mutex::new(etx))),
+ });
+ erx.recv().ok_or(Error::Socket("internal receiver error".to_string()))
+}
+
+
+#[cfg(test)]
+mod tests {
+ use chan;
+ use crossbeam;
+ use rustc_serialize::json;
+ use std::{fs, thread};
+ use std::io::{Read, Write};
+ use std::net::Shutdown;
+ use std::time::Duration;
+
+ use datatype::{Command, DownloadComplete, Event};
+ use gateway::{Gateway, Interpret};
+ use super::*;
+ use unix_socket::{UnixListener, UnixStream};
+
+
+ #[test]
+ fn socket_commands_and_events() {
+ let (etx, erx) = chan::sync::<Event>(0);
+ let (itx, irx) = chan::sync::<Interpret>(0);
+
+ thread::spawn(move || Socket {
+ commands_path: "/tmp/sota-commands.socket".to_string(),
+ events_path: "/tmp/sota-events.socket".to_string(),
+ }.start(itx, erx));
+ thread::sleep(Duration::from_millis(100)); // wait until socket gateway is created
+
+ let path = "/tmp/sota-events.socket";
+ let _ = fs::remove_file(&path);
+ let server = UnixListener::bind(&path).expect("couldn't create events socket for testing");
+
+ let send = DownloadComplete {
+ update_id: "1".to_string(),
+ update_image: "/foo/bar".to_string(),
+ signature: "abc".to_string()
+ };
+ etx.send(Event::DownloadComplete(send.clone()));
+
+ let (mut stream, _) = server.accept().expect("couldn't read from events socket");
+ let mut text = String::new();
+ stream.read_to_string(&mut text).unwrap();
+ let receive: DownloadComplete = json::decode(&text).expect("couldn't decode DownloadComplete message");
+ assert_eq!(send, receive);
+
+ thread::spawn(move || {
+ let _ = etx; // move into this scope
+ loop {
+ let interpret = irx.recv().expect("gtx is closed");
+ match interpret.command {
+ Command::StartDownload(ids) => {
+ let tx = interpret.response_tx.unwrap();
+ tx.lock().unwrap().send(Event::FoundSystemInfo(ids.first().unwrap().to_owned()));
+ }
+ _ => panic!("expected AcceptUpdates"),
+ }
+ }
+ });
+
+ crossbeam::scope(|scope| {
+ for id in 0..10 {
+ scope.spawn(move || {
+ let mut stream = UnixStream::connect("/tmp/sota-commands.socket").expect("couldn't connect to socket");
+ let _ = stream.write_all(&format!("dl {}", id).into_bytes()).expect("couldn't write to stream");
+ stream.shutdown(Shutdown::Write).expect("couldn't shut down writing");
+
+ let mut resp = String::new();
+ stream.read_to_string(&mut resp).expect("couldn't read from stream");
+ let ev: Event = json::decode(&resp).expect("couldn't decode json event");
+ assert_eq!(ev, Event::FoundSystemInfo(format!("{}", id)));
+ });
+ }
+ });
+ }
+}