Skip to content

Commit

Permalink
add JSON export to CLI (#23)
Browse files Browse the repository at this point in the history
* add export of ORA and file overwrite prompt

* add output to each method
  • Loading branch information
iblacksand authored Apr 23, 2024
1 parent eee42e4 commit 380f19a
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 121 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ repository = "https://github.com/bzhanglab/webgestalt_rust"
bincode = "1.3.3"
clap = { version = "4.4.15", features = ["derive"] }
owo-colors = { version = "4.0.0", features = ["supports-colors"] }
serde_json = "1.0.116"
webgestalt_lib = { version = "0.3.0", path = "webgestalt_lib" }

[profile.release]
Expand Down
206 changes: 89 additions & 117 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use bincode::deserialize_from;
use clap::{Args, Parser};
use clap::{Subcommand, ValueEnum};
use owo_colors::{OwoColorize, Stream::Stdout, Style};
use std::io::{BufReader, Write};
use std::io::Write;
use std::{fs::File, time::Instant};
use webgestalt_lib::methods::gsea::GSEAConfig;
use webgestalt_lib::methods::multilist::{combine_gmts, MultiListMethod, NormalizationMethod};
use webgestalt_lib::methods::nta::NTAConfig;
use webgestalt_lib::methods::ora::ORAConfig;
use webgestalt_lib::readers::utils::Item;
use webgestalt_lib::readers::{read_gmt_file, read_rank_file};
use webgestalt_lib::{MalformedError, WebGestaltError};

/// WebGestalt CLI.
/// ORA and GSEA enrichment tool.
Expand All @@ -24,8 +22,6 @@ struct CliArgs {

#[derive(Subcommand)]
enum Commands {
/// Benchmark different file formats for gmt. TODO: Remove later
Benchmark,
/// Run provided examples for various types of analyses
Example(ExampleArgs),
/// Run GSEA on the provided files
Expand All @@ -34,8 +30,6 @@ enum Commands {
Ora(ORAArgs),
/// Run NTA on the provided files
Nta(NtaArgs),
/// Run a test
Test,
/// Combine multiple files into a single file
Combine(CombineArgs),
}
Expand Down Expand Up @@ -65,7 +59,7 @@ struct NtaArgs {
seeds: String,
/// Output path for the results
#[arg(short, long)]
out: String,
output: String,
/// Probability of random walk resetting
#[arg(short, long, default_value = "0.5")]
reset_probability: f64,
Expand All @@ -77,8 +71,8 @@ struct NtaArgs {
neighborhood_size: usize,
/// Method to use for NTA
/// Options: prioritize, expand
#[arg(short, long)]
method: Option<NTAMethodClap>,
#[arg(short, long, default_value = "prioritize")]
method: NTAMethodClap,
}

#[derive(ValueEnum, Clone)]
Expand All @@ -90,19 +84,29 @@ enum NTAMethodClap {
#[derive(Args)]
struct GseaArgs {
/// Path to the GMT file of interest
gmt: Option<String>,
#[arg(short, long)]
gmt: String,
/// Path to the rank file of interest
rnk: Option<String>,
#[arg(short, long)]
rnk: String,
/// Output path for the results
#[arg(short, long, default_value = "out.json")]
output: String,
}

#[derive(Args)]
#[derive(Parser)]
struct ORAArgs {
/// Path to the GMT file of interest
gmt: Option<String>,
#[arg(short, long)]
gmt: String,
/// Path to the file containing the interesting analytes
interest: Option<String>,
#[arg(short, long)]
interest: String,
/// Output path for the results
#[arg(short, long, default_value = "out.json")]
output: String,
/// Path the file containing the reference list
reference: Option<String>,
#[arg(short, long)]
reference: String,
}

#[derive(Args)]
Expand Down Expand Up @@ -146,13 +150,43 @@ struct CombineListArgs {
files: Vec<String>,
}

fn prompt_yes_no(question: &str) -> bool {
loop {
print!("{} (y/n): ", question);
std::io::stdout().flush().expect("Could not flush stdout!"); // Ensure the prompt is displayed

let mut input = String::new();
std::io::stdin()
.read_line(&mut input)
.expect("Could not read line");
print!("\x1B[2J\x1B[1;1H");
std::io::stdout().flush().expect("Could not flush stdout!");
match input.trim().to_lowercase().as_str() {
"y" => return true,
"n" => return false,
_ => println!("Invalid input. Please enter 'y' or 'n'."),
}
}
}

fn check_and_overwrite(file_path: &str) {
// Check if the file exists
if std::path::Path::new(file_path).exists() {
// Check if the user wants to overwrite the file
if !prompt_yes_no(&format!(
"File at {} already exists. Do you want to overwrite it?",
file_path
)) {
println!("Stopping analysis.");
std::process::exit(1);
};
}
}

fn main() {
println!("WebGestalt CLI v{}", env!("CARGO_PKG_VERSION"));
let args = CliArgs::parse();
match &args.command {
Some(Commands::Benchmark) => {
benchmark();
}
Some(Commands::Example(ex)) => match &ex.commands {
Some(ExampleOptions::Gsea) => {
let gene_list = webgestalt_lib::readers::read_rank_file(
Expand All @@ -177,7 +211,7 @@ fn main() {
"webgestalt_lib/data/genelist.txt".to_owned(),
"webgestalt_lib/data/reference.txt".to_owned(),
);
let gmtcount = gmt.len();
let gmt_count = gmt.len();
let start = Instant::now();
let x: Vec<webgestalt_lib::methods::ora::ORAResult> =
webgestalt_lib::methods::ora::get_ora(
Expand All @@ -187,6 +221,8 @@ fn main() {
ORAConfig::default(),
);
let mut count = 0;
let output_file = File::create("test.json").expect("Could not create output file!");
serde_json::to_writer(output_file, &x).expect("Could not create JSON file!");
for i in x {
if i.p < 0.05 && i.fdr < 0.05 {
println!("{}: {}, {}, {}", i.set, i.p, i.fdr, i.overlap);
Expand All @@ -196,56 +232,48 @@ fn main() {
let duration = start.elapsed();
println!(
"ORA\nTime took: {:?}\nFound {} significant pathways out of {} pathways",
duration, count, gmtcount
duration, count, gmt_count
);
}
_ => {
println!("Please select a valid example: ora or gsea.");
}
},
Some(Commands::Gsea(gsea_args)) => {
let style = Style::new().red().bold();
if gsea_args.gmt.is_none() || gsea_args.rnk.is_none() {
println!(
"{}: DID NOT PROVIDE PATHS FOR GMT AND RANK FILE.",
"ERROR".if_supports_color(Stdout, |text| text.style(style))
);
return;
}
let gene_list = webgestalt_lib::readers::read_rank_file(gsea_args.rnk.clone().unwrap())
check_and_overwrite(&gsea_args.output);
let gene_list = webgestalt_lib::readers::read_rank_file(gsea_args.rnk.clone())
.unwrap_or_else(|_| {
panic!("File {} not found", gsea_args.rnk.clone().unwrap());
});
let gmt = webgestalt_lib::readers::read_gmt_file(gsea_args.gmt.clone().unwrap())
.unwrap_or_else(|_| {
panic!("File {} not found", gsea_args.gmt.clone().unwrap());
panic!("File {} not found", gsea_args.rnk.clone());
});
let gmt = webgestalt_lib::readers::read_gmt_file(gsea_args.gmt.clone()).unwrap_or_else(
|_| {
panic!("File {} not found", gsea_args.gmt.clone());
},
);
let res =
webgestalt_lib::methods::gsea::gsea(gene_list, gmt, GSEAConfig::default(), None);
let output_file =
File::create(&gsea_args.output).expect("Could not create output file!");
serde_json::to_writer(output_file, &res).expect("Could not create JSON file!");
let mut count = 0;
for i in res {
if i.p < 0.05 && i.fdr < 0.05 {
println!("{}: {}, {}", i.set, i.p, i.fdr);
count += 1;
}
}
println!("Done with GSEA: {}", count);
println!(
"Done with GSEA and found {} significant analyte sets",
count
);
}
Some(Commands::Ora(ora_args)) => {
let style = Style::new().red().bold();
if ora_args.gmt.is_none() || ora_args.interest.is_none() || ora_args.reference.is_none()
{
println!(
"{}: DID NOT PROVIDE PATHS FOR GMT, INTEREST, AND REFERENCE FILE.",
"ERROR".if_supports_color(Stdout, |text| text.style(style))
);
return;
}
check_and_overwrite(&ora_args.output);
let start = Instant::now();
let (gmt, interest, reference) = webgestalt_lib::readers::read_ora_files(
ora_args.gmt.clone().unwrap(),
ora_args.interest.clone().unwrap(),
ora_args.reference.clone().unwrap(),
ora_args.gmt.clone(),
ora_args.interest.clone(),
ora_args.reference.clone(),
);
println!("Reading Took {:?}", start.elapsed());
let start = Instant::now();
Expand All @@ -255,6 +283,9 @@ fn main() {
gmt,
ORAConfig::default(),
);
let output_file =
File::create(&ora_args.output).expect("Could not create output file!");
serde_json::to_writer(output_file, &res).expect("Could not create JSON file!");
println!("Analysis Took {:?}", start.elapsed());
let mut count = 0;
for row in res.iter() {
Expand All @@ -263,42 +294,33 @@ fn main() {
}
}
println!(
"Found {} significant pathways out of {} pathways",
"Found {} significant analyte sets out of {} sets",
count,
res.len()
);
}
Some(Commands::Test) => will_err(1).unwrap_or_else(|x| println!("{}", x)),
Some(Commands::Nta(nta_args)) => {
let style = Style::new().fg_rgb::<255, 179, 71>().bold();
check_and_overwrite(&nta_args.output);
let network = webgestalt_lib::readers::read_edge_list(nta_args.network.clone());
let start = Instant::now();
if nta_args.method.is_none() {
println!(
"{}: DID NOT PROVIDE A METHOD FOR NTA. USING DEFAULT EXPAND METHOD.",
"WARNING".if_supports_color(Stdout, |text| text.style(style))
);
};
let nta_method = match nta_args.method {
Some(NTAMethodClap::Prioritize) => webgestalt_lib::methods::nta::NTAMethod::Prioritize(
nta_args.neighborhood_size,
),
Some(NTAMethodClap::Expand) => webgestalt_lib::methods::nta::NTAMethod::Expand(
nta_args.neighborhood_size,
),
None => webgestalt_lib::methods::nta::NTAMethod::Expand(nta_args.neighborhood_size),
NTAMethodClap::Prioritize => {
webgestalt_lib::methods::nta::NTAMethod::Prioritize(nta_args.neighborhood_size)
}
NTAMethodClap::Expand => {
webgestalt_lib::methods::nta::NTAMethod::Expand(nta_args.neighborhood_size)
}
};
let config: NTAConfig = NTAConfig {
edge_list: network,
seeds: webgestalt_lib::readers::read_seeds(nta_args.seeds.clone()),
reset_probability: nta_args.reset_probability,
tolerance: nta_args.tolerance,
method: Some(nta_method),

};
let res = webgestalt_lib::methods::nta::get_nta(config);
println!("Analysis Took {:?}", start.elapsed());
webgestalt_lib::writers::save_nta(nta_args.out.clone(), res).unwrap();
webgestalt_lib::writers::save_nta(nta_args.output.clone(), res).unwrap();
}
Some(Commands::Combine(args)) => match &args.combine_type {
Some(CombineType::Gmt(gmt_args)) => {
Expand Down Expand Up @@ -374,53 +396,3 @@ fn main() {
}
}
}

fn benchmark() {
let mut bin_durations: Vec<f64> = Vec::new();
for _i in 0..1000 {
let start = Instant::now();
let mut r = BufReader::new(File::open("test.gmt.wga").unwrap());
let _x: Vec<webgestalt_lib::readers::utils::Item> = deserialize_from(&mut r).unwrap();
let duration = start.elapsed();
bin_durations.push(duration.as_secs_f64())
}
let mut gmt_durations: Vec<f64> = Vec::new();
for _i in 0..1000 {
let start = Instant::now();
let _x = webgestalt_lib::readers::read_gmt_file("webgestalt_lib/data/ktest.gmt".to_owned())
.unwrap();
let duration = start.elapsed();
gmt_durations.push(duration.as_secs_f64())
}
let gmt_avg: f64 = gmt_durations.iter().sum::<f64>() / gmt_durations.len() as f64;
let bin_avg: f64 = bin_durations.iter().sum::<f64>() / bin_durations.len() as f64;
let improvement: f64 = 100.0 * (gmt_avg - bin_avg) / gmt_avg;
println!(
" GMT time: {}\tGMT.WGA time: {}\n Improvement: {:.1}%",
gmt_avg, bin_avg, improvement
);
let mut whole_file: Vec<String> = Vec::new();
whole_file.push("type\ttime".to_string());
for line in bin_durations {
whole_file.push(format!("bin\t{:?}", line));
}
for line in gmt_durations {
whole_file.push(format!("gmt\t{:?}", line));
}
let mut ftsv = File::create("format_benchmarks.tsv").unwrap();
writeln!(ftsv, "{}", whole_file.join("\n")).unwrap();
}

fn will_err(x: i32) -> Result<(), WebGestaltError> {
if x == 0 {
Ok(())
} else {
Err(WebGestaltError::MalformedFile(MalformedError {
path: String::from("ExamplePath.txt"),
kind: webgestalt_lib::MalformedErrorType::WrongFormat {
found: String::from("GMT"),
expected: String::from("rank"),
},
}))
}
}
4 changes: 2 additions & 2 deletions webgestalt_lib/src/readers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ pub fn read_intersection_list(path: String, ref_list: &AHashSet<String>) -> AHas
}

/// Read edge list from specified path. Separated by whitespace with no support for weights
///
///
/// # Parameters
/// path - A [`String`] of the path of the edge list to read.
///
///
/// # Returns
/// A [`Vec<Vec<String>>`] containing the edge list
pub fn read_edge_list(path: String) -> Vec<Vec<String>> {
Expand Down

0 comments on commit 380f19a

Please sign in to comment.