Skip to content

Commit

Permalink
add pysam and remmove samtools
Browse files Browse the repository at this point in the history
  • Loading branch information
fangli80 committed Jun 1, 2024
1 parent 54e10de commit 8e3650b
Showing 1 changed file with 73 additions and 63 deletions.
136 changes: 73 additions & 63 deletions src/NanoRepeat/nanoRepeat_bam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
import sys
import shutil
import numpy as np
import pysam
from typing import List
import concurrent.futures
import subprocess
import multiprocessing


from NanoRepeat import tk
from NanoRepeat.repeat_region import *
Expand Down Expand Up @@ -652,8 +654,41 @@ def split_allele_using_gmm_1d(repeat_region:RepeatRegion, ploidy, error_rate, ma
plot_repeat_counts_1d(readinfo_dict, allele_list, repeat_region.to_unique_id(), out_prefix)

return

def extract_fastq_from_bam(in_bam_file, repeat_region:RepeatRegion, flank_dist, out_fastq_file):

assert (flank_dist >= 0)
chrom = repeat_region.chrom
start_pos = repeat_region.start_pos - flank_dist
end_pos = repeat_region.end_pos + flank_dist
if start_pos < 0: start_pos = 0

bam = pysam.AlignmentFile(in_bam_file, "rb")
with open(out_fastq_file, 'w') as fastq:
for read in bam.fetch(chrom, start_pos, end_pos):
if read.query_sequence == None: continue
if read.query_qualities == None: continue
fastq.write(f'@{read.query_name}\n{read.query_sequence}\n+\n')
fastq.write(''.join([chr(q + 33) for q in read.query_qualities]) + '\n')

bam.close()
return

def quantify1repeat_from_bam_worker(process_id, num_para_regions, num_threads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region_list, result_queue):
result_list = []
for i in range(process_id, len(repeat_region_list), num_para_regions):
process_name = f'Process {process_id:02}'
repeat_region = repeat_region_list[i]
quantified_repeat_region = quantify1repeat_from_bam(process_name, num_threads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region)
result_list.append(quantified_repeat_region)

result_queue.put(result_list)

return

def quantify1repeat_from_bam(process_name, num_threads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region):

def quantify1repeat_from_bam(process_id, num_theads_per_region, input_args, error_rate, in_bam_file:string, ref_fasta_dict:dict, repeat_region:RepeatRegion):
tk.eprint(f'NOTICE: [{process_name}] Quantifying repeat: {repeat_region.to_outfile_prefix()}')

def _clean_and_exit(input_args, repeat_region:RepeatRegion):

Expand All @@ -662,20 +697,16 @@ def _clean_and_exit(input_args, repeat_region:RepeatRegion):

if repeat_region.no_details == True:
shutil.rmtree(f'{input_args.out_prefix}.details')

out_string = f'{repeat_region.to_tab_invertal()}\t{repeat_region.repeat_unit_seq}\t'
num_alleles = len(repeat_region.results.quantified_allele_list)
out_string += f'{num_alleles}\t{repeat_region.results.max_repeat_size1()}\t{repeat_region.results.min_repeat_size1()}\t{repeat_region.results.allele_summary()}\t{repeat_region.results.read_summary()}\n'
repeat_region.final_output = out_string

return

repeat_region.get_final_output()
return repeat_region

if repeat_region.chrom[0:3].lower() != 'chr':
formatted_chr_name = 'chr' + repeat_region.chrom
else:
formatted_chr_name = repeat_region.chrom

temp_out_dir = f'{input_args.out_prefix}.NanoRepeat_temp_{process_id:02}.{repeat_region.to_outfile_prefix()}'
temp_out_dir = f'{input_args.out_prefix}.NanoRepeat_temp.{repeat_region.to_outfile_prefix()}'
out_dir = f'{input_args.out_prefix}.details/{formatted_chr_name}'
os.makedirs(temp_out_dir, exist_ok=True)
os.makedirs(out_dir, exist_ok=True)
Expand All @@ -686,22 +717,7 @@ def _clean_and_exit(input_args, repeat_region:RepeatRegion):

# extract reads from bam file
repeat_region.region_fq_file = os.path.join(temp_out_dir, f'{repeat_region.to_outfile_prefix()}.fastq')
region_bam_file = f'{repeat_region.out_prefix}.sorted.bam'

'''
cmd = f'{input_args.samtools} view -hb {in_bam_file} {repeat_region.to_invertal(flank_dist=repeat_region.anchor_len)} > {region_bam_file}'
tk.run_system_cmd(cmd)
cmd = f'{input_args.samtools} index {region_bam_file}'
tk.run_system_cmd(cmd)
cmd = f'{input_args.samtools} fastq {region_bam_file} > {repeat_region.region_fq_file}'
tk.run_system_cmd(cmd)
'''

cmd = f'{input_args.samtools} view -h {in_bam_file} {repeat_region.to_invertal(flank_dist=repeat_region.anchor_len)} | {input_args.samtools} fastq - > {repeat_region.region_fq_file}'

tk.run_system_cmd(cmd)
extract_fastq_from_bam(in_bam_file, repeat_region, repeat_region.anchor_len, repeat_region.region_fq_file)

fastq_file_size = os.path.getsize(repeat_region.region_fq_file)

Expand All @@ -712,39 +728,31 @@ def _clean_and_exit(input_args, repeat_region:RepeatRegion):
# extract ref sequence
extract_ref_sequence(ref_fasta_dict, repeat_region)

refine_repeat_region_in_ref(input_args.minimap2, repeat_region, num_theads_per_region)
refine_repeat_region_in_ref(input_args.minimap2, repeat_region, num_threads_per_region)
if repeat_region.ref_has_issue == True and input_args.save_temp_files == False:
return _clean_and_exit(input_args, repeat_region)

tk.eprint(f'NOTICE: [Process {process_id:02}] Step 1: finding anchor location in reads')
find_anchor_locations_in_reads(input_args.minimap2, input_args.data_type, repeat_region, num_theads_per_region)
tk.eprint(f'NOTICE: [{process_name}] Step 1: finding anchor location in reads')
find_anchor_locations_in_reads(input_args.minimap2, input_args.data_type, repeat_region, num_threads_per_region)

# make core sequence fastq
make_core_seq_fastq(repeat_region)

tk.eprint(f'NOTICE: [Process {process_id:02}] Step 2: round 1 and round 2 estimation')
round1_and_round2_estimation(input_args.minimap2, input_args.data_type, repeat_region, num_theads_per_region)
tk.eprint(f'NOTICE: [{process_name}] Step 2: round 1 and round 2 estimation')
round1_and_round2_estimation(input_args.minimap2, input_args.data_type, repeat_region, num_threads_per_region)

tk.eprint(f'NOTICE: [Process {process_id:02}] Step 3: round 3 estimation')
round3_estimation(input_args.minimap2, input_args.data_type, input_args.fast_mode, repeat_region, num_theads_per_region)
tk.eprint(f'NOTICE: [{process_name}] Step 3: round 3 estimation')
round3_estimation(input_args.minimap2, input_args.data_type, input_args.fast_mode, repeat_region, num_threads_per_region)

output_repeat_size_1d(repeat_region)

tk.eprint(f'NOTICE: [Process {process_id:02}] Step 4: phasing reads using GMM')
tk.eprint(f'NOTICE: [{process_name}] Step 4: phasing reads using GMM')
split_allele_using_gmm_1d(repeat_region, input_args.ploidy, error_rate, input_args.max_mutual_overlap, input_args.max_num_components, input_args.remove_noisy_reads)

return _clean_and_exit(input_args, repeat_region)

def quantify_repeats_from_bam_1process(args):
process_id, num_para_regions, num_theads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region_list = args
for i in range(0, len(repeat_region_list)):
if i % num_para_regions != process_id: continue

repeat_region = repeat_region_list[i]
tk.eprint(f'NOTICE: [Process {process_id:02}] Quantifying repeat: {repeat_region.to_outfile_prefix()}')
quantify1repeat_from_bam(process_id, num_theads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region)
return


def nanoRepeat_bam (input_args, in_bam_file:string):

# ont, ont_sup, ont_q20, clr, hifi
Expand All @@ -771,29 +779,31 @@ def nanoRepeat_bam (input_args, in_bam_file:string):
num_para_regions = min(max_num_para_regions, len(repeat_region_list), input_args.num_cpu)
num_threads_per_region = int(input_args.num_cpu / num_para_regions)

args_list = []
for i in range(0, len(repeat_region_list)):
repeat_region_list[i].index = i

processes = []
result_queue = multiprocessing.Queue()
for process_id in range(num_para_regions):
args = (process_id, num_para_regions, num_threads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region_list)
args_list.append(args)

with concurrent.futures.ProcessPoolExecutor(max_workers=max_num_para_regions) as executor:

future_to_args = {executor.submit(quantify_repeats_from_bam_1process, args): args for args in args_list}
for future in concurrent.futures.as_completed(future_to_args):
args = future_to_args[future]
process_id, num_para_regions, num_theads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region_list = args
try:
future.result()
except Exception as exc:
tk.eprint(f'Process {process_id} generated an exception: {exc}')
else:
tk.eprint(f'Process {process_id} completed successfully.')

p = multiprocessing.Process(target=quantify1repeat_from_bam_worker, args=(process_id, num_para_regions, num_threads_per_region, input_args, error_rate, in_bam_file, ref_fasta_dict, repeat_region_list, result_queue))
processes.append(p)
p.start()

quantified_repeat_region = []
for p in processes:
quantified_repeat_region.extend(result_queue.get())

for p in processes:
p.join()

quantified_repeat_region.sort(key = lambda repeat_region:repeat_region.index)

out_tsv_file = f'{input_args.out_prefix}.NanoRepeat_output.tsv'
out_tsv_f = open(out_tsv_file, 'w')

for i in range(0, len(repeat_region_list)):
out_tsv_f.write(repeat_region_list[i].final_output)
for i in range(0, len(quantified_repeat_region)):
quantified_repeat_region[i].get_final_output()
out_tsv_f.write(quantified_repeat_region[i].final_output)
out_tsv_f.close()

tk.eprint('NOTICE: Program finished.')
Expand Down

0 comments on commit 8e3650b

Please sign in to comment.