diff --git a/hawkbit/Cargo.toml b/hawkbit/Cargo.toml index 8c22cba..cd2dc90 100644 --- a/hawkbit/Cargo.toml +++ b/hawkbit/Cargo.toml @@ -11,7 +11,7 @@ repository = "https://github.com/collabora/hawkbit-rs" documentation = "https://docs.rs/hawkbit_mock/" [dependencies] -reqwest = { version = "0.10", features = ["json"] } +reqwest = { version = "0.10", features = ["json", "stream"] } tokio = { version = "0.2", features = ["full"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -23,6 +23,9 @@ digest = { version = "0.9", optional = true } md-5 = { version = "0.9", optional = true } sha-1 = { version = "0.9", optional = true } sha2 = { version = "0.9", optional = true } +generic-array = {version = "0.14", optional = true } +futures = "0.3" +bytes = "0.5.6" [dev-dependencies] hawkbit_mock = { path = "../hawkbit_mock/" } @@ -31,9 +34,10 @@ anyhow = "1.0" log = "0.4" env_logger = "0.8" tempdir = "0.3" +assert_matches = "1.4" [features] -hash-digest= ["digest"] +hash-digest= ["digest", "generic-array"] hash-md5 = ["md-5", "hash-digest"] hash-sha1 = ["sha-1", "hash-digest"] hash-sha256 = ["sha2", "hash-digest"] \ No newline at end of file diff --git a/hawkbit/src/ddi.rs b/hawkbit/src/ddi.rs index dcdb662..0c9028c 100644 --- a/hawkbit/src/ddi.rs +++ b/hawkbit/src/ddi.rs @@ -23,6 +23,8 @@ mod poll; pub use client::{Client, Error}; pub use common::{Execution, Finished}; pub use config_data::{ConfigRequest, Mode}; +#[cfg(feature = "hash-digest")] +pub use deployment_base::ChecksumType; pub use deployment_base::{ Artifact, Chunk, DownloadedArtifact, MaintenanceWindow, Type, Update, UpdatePreFetch, }; diff --git a/hawkbit/src/ddi/client.rs b/hawkbit/src/ddi/client.rs index 5015f8b..7442e8c 100644 --- a/hawkbit/src/ddi/client.rs +++ b/hawkbit/src/ddi/client.rs @@ -34,6 +34,10 @@ pub enum Error { /// IO error #[error("Failed to download update")] Io(#[from] std::io::Error), + /// Invalid checksum + #[cfg(feature = "hash-digest")] + #[error("Invalid Checksum")] + ChecksumError(crate::ddi::deployment_base::ChecksumType), } impl Client { diff --git a/hawkbit/src/ddi/deployment_base.rs b/hawkbit/src/ddi/deployment_base.rs index 0bf701b..5b21015 100644 --- a/hawkbit/src/ddi/deployment_base.rs +++ b/hawkbit/src/ddi/deployment_base.rs @@ -5,7 +5,9 @@ use std::path::{Path, PathBuf}; -use reqwest::Client; +use bytes::Bytes; +use futures::{prelude::*, TryStreamExt}; +use reqwest::{Client, Response}; use serde::{Deserialize, Serialize}; use tokio::{ fs::{DirBuilder, File}, @@ -319,15 +321,20 @@ impl<'a> Artifact<'a> { self.artifact.size } - /// Download the artifact file to the directory defined in `dir`. - pub async fn download(&'a self, dir: &Path) -> Result { - let mut resp = self + async fn download_response(&'a self) -> Result { + let resp = self .client .get(&self.artifact.links.download_http.to_string()) .send() .await?; resp.error_for_status_ref()?; + Ok(resp) + } + + /// Download the artifact file to the directory defined in `dir`. + pub async fn download(&'a self, dir: &Path) -> Result { + let mut resp = self.download_response().await?; if !dir.exists() { DirBuilder::new().recursive(true).create(dir).await?; @@ -346,6 +353,76 @@ impl<'a> Artifact<'a> { self.artifact.hashes.clone(), )) } + + /// Provide a `Stream` of `Bytes` to download the artifact. + /// + /// This can be used as an alternative to [`Artifact::download`], + /// for example, to extract an archive while it's being downloaded, + /// saving the need to store the archive file on disk. + pub async fn download_stream( + &'a self, + ) -> Result>, Error> { + let resp = self.download_response().await?; + + Ok(resp.bytes_stream().map_err(|e| e.into())) + } + + /// Provide a `Stream` of `Bytes` to download the artifact while checking md5 checksum. + /// + /// The stream will yield the same data as [`Artifact::download_stream`] but will raise + /// an error if the md5sum of the downloaded data does not match the one provided by the server. + #[cfg(feature = "hash-md5")] + pub async fn download_stream_with_md5_check( + &'a self, + ) -> Result>, Error> { + let stream = self.download_stream().await?; + let hasher = DownloadHasher::new_md5(self.artifact.hashes.md5.clone()); + + let stream = DownloadStreamHash { + stream: Box::new(stream), + hasher, + }; + + Ok(stream) + } + + /// Provide a `Stream` of `Bytes` to download the artifact while checking sha1 checksum. + /// + /// The stream will yield the same data as [`Artifact::download_stream`] but will raise + /// an error if the sha1sum of the downloaded data does not match the one provided by the server. + #[cfg(feature = "hash-sha1")] + pub async fn download_stream_with_sha1_check( + &'a self, + ) -> Result>, Error> { + let stream = self.download_stream().await?; + let hasher = DownloadHasher::new_sha1(self.artifact.hashes.sha1.clone()); + + let stream = DownloadStreamHash { + stream: Box::new(stream), + hasher, + }; + + Ok(stream) + } + + /// Provide a `Stream` of `Bytes` to download the artifact while checking sha256 checksum. + /// + /// The stream will yield the same data as [`Artifact::download_stream`] but will raise + /// an error if the sha256sum of the downloaded data does not match the one provided by the server. + #[cfg(feature = "hash-sha256")] + pub async fn download_stream_with_sha256_check( + &'a self, + ) -> Result>, Error> { + let stream = self.download_stream().await?; + let hasher = DownloadHasher::new_sha256(self.artifact.hashes.sha256.clone()); + + let stream = DownloadStreamHash { + stream: Box::new(stream), + hasher, + }; + + Ok(stream) + } } /// A downloaded file part of a [`Chunk`]. @@ -357,28 +434,140 @@ pub struct DownloadedArtifact { cfg_if::cfg_if! { if #[cfg(feature = "hash-digest")] { + use std::{ + pin::Pin, + task::Poll, + }; use digest::Digest; - use thiserror::Error; const HASH_BUFFER_SIZE: usize = 4096; - #[derive(Error, Debug)] - pub enum ChecksumError { - #[error("Failed to compute checksum")] - Io(#[from] std::io::Error), - #[error("Checksum {0} does not match")] - Invalid(CheckSumType), - } - - #[derive(Debug, strum::Display)] - pub enum CheckSumType { + /// Enum representing the different type of supported checksums + #[derive(Debug, strum::Display, Clone)] + pub enum ChecksumType { + /// md5 #[cfg(feature = "hash-md5")] Md5, + /// sha1 #[cfg(feature = "hash-sha1")] Sha1, + /// sha256 #[cfg(feature = "hash-sha256")] Sha256, } + + // quite complex trait bounds because of requirements so LowerHex is implemented on the output + #[derive(Clone)] + struct DownloadHasher + where + T: Digest, + ::OutputSize: core::ops::Add, + <::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength, + { + hasher: T, + expected: String, + error: ChecksumType, + } + + impl DownloadHasher + where + T: Digest, + ::OutputSize: core::ops::Add, + <::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength + { + fn update(&mut self, data: impl AsRef<[u8]>) { + self.hasher.update(data); + } + + fn finalize(self) -> Result<(), Error> { + let digest = self.hasher.finalize(); + + if format!("{:x}", digest) == self.expected { + Ok(()) + } else { + Err(Error::ChecksumError(self.error)) + } + } + } + + #[cfg(feature = "hash-md5")] + impl DownloadHasher { + fn new_md5(expected: String) -> Self { + Self { + hasher: md5::Md5::new(), + expected, + error: ChecksumType::Md5, + } + } + } + + #[cfg(feature = "hash-sha1")] + impl DownloadHasher { + fn new_sha1(expected: String) -> Self { + Self { + hasher: sha1::Sha1::new(), + expected, + error: ChecksumType::Sha1, + } + } + } + + #[cfg(feature = "hash-sha256")] + impl DownloadHasher { + fn new_sha256(expected: String) -> Self { + Self { + hasher: sha2::Sha256::new(), + expected, + error: ChecksumType::Sha256, + } + } + } + + struct DownloadStreamHash + where + T: Digest, + ::OutputSize: core::ops::Add, + <::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength, + { + stream: Box> + Unpin + Send + Sync>, + hasher: DownloadHasher, + } + + impl Stream for DownloadStreamHash + where + T: Digest, + ::OutputSize: core::ops::Add, + <::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength, + T: Unpin, + T: Clone, + { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let me = Pin::into_inner(self); + + match Pin::new(&mut me.stream).poll_next(cx) { + Poll::Ready(Some(Ok(data))) => { + // feed data to the hasher and then pass them back to the stream + me.hasher.update(&data); + Poll::Ready(Some(Ok(data))) + } + Poll::Ready(None) => { + // download is done, check the hash + match me.hasher.clone().finalize() { + Ok(_) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(e))), + } + } + // passthrough on errors and pendings + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } + } + } } } @@ -393,7 +582,12 @@ impl<'a> DownloadedArtifact { } #[cfg(feature = "hash-digest")] - async fn hash(&self, mut hasher: T) -> Result, ChecksumError> { + async fn hash(&self, mut hasher: DownloadHasher) -> Result<(), Error> + where + T: Digest, + ::OutputSize: core::ops::Add, + <::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength, + { use tokio::io::AsyncReadExt; let mut file = File::open(&self.file).await?; @@ -407,42 +601,27 @@ impl<'a> DownloadedArtifact { hasher.update(&buffer[..n]); } - Ok(hasher.finalize()) + hasher.finalize() } /// Check if the md5sum of the downloaded file matches the one provided by the server. #[cfg(feature = "hash-md5")] - pub async fn check_md5(&self) -> Result<(), ChecksumError> { - let digest = self.hash(md5::Md5::new()).await?; - - if format!("{:x}", digest) == self.hashes.md5 { - Ok(()) - } else { - Err(ChecksumError::Invalid(CheckSumType::Md5)) - } + pub async fn check_md5(&self) -> Result<(), Error> { + let hasher = DownloadHasher::new_md5(self.hashes.md5.clone()); + self.hash(hasher).await } /// Check if the sha1sum of the downloaded file matches the one provided by the server. #[cfg(feature = "hash-sha1")] - pub async fn check_sha1(&self) -> Result<(), ChecksumError> { - let digest = self.hash(sha1::Sha1::new()).await?; - - if format!("{:x}", digest) == self.hashes.sha1 { - Ok(()) - } else { - Err(ChecksumError::Invalid(CheckSumType::Sha1)) - } + pub async fn check_sha1(&self) -> Result<(), Error> { + let hasher = DownloadHasher::new_sha1(self.hashes.sha1.clone()); + self.hash(hasher).await } /// Check if the sha256sum of the downloaded file matches the one provided by the server. #[cfg(feature = "hash-sha256")] - pub async fn check_sha256(&self) -> Result<(), ChecksumError> { - let digest = self.hash(sha2::Sha256::new()).await?; - - if format!("{:x}", digest) == self.hashes.sha256 { - Ok(()) - } else { - Err(ChecksumError::Invalid(CheckSumType::Sha256)) - } + pub async fn check_sha256(&self) -> Result<(), Error> { + let hasher = DownloadHasher::new_sha256(self.hashes.sha256.clone()); + self.hash(hasher).await } } diff --git a/hawkbit/tests/tests.rs b/hawkbit/tests/tests.rs index b09cc33..0c4b6da 100644 --- a/hawkbit/tests/tests.rs +++ b/hawkbit/tests/tests.rs @@ -1,9 +1,13 @@ // Copyright 2020, Collabora Ltd. // SPDX-License-Identifier: MIT OR Apache-2.0 +use std::fs::File; +use std::io::prelude::*; use std::{path::PathBuf, time::Duration}; -use hawkbit::ddi::{Client, Execution, Finished, MaintenanceWindow, Mode, Type}; +use bytes::Bytes; +use futures::prelude::*; +use hawkbit::ddi::{Client, Error, Execution, Finished, MaintenanceWindow, Mode, Type}; use serde::Serialize; use serde_json::json; use tempdir::TempDir; @@ -97,25 +101,32 @@ async fn upload_config() { assert_eq!(target.config_data_hits(), 1); } -fn get_deployment() -> Deployment { +fn artifact_path() -> PathBuf { let mut test_artifact = PathBuf::new(); test_artifact.push("tests"); test_artifact.push("data"); test_artifact.push("test.txt"); + test_artifact +} + +fn get_deployment(valid_checksums: bool) -> Deployment { + let test_artifact = artifact_path(); + + let artifacts = if valid_checksums { + vec![( + test_artifact, + "5eb63bbbe01eeed093cb22bb8f5acdc3", + "2aae6c35c94fcfb415dbe95f408b9ce91ee846ed", + "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + )] + } else { + vec![(test_artifact, "badger", "badger", "badger")] + }; + DeploymentBuilder::new("10", Type::Forced, Type::Attempt) .maintenance_window(MaintenanceWindow::Available) - .chunk( - "app", - "1.0", - "some-chunk", - vec![( - test_artifact, - "5eb63bbbe01eeed093cb22bb8f5acdc3", - "2aae6c35c94fcfb415dbe95f408b9ce91ee846ed", - "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", - )], - ) + .chunk("app", "1.0", "some-chunk", artifacts) .build() } @@ -125,7 +136,7 @@ async fn deployment() { let server = ServerBuilder::default().build(); let (client, target) = add_target(&server, "Target1"); - target.push_deployment(get_deployment()); + target.push_deployment(get_deployment(true)); let reply = client.poll().await.expect("poll failed"); assert!(reply.config_data_request().is_none()); @@ -178,7 +189,7 @@ async fn send_feedback() { init(); let server = ServerBuilder::default().build(); - let deploy = get_deployment(); + let deploy = get_deployment(true); let deploy_id = deploy.id.clone(); let (client, target) = add_target(&server, "Target1"); target.push_deployment(deploy); @@ -265,9 +276,153 @@ async fn config_then_deploy() { assert!(reply.update().is_none()); // server pushes an update - target.push_deployment(get_deployment()); + target.push_deployment(get_deployment(true)); let reply = client.poll().await.expect("poll failed"); assert!(reply.config_data_request().is_some()); assert!(reply.update().is_some()); } + +#[tokio::test] +async fn download_stream() { + init(); + + let server = ServerBuilder::default().build(); + let (client, target) = add_target(&server, "Target1"); + + target.push_deployment(get_deployment(true)); + let reply = client.poll().await.expect("poll failed"); + + let update = reply.update().expect("missing update"); + let update = update.fetch().await.expect("failed to fetch update info"); + let chunk = update.chunks().next().unwrap(); + let art = chunk.artifacts().next().unwrap(); + + async fn check_download(mut stream: Box> + Unpin>) { + let mut downloaded: Vec = Vec::new(); + while let Some(b) = stream.next().await { + downloaded.extend(b.unwrap().as_ref()); + } + + // Compare downloaded content with the actual file + let mut art_file = File::open(&artifact_path()).expect("failed to open artifact"); + let mut expected = Vec::new(); + art_file + .read_to_end(&mut expected) + .expect("failed to read artifact"); + + assert_eq!(downloaded, expected); + } + + // Download artifact using the stream API + let stream = art + .download_stream() + .await + .expect("failed to get download stream"); + check_download(Box::new(stream)).await; + + cfg_if::cfg_if! { + if #[cfg(feature = "hash-md5")] { + let stream = art + .download_stream_with_md5_check() + .await + .expect("failed to get download stream"); + check_download(Box::new(stream)).await; + } + } + + cfg_if::cfg_if! { + if #[cfg(feature = "hash-sha1")] { + let stream = art + .download_stream_with_sha1_check() + .await + .expect("failed to get download stream"); + check_download(Box::new(stream)).await; + } + } + + cfg_if::cfg_if! { + if #[cfg(feature = "hash-sha256")] { + let stream = art + .download_stream_with_sha256_check() + .await + .expect("failed to get download stream"); + check_download(Box::new(stream)).await; + } + } +} + +#[cfg(feature = "hash-digest")] +#[tokio::test] +async fn wrong_checksums() { + use assert_matches::assert_matches; + use hawkbit::ddi::ChecksumType; + + init(); + + let server = ServerBuilder::default().build(); + let (client, target) = add_target(&server, "Target1"); + + target.push_deployment(get_deployment(false)); + let reply = client.poll().await.expect("poll failed"); + + let update = reply.update().expect("missing update"); + let update = update.fetch().await.expect("failed to fetch update info"); + let chunk = update.chunks().next().unwrap(); + let art = chunk.artifacts().next().unwrap(); + + let out_dir = TempDir::new("test-hawkbitrs").expect("Failed to create temp dir"); + let downloaded = art + .download(out_dir.path()) + .await + .expect("failed to download artifact"); + + #[cfg(feature = "hash-md5")] + assert_matches!( + downloaded.check_md5().await, + Err(Error::ChecksumError(ChecksumType::Md5)) + ); + #[cfg(feature = "hash-sha1")] + assert_matches!( + downloaded.check_sha1().await, + Err(Error::ChecksumError(ChecksumType::Sha1)) + ); + #[cfg(feature = "hash-sha256")] + assert_matches!( + downloaded.check_sha256().await, + Err(Error::ChecksumError(ChecksumType::Sha256)) + ); + + cfg_if::cfg_if! { + if #[cfg(feature = "hash-md5")] { + let stream = art + .download_stream_with_md5_check() + .await + .expect("failed to get download stream"); + let end = stream.skip_while(|b| future::ready(b.is_ok())).next().await; + assert_matches!(end, Some(Err(Error::ChecksumError(ChecksumType::Md5)))); + } + } + + cfg_if::cfg_if! { + if #[cfg(feature = "hash-sha1")] { + let stream = art + .download_stream_with_sha1_check() + .await + .expect("failed to get download stream"); + let end = stream.skip_while(|b| future::ready(b.is_ok())).next().await; + assert_matches!(end, Some(Err(Error::ChecksumError(ChecksumType::Sha1)))); + } + } + + cfg_if::cfg_if! { + if #[cfg(feature = "hash-sha256")] { + let stream = art + .download_stream_with_sha256_check() + .await + .expect("failed to get download stream"); + let end = stream.skip_while(|b| future::ready(b.is_ok())).next().await; + assert_matches!(end, Some(Err(Error::ChecksumError(ChecksumType::Sha256)))); + } + } +}