Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions hawkbit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repository = "https://github.com/collabora/hawkbit-rs"
documentation = "https://docs.rs/hawkbit_mock/"

[dependencies]
reqwest = { version = "0.10", features = ["json"] }
reqwest = { version = "0.10", features = ["json", "stream"] }
tokio = { version = "0.2", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand All @@ -23,6 +23,9 @@ digest = { version = "0.9", optional = true }
md-5 = { version = "0.9", optional = true }
sha-1 = { version = "0.9", optional = true }
sha2 = { version = "0.9", optional = true }
generic-array = {version = "0.14", optional = true }
futures = "0.3"
bytes = "0.5.6"

[dev-dependencies]
hawkbit_mock = { path = "../hawkbit_mock/" }
Expand All @@ -31,9 +34,10 @@ anyhow = "1.0"
log = "0.4"
env_logger = "0.8"
tempdir = "0.3"
assert_matches = "1.4"

[features]
hash-digest= ["digest"]
hash-digest= ["digest", "generic-array"]
hash-md5 = ["md-5", "hash-digest"]
hash-sha1 = ["sha-1", "hash-digest"]
hash-sha256 = ["sha2", "hash-digest"]
2 changes: 2 additions & 0 deletions hawkbit/src/ddi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ mod poll;
pub use client::{Client, Error};
pub use common::{Execution, Finished};
pub use config_data::{ConfigRequest, Mode};
#[cfg(feature = "hash-digest")]
pub use deployment_base::ChecksumType;
pub use deployment_base::{
Artifact, Chunk, DownloadedArtifact, MaintenanceWindow, Type, Update, UpdatePreFetch,
};
Expand Down
4 changes: 4 additions & 0 deletions hawkbit/src/ddi/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ pub enum Error {
/// IO error
#[error("Failed to download update")]
Io(#[from] std::io::Error),
/// Invalid checksum
#[cfg(feature = "hash-digest")]
#[error("Invalid Checksum")]
ChecksumError(crate::ddi::deployment_base::ChecksumType),
}

impl Client {
Expand Down
261 changes: 220 additions & 41 deletions hawkbit/src/ddi/deployment_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

use std::path::{Path, PathBuf};

use reqwest::Client;
use bytes::Bytes;
use futures::{prelude::*, TryStreamExt};
use reqwest::{Client, Response};
use serde::{Deserialize, Serialize};
use tokio::{
fs::{DirBuilder, File},
Expand Down Expand Up @@ -319,15 +321,20 @@ impl<'a> Artifact<'a> {
self.artifact.size
}

/// Download the artifact file to the directory defined in `dir`.
pub async fn download(&'a self, dir: &Path) -> Result<DownloadedArtifact, Error> {
let mut resp = self
async fn download_response(&'a self) -> Result<Response, Error> {
let resp = self
.client
.get(&self.artifact.links.download_http.to_string())
.send()
.await?;

resp.error_for_status_ref()?;
Ok(resp)
}

/// Download the artifact file to the directory defined in `dir`.
pub async fn download(&'a self, dir: &Path) -> Result<DownloadedArtifact, Error> {
let mut resp = self.download_response().await?;

if !dir.exists() {
DirBuilder::new().recursive(true).create(dir).await?;
Expand All @@ -346,6 +353,76 @@ impl<'a> Artifact<'a> {
self.artifact.hashes.clone(),
))
}

/// Provide a `Stream` of `Bytes` to download the artifact.
///
/// This can be used as an alternative to [`Artifact::download`],
/// for example, to extract an archive while it's being downloaded,
/// saving the need to store the archive file on disk.
pub async fn download_stream(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let resp = self.download_response().await?;

Ok(resp.bytes_stream().map_err(|e| e.into()))
}

/// Provide a `Stream` of `Bytes` to download the artifact while checking md5 checksum.
///
/// The stream will yield the same data as [`Artifact::download_stream`] but will raise
/// an error if the md5sum of the downloaded data does not match the one provided by the server.
#[cfg(feature = "hash-md5")]
pub async fn download_stream_with_md5_check(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let stream = self.download_stream().await?;
let hasher = DownloadHasher::new_md5(self.artifact.hashes.md5.clone());

let stream = DownloadStreamHash {
stream: Box::new(stream),
hasher,
};

Ok(stream)
}

/// Provide a `Stream` of `Bytes` to download the artifact while checking sha1 checksum.
///
/// The stream will yield the same data as [`Artifact::download_stream`] but will raise
/// an error if the sha1sum of the downloaded data does not match the one provided by the server.
#[cfg(feature = "hash-sha1")]
pub async fn download_stream_with_sha1_check(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let stream = self.download_stream().await?;
let hasher = DownloadHasher::new_sha1(self.artifact.hashes.sha1.clone());

let stream = DownloadStreamHash {
stream: Box::new(stream),
hasher,
};

Ok(stream)
}

/// Provide a `Stream` of `Bytes` to download the artifact while checking sha256 checksum.
///
/// The stream will yield the same data as [`Artifact::download_stream`] but will raise
/// an error if the sha256sum of the downloaded data does not match the one provided by the server.
#[cfg(feature = "hash-sha256")]
pub async fn download_stream_with_sha256_check(
&'a self,
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
let stream = self.download_stream().await?;
let hasher = DownloadHasher::new_sha256(self.artifact.hashes.sha256.clone());

let stream = DownloadStreamHash {
stream: Box::new(stream),
hasher,
};

Ok(stream)
}
}

/// A downloaded file part of a [`Chunk`].
Expand All @@ -357,28 +434,140 @@ pub struct DownloadedArtifact {

cfg_if::cfg_if! {
if #[cfg(feature = "hash-digest")] {
use std::{
pin::Pin,
task::Poll,
};
use digest::Digest;
use thiserror::Error;

const HASH_BUFFER_SIZE: usize = 4096;

#[derive(Error, Debug)]
pub enum ChecksumError {
#[error("Failed to compute checksum")]
Io(#[from] std::io::Error),
#[error("Checksum {0} does not match")]
Invalid(CheckSumType),
}

#[derive(Debug, strum::Display)]
pub enum CheckSumType {
/// Enum representing the different type of supported checksums
#[derive(Debug, strum::Display, Clone)]
pub enum ChecksumType {
/// md5
#[cfg(feature = "hash-md5")]
Md5,
/// sha1
#[cfg(feature = "hash-sha1")]
Sha1,
/// sha256
#[cfg(feature = "hash-sha256")]
Sha256,
}

// quite complex trait bounds because of requirements so LowerHex is implemented on the output
#[derive(Clone)]
struct DownloadHasher<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
{
hasher: T,
expected: String,
error: ChecksumType,
}

impl<T> DownloadHasher<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>
{
fn update(&mut self, data: impl AsRef<[u8]>) {
self.hasher.update(data);
}

fn finalize(self) -> Result<(), Error> {
let digest = self.hasher.finalize();

if format!("{:x}", digest) == self.expected {
Ok(())
} else {
Err(Error::ChecksumError(self.error))
}
}
}

#[cfg(feature = "hash-md5")]
impl DownloadHasher<md5::Md5> {
fn new_md5(expected: String) -> Self {
Self {
hasher: md5::Md5::new(),
expected,
error: ChecksumType::Md5,
}
}
}

#[cfg(feature = "hash-sha1")]
impl DownloadHasher<sha1::Sha1> {
fn new_sha1(expected: String) -> Self {
Self {
hasher: sha1::Sha1::new(),
expected,
error: ChecksumType::Sha1,
}
}
}

#[cfg(feature = "hash-sha256")]
impl DownloadHasher<sha2::Sha256> {
fn new_sha256(expected: String) -> Self {
Self {
hasher: sha2::Sha256::new(),
expected,
error: ChecksumType::Sha256,
}
}
}

struct DownloadStreamHash<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
{
stream: Box<dyn Stream<Item = Result<Bytes, Error>> + Unpin + Send + Sync>,
hasher: DownloadHasher<T>,
}

impl<T> Stream for DownloadStreamHash<T>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
T: Unpin,
T: Clone,
{
type Item = Result<Bytes, Error>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let me = Pin::into_inner(self);

match Pin::new(&mut me.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => {
// feed data to the hasher and then pass them back to the stream
me.hasher.update(&data);
Poll::Ready(Some(Ok(data)))
}
Poll::Ready(None) => {
// download is done, check the hash
match me.hasher.clone().finalize() {
Ok(_) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(e))),
}
}
// passthrough on errors and pendings
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
}
}
}
}
}

Expand All @@ -393,7 +582,12 @@ impl<'a> DownloadedArtifact {
}

#[cfg(feature = "hash-digest")]
async fn hash<T: Digest>(&self, mut hasher: T) -> Result<digest::Output<T>, ChecksumError> {
async fn hash<T>(&self, mut hasher: DownloadHasher<T>) -> Result<(), Error>
where
T: Digest,
<T as Digest>::OutputSize: core::ops::Add,
<<T as Digest>::OutputSize as core::ops::Add>::Output: generic_array::ArrayLength<u8>,
{
use tokio::io::AsyncReadExt;

let mut file = File::open(&self.file).await?;
Expand All @@ -407,42 +601,27 @@ impl<'a> DownloadedArtifact {
hasher.update(&buffer[..n]);
}

Ok(hasher.finalize())
hasher.finalize()
}

/// Check if the md5sum of the downloaded file matches the one provided by the server.
#[cfg(feature = "hash-md5")]
pub async fn check_md5(&self) -> Result<(), ChecksumError> {
let digest = self.hash(md5::Md5::new()).await?;

if format!("{:x}", digest) == self.hashes.md5 {
Ok(())
} else {
Err(ChecksumError::Invalid(CheckSumType::Md5))
}
pub async fn check_md5(&self) -> Result<(), Error> {
let hasher = DownloadHasher::new_md5(self.hashes.md5.clone());
self.hash(hasher).await
}

/// Check if the sha1sum of the downloaded file matches the one provided by the server.
#[cfg(feature = "hash-sha1")]
pub async fn check_sha1(&self) -> Result<(), ChecksumError> {
let digest = self.hash(sha1::Sha1::new()).await?;

if format!("{:x}", digest) == self.hashes.sha1 {
Ok(())
} else {
Err(ChecksumError::Invalid(CheckSumType::Sha1))
}
pub async fn check_sha1(&self) -> Result<(), Error> {
let hasher = DownloadHasher::new_sha1(self.hashes.sha1.clone());
self.hash(hasher).await
}

/// Check if the sha256sum of the downloaded file matches the one provided by the server.
#[cfg(feature = "hash-sha256")]
pub async fn check_sha256(&self) -> Result<(), ChecksumError> {
let digest = self.hash(sha2::Sha256::new()).await?;

if format!("{:x}", digest) == self.hashes.sha256 {
Ok(())
} else {
Err(ChecksumError::Invalid(CheckSumType::Sha256))
}
pub async fn check_sha256(&self) -> Result<(), Error> {
let hasher = DownloadHasher::new_sha256(self.hashes.sha256.clone());
self.hash(hasher).await
}
}
Loading