/**
>HEADER
Copyright (c) 2013 Rob Patro robp@cs.cmu.edu
This file is part of Salmon.
Salmon is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Salmon is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Salmon. If not, see .
#include
#include
namespace stats {
namespace utils {
arma::Mat make_covariance_matrix(const arma::Mat& data) {
return std::move((data.t() * data) * (1. / (data.n_rows - 1)));
}
arma::Mat make_shuffled_matrix(const arma::Mat& data) {
const long n_rows = data.n_rows;
const long n_cols = data.n_cols;
arma::Mat shuffle(n_rows, n_cols);
for (long j = 0; j < n_cols; ++j) {
for (long i = 0; i < n_rows; ++i) {
shuffle(i, j) = data(std::rand() % n_rows, j);
}
}
return std::move(shuffle);
}
arma::Col compute_column_means(const arma::Mat& data) {
const long n_cols = data.n_cols;
arma::Col means(n_cols);
for (long i = 0; i < n_cols; ++i)
means(i) = arma::mean(data.col(i));
return std::move(means);
}
void remove_column_means(arma::Mat& data,
const arma::Col& means) {
if (data.n_cols != means.n_elem)
throw std::range_error("Number of elements of means is not equal to the "
"number of columns of data");
for (long i = 0; i < long(data.n_cols); ++i)
data.col(i) -= means(i);
}
arma::Col compute_column_rms(const arma::Mat& data) {
const long n_cols = data.n_cols;
arma::Col rms(n_cols);
for (long i = 0; i < n_cols; ++i) {
const double dot = arma::dot(data.col(i), data.col(i));
rms(i) = std::sqrt(dot / (data.col(i).n_rows - 1));
}
return std::move(rms);
}
void normalize_by_column(arma::Mat& data,
const arma::Col& rms) {
if (data.n_cols != rms.n_elem)
throw std::range_error("Number of elements of rms is not equal to the "
"number of columns of data");
for (long i = 0; i < long(data.n_cols); ++i) {
if (rms(i) == 0)
throw std::runtime_error(
"At least one of the entries of rms equals to zero");
data.col(i) *= 1. / rms(i);
}
}
void enforce_positive_sign_by_column(arma::Mat& data) {
for (long i = 0; i < long(data.n_cols); ++i) {
const double max = arma::max(data.col(i));
const double min = arma::min(data.col(i));
bool change_sign = false;
if (std::abs(max) >= std::abs(min)) {
if (max < 0)
change_sign = true;
} else {
if (min < 0)
change_sign = true;
}
if (change_sign)
data.col(i) *= -1;
}
}
std::vector extract_column_vector(const arma::Mat& data,
long index) {
if (index < 0 || index >= long(data.n_cols))
throw std::range_error(join("Index out of range: ", index));
const long n_rows = data.n_rows;
const double* memptr = data.colptr(index);
std::vector result(memptr, memptr + n_rows);
return std::move(result);
}
std::vector extract_row_vector(const arma::Mat& data,
long index) {
if (index < 0 || index >= long(data.n_rows))
throw std::range_error(join("Index out of range: ", index));
const arma::Row row(data.row(index));
const double* memptr = row.memptr();
std::vector result(memptr, memptr + row.n_elem);
return std::move(result);
}
void assert_file_good(const bool& is_file_good, const std::string& filename) {
if (!is_file_good)
throw std::ios_base::failure(join("Cannot open file: ", filename));
}
double get_mean(const std::vector& iter) {
const double init = 0;
return std::accumulate(iter.begin(), iter.end(), init) / iter.size();
}
double get_sigma(const std::vector& iter) {
const double mean = get_mean(iter);
double sum = 0;
for (auto v = iter.begin(); v != iter.end(); ++v)
sum += std::pow(*v - mean, 2.);
return std::sqrt(sum / (iter.size() - 1));
}
} // namespace utils
} // namespace stats