Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ tar -xvf emu.tar

##### 2. Activate appropriate conda environment

Emu requires Python version to be >=3.6 and <3.11.
Emu requires Python version to be >=3.6.

###### Option A: Create new Conda environment

Expand Down Expand Up @@ -95,17 +95,20 @@ Each step of the installation process is expected to take a matter of seconds.
|--N| 50| max number of alignments utilized for each read in minimap2|
|--K| 500M| minibatch size for mapping in minimap2|
|--mm2-forward-only| FALSE| force minimap2 to consider the forward transcript strand only ([for long mRNA/cDNA reads](https://github.com/lh3/minimap2?tab=readme-ov-file#map-long-mrnacdna-reads))|
|--min-pid| 0| minimum percent identity (PID) based on NM tag|
|--min-align-len| 0| minimun aligned query length (excluding clipped bp)|
|--max-align-len| 2000| maximum aligned query length (excluding clipped bp)|
|--output-dir| ./results| directory for output results|
|--output-basename| stem of input_file(s)| basename of all output files saved in output-dir; default utilizes basename from input file(s)|
|--keep-files| FALSE| keep working files in output-dir ( alignments [.sam], reads of specied length [.fa])|
|--keep-counts| FALSE| include estimated read counts for each species in output*|
|--keep-read-assignments| FALSE| output .tsv file with read assignment distributions: each row as an input read; each entry as the likelihood it is dervied from that taxa (taxid is the column header); each row sums to 1|
|--output-unclassified| FALSE| generate two additional sequence files of unmapped and unclassified mapped input reads**|
|--output-unclassified| FALSE| generate three additional sequence files: unmapped, filtered mapped, and unclassified mapped input reads**|
|--threads| 3| number of threads utilized by minimap2|

*Estimated read counts are based on likelihood probabilities and therefore may not be integer values. They are calculated as the product of estimated relative abundance and total classified reads.

**Here, "unmapped reads" are reads that did not result in a mapping to the provided database with minimap2. "Unclassified mapped reads" are those that mapped only to database sequences of species that are presumed to not be present in the sample by Emu's algorithm (likely due to low overall abundance).
**Here, "unmapped" reads are reads that did not result in a mapping to the provided database with minimap2. "Filtered mapped" reads are those that were mapped with minimap2, but all alignments for the given query (read) were filtered via the align-len and percent identity (pid) requirement parameters. "Unclassified mapped" reads are those that mapped only to database sequences of species that are presumed to not be present in the sample by Emu's algorithm (likely due to low overall abundance).

Note: If you are experiencing heavy RAM consumption, first upgrade minimap2 to at least v2.22. If memory is still an issue, try decreasing the number of secondary alignments evaluated for each read (--N).

Expand Down
144 changes: 125 additions & 19 deletions emu
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ def get_align_len(alignment):
return sum(alignment.get_cigar_stats()[0][cigar_op] for cigar_op in CIGAR_OPS_ALL)


def get_aligned_query_len(alignment):
"""Retrieve aligned query length excluding soft/hard clipping.

alignment (pysam.AlignmentFile): align of interest
return (int): aligned query length excluding S/H
"""
return alignment.query_alignment_length


def output_sequences(in_path, seq_output_path, input_type, keep_ids):
"""Output specified sequences from input_file based on sequence id.

Expand All @@ -143,7 +152,6 @@ def output_sequences(in_path, seq_output_path, input_type, keep_ids):
if ids_match(rec.id, getattr(rec, "name", ""), keep_ids):
SeqIO.write(rec, out_seq_file, input_type)


def get_cigar_op_log_probabilities(sam_path):
"""P(align_type) for each type in CIGAR_OPS by counting how often the corresponding
operations occur in the primary alignments and by normalizing over the total
Expand Down Expand Up @@ -179,6 +187,22 @@ def get_cigar_op_log_probabilities(sam_path):
return [math.log(x) for x in np.array(cigar_stats_primary)/n_char], zero_locs, \
dict_longest_align

def passes_edit_distance_filter(alignment, min_pid):
""" Filter helper function for excluding alignments with percent identity below certain threshold,
helping ensure high-confidence only matches are used for taxonomic assignment

return True if alignment passes the minimum percent identity (PID) threshold
computed via NM tag of sam files
return False if threshold is not met or NM tag is missing.
"""

try:
nm = alignment.get_tag("NM") # edit distance from sam file
aligned_base = alignment.query_alignment_length
pid = 100 * (1 - (nm / aligned_base)) if aligned_base > 0 else 0
return pid >= min_pid
except (KeyError, ZeroDivisionError):
return False

def compute_log_prob_rgs(alignment, cigar_stats, log_p_cigar_op, dict_longest_align, align_len):
""" log(L(r|s)) = log(P(cigar_op)) × n_cigar_op for CIGAR_OPS
Expand All @@ -200,8 +224,8 @@ def compute_log_prob_rgs(alignment, cigar_stats, log_p_cigar_op, dict_longest_al
species_tid = int(ref_name.split(":")[0])
return log_score, query_name, species_tid


def log_prob_rgs_dict(sam_path, log_p_cigar_op, dict_longest_align, p_cigar_op_zero_locs=None):
def log_prob_rgs_dict(sam_path, log_p_cigar_op, dict_longest_align,
p_cigar_op_zero_locs=None, args=None):
"""dict containing log(L(read|seq)) for all pairwise alignments in sam file

sam_path(str): path to sam file
Expand All @@ -214,14 +238,48 @@ def log_prob_rgs_dict(sam_path, log_p_cigar_op, dict_longest_align, p_cigar_op_z
int: mapped read count
"""
# calculate log(L(read|seq)) for all alignments
log_p_rgs, unmapped_set = {}, set()
log_p_rgs, unmapped_set, all_queries_set = {}, set(), set()
# pylint: disable=maybe-no-member
sam_filename = pysam.AlignmentFile(sam_path, 'rb')

# track PID filtering stats
filtered_pid_count = 0

# track aligned length filtering stats
filtered_align_len_count = 0

#debug_path = os.path.join(args.output_dir, "pid_debug_output.txt")
#debug_pid_file = open(debug_path, "w", encoding="utf-8")
#debug_pid_file.write("read_id\tref_id\tNM\tpid\n")

if not p_cigar_op_zero_locs:
for alignment in sam_filename.fetch():
all_queries_set.add(alignment.query_name)
align_len = get_align_len(alignment)
align_len_q = get_aligned_query_len(alignment)
if alignment.reference_name and align_len:

# Output the PID filtering metadata
try:
nm = alignment.get_tag("NM")
aligned_bases = alignment.query_alignment_length
pid = 100 * (1 - (nm / aligned_bases)) if aligned_bases > 0 else 0
#debug_pid_file.write(f"{alignment.query_name}\t{alignment.reference_name}\t{nm}\t{pid:.2f}\n")
except (KeyError, ZeroDivisionError):
#debug_pid_file.write(f"{alignment.query_name}\t{alignment.reference_name}\t0\n")
pid = None

if args.min_align_len is not None and align_len_q < args.min_align_len:
filtered_align_len_count += 1
continue
if args.max_align_len is not None and align_len_q > args.max_align_len:
filtered_align_len_count += 1
continue

if args.min_pid is not None and not passes_edit_distance_filter(alignment, args.min_pid):
filtered_pid_count += 1
continue # skip low identity alignments

cigar_stats = get_align_stats(alignment)
log_score, query_name, species_tid = \
compute_log_prob_rgs(alignment, cigar_stats, log_p_cigar_op,
Expand All @@ -242,8 +300,32 @@ def log_prob_rgs_dict(sam_path, log_p_cigar_op, dict_longest_align, p_cigar_op_z
unmapped_set.add(alignment.query_name)
else:
for alignment in sam_filename.fetch():
all_queries_set.add(alignment.query_name)
align_len = get_align_len(alignment)
align_len_q = get_aligned_query_len(alignment)
if alignment.reference_name and align_len:

# Output the PID filtering metadata
try:
nm = alignment.get_tag("NM")
aligned_bases = alignment.query_alignment_length
pid = 100 * (1 - (nm / aligned_bases)) if aligned_bases > 0 else 0
#debug_pid_file.write(f"{alignment.query_name}\t{alignment.reference_name}\t{nm}\t{pid:.2f}\n")
except (KeyError, ZeroDivisionError):
#debug_pid_file.write(f"{alignment.query_name}\t{alignment.reference_name}\t0\n")
pid = None

if args.min_align_len and align_len_q < args.min_align_len:
filtered_align_len_count += 1
continue
if args.max_align_len is not None and align_len_q > args.max_align_len:
filtered_align_len_count += 1
continue

if args.min_pid and not passes_edit_distance_filter(alignment, args.min_pid):
filtered_pid_count += 1
continue # skip low identity alignments

cigar_stats = get_align_stats(alignment)
if sum(cigar_stats[x] for x in p_cigar_op_zero_locs) == 0:
for i in sorted(p_cigar_op_zero_locs, reverse=True):
Expand All @@ -265,16 +347,27 @@ def log_prob_rgs_dict(sam_path, log_p_cigar_op, dict_longest_align, p_cigar_op_z
unmapped_set.add(alignment.query_name)

mapped_set = set(log_p_rgs.keys())
unmapped_set = unmapped_set - mapped_set
unmapped_count = len(unmapped_set)
stdout.write(f"Unmapped read count: {unmapped_count}\n")
unmapped_set = unmapped_set - mapped_set # double check
filtered_set = all_queries_set - unmapped_set - mapped_set
stdout.write(f"Unmapped read count: {len(unmapped_set)}\n")
stdout.write(f"Filtered read count: {len(filtered_set)}\n")

if args.min_align_len is not None or args.max_align_len is not None:
stdout.write(
f"Filtered {filtered_align_len_count} alignments outside aligned query length bounds "
f"(min_align_len={args.min_align_len}, max_align_len={args.max_align_len}).\n"
)

# output PID filtering information
if args.min_pid is not None:
stdout.write(f"Filtered {filtered_pid_count} alignments below min-pid ({args.min_pid}%) threshold.\n")
#debug_pid_file.close()

## remove low likelihood alignments?
## remove if p(r|s) < 0.01
#min_p_thresh = math.log(0.01)
#log_p_rgs = {r_map: val for r_map, val in log_p_rgs.items() if val > min_p_thresh}
return log_p_rgs, unmapped_set, mapped_set

return log_p_rgs, unmapped_set, mapped_set, filtered_set

def expectation_maximization(log_p_rgs, freq):
"""One iteration of the EM algorithm. Updates the relative abundance estimation in f based on
Expand Down Expand Up @@ -342,7 +435,8 @@ def expectation_maximization_iterations(log_p_rgs, db_ids, lli_thresh, input_thr
# check if there are enough reads
if n_reads == 0:
raise ValueError("0 reads mapped")
freq, counter = dict.fromkeys(db_ids, 1 / n_db), 1
freq = {int(k): 1 / n_db for k in db_ids}
counter = 1

# set output abundance threshold
freq_thresh = 1/(n_reads + 1)
Expand Down Expand Up @@ -409,7 +503,7 @@ def lineage_dict_from_tid(taxid, nodes_dict, names_dict):


def freq_to_lineage_df(freq, tsv_output_path, taxonomy_df, mapped_count,
unmapped_count, mapped_unclassified_count, counts=False):
unmapped_count, mapped_unclassified_count, mapped_filtered_count, counts=False):
"""Converts freq to a pandas df where each row contains abundance and tax lineage for
classified species in f.keys(). Stores df as .tsv file in tsv_output_path.

Expand All @@ -424,15 +518,16 @@ def freq_to_lineage_df(freq, tsv_output_path, taxonomy_df, mapped_count,
returns(df): pandas df with lineage and abundances for values in f
"""
#add tax lineage for values in freq
results_df = pd.DataFrame(zip(list(freq.keys()) + ['unmapped', 'mapped_unclassified'],
list(freq.values()) + [0, 0]),
results_df = pd.DataFrame(zip(list(freq.keys()) + ['unmapped', 'mapped_filtered', 'mapped_unclassified'],
list(freq.values()) + [0, 0, 0]),
columns=["tax_id", "abundance"]).set_index('tax_id')
results_df = results_df.join(taxonomy_df, how='left').reset_index()
#add in the estimated count values for the mapped and unmapped counts
if counts:
classified_count = mapped_count - mapped_unclassified_count
counts_series = pd.concat([(results_df["abundance"] * classified_count)[:-2],
pd.Series(unmapped_count), pd.Series(mapped_unclassified_count)],
counts_series = pd.concat([(results_df["abundance"] * classified_count)[:-3],
pd.Series(unmapped_count), pd.Series(mapped_filtered_count),
pd.Series(mapped_unclassified_count)],
ignore_index=True)
results_df["estimated counts"] = counts_series
results_df.to_csv("{}.tsv".format(tsv_output_path), sep='\t', index=False)
Expand Down Expand Up @@ -743,7 +838,7 @@ def combine_outputs(dir_path, rank, split_files=False, count_table=False):
return df_combined_full

if __name__ == "__main__":
__version__ = "3.5.5"
__version__ = "3.6.0"
parser = argparse.ArgumentParser()
parser.add_argument('--version', '-v', action='version', version='%(prog)s v' + __version__)
subparsers = parser.add_subparsers(dest="subparser_name", help='sub-commands')
Expand Down Expand Up @@ -792,6 +887,15 @@ if __name__ == "__main__":
abundance_parser.add_argument(
'--threads', type=int, default=3,
help='threads utilized by minimap [3]')
abundance_parser.add_argument(
'--min-pid', type=float, default=0,
help='Minimum percent identity (PID) based on NM tag [0%%]')
abundance_parser.add_argument(
'--min-align-len', type=int, default=0,
help='Minimum aligned query length (excludes soft/hard clipping [0]')
abundance_parser.add_argument(
'--max-align-len', type=int, default=2000,
help='Maximum aligned query length (excludes soft/hard clipping) [2000]')

build_db_parser = subparsers.add_parser("build-database",
help="Build custom Emu database")
Expand Down Expand Up @@ -860,8 +964,8 @@ if __name__ == "__main__":
SAM_FILE = generate_alignments(args.input_file, out_file, args.db)
log_prob_cigar_op, locs_p_cigar_zero, longest_align_dict = \
get_cigar_op_log_probabilities(SAM_FILE)
log_prob_rgs, set_unmapped, set_mapped = log_prob_rgs_dict(
SAM_FILE, log_prob_cigar_op, longest_align_dict, locs_p_cigar_zero)
log_prob_rgs, set_unmapped, set_mapped, set_filtered = log_prob_rgs_dict(
SAM_FILE, log_prob_cigar_op, longest_align_dict, locs_p_cigar_zero, args)
f_full, f_set_thresh, read_dist = expectation_maximization_iterations(log_prob_rgs,
db_species_tids,
.01, args.min_abundance)
Expand All @@ -870,7 +974,7 @@ if __name__ == "__main__":
stdout.write(f"Unclassified mapped read count: {len(mapped_unclassified)}\n")
freq_to_lineage_df(f_full, "{}_rel-abundance".format(out_file), df_taxonomy,
len(set_mapped), len(set_unmapped), len(mapped_unclassified),
args.keep_counts)
len(set_filtered), args.keep_counts)


# output read assignment distributions as a tsv
Expand All @@ -892,6 +996,8 @@ if __name__ == "__main__":
set_unmapped)
output_sequences(args.input_file[0], "{}_unclassified_mapped".format(out_file),
input_filetype, mapped_unclassified)
output_sequences(args.input_file[0], "{}_filtered_mapped".format(out_file),
input_filetype, set_filtered)

# clean up extra file
if not args.keep_files:
Expand Down