Skip to content

Features

Some code adapted from @mheinzinger

https://github.com/mheinzinger/ProstT5/blob/main/scripts/generate_foldseek_db.py

create_foldseek_prostt5_gpu_db(fasta_aa, foldseek_db_path, db_dir, logdir)

Convert a Foldseek DB with ProstT5 3Di predictions using Foldseek-GPU

Parameters:

Name Type Description Default
fasta_aa Path

Path to the amino-acid FASTA file.

required
foldseek_db_path Path

Path to the directory where Foldseek database will be stored.

required
db_dir Path

Path to the baktfold DB

required
logdir Path

Path to the directory where logs will be stored.

required

Returns:

Type Description
None

None

Source code in src/baktfold/features/create_foldseek_db.py
def create_foldseek_prostt5_gpu_db(
    fasta_aa: Path, foldseek_db_path: Path, db_dir: Path, logdir: Path
) -> None:
    """
    Convert a Foldseek DB with ProstT5 3Di predictions using Foldseek-GPU

    Args:
        fasta_aa (Path): Path to the amino-acid FASTA file.
        foldseek_db_path (Path): Path to the directory where Foldseek database will be stored.
        db_dir (Path): Path to the baktfold DB
        logdir (Path): Path to the directory where logs will be stored.
    Returns:
        None
    """

    prostt5_db_path = Path(db_dir) / "prostt5_weights"

    foldseek_createdb_prostt5 = ExternalTool(
        tool="foldseek",
        input=f"",
        output=f"",
        params=f"createdb {fasta_aa} {foldseek_db_path}  --prostt5-model {prostt5_db_path}  ",
        logdir=logdir,
    )

    ExternalTool.run_tool(foldseek_createdb_prostt5)

foldseek_tsv2db(in_tsv, out_db_name, db_type, logdir)

Convert a Foldseek TSV file to a Foldseek database.

Parameters:

Name Type Description Default
in_tsv Path

Path to the input TSV file.

required
out_db_name Path

Path for the output Foldseek database.

required
db_type int

Type of the output database.

required
logdir Path

Path to the directory where logs will be stored.

required

Returns:

Type Description
None

None

Source code in src/baktfold/features/create_foldseek_db.py
def foldseek_tsv2db(
    in_tsv: Path, out_db_name: Path, db_type: int, logdir: Path
) -> None:
    """
    Convert a Foldseek TSV file to a Foldseek database.

    Args:
        in_tsv (Path): Path to the input TSV file.
        out_db_name (Path): Path for the output Foldseek database.
        db_type (int): Type of the output database.
        logdir (Path): Path to the directory where logs will be stored.

    Returns:
        None
    """
    foldseek_tsv2db = ExternalTool(
        tool="foldseek",
        input=f"",
        output=f"",
        params=f"tsv2db {in_tsv} {out_db_name}  --output-dbtype {str(db_type)} ",
        logdir=logdir,
    )

    ExternalTool.run_tool(foldseek_tsv2db)

generate_foldseek_db_from_aa_3di(fasta_aa, fasta_3di, foldseek_db_path, logdir, prefix)

Generate Foldseek database from amino-acid and 3Di sequences.

Parameters:

Name Type Description Default
fasta_aa Path

Path to the amino-acid FASTA file.

required
fasta_3di Path

Path to the 3Di FASTA file.

required
foldseek_db_path Path

Path to the directory where Foldseek database will be stored.

required
logdir Path

Path to the directory where logs will be stored.

required
prefix str

Prefix for the Foldseek database.

required

Returns:

Type Description
None

None

Source code in src/baktfold/features/create_foldseek_db.py
def generate_foldseek_db_from_aa_3di(
    fasta_aa: Path, fasta_3di: Path, foldseek_db_path: Path, logdir: Path, prefix: str
) -> None:
    """
    Generate Foldseek database from amino-acid and 3Di sequences.

    Args:
        fasta_aa (Path): Path to the amino-acid FASTA file.
        fasta_3di (Path): Path to the 3Di FASTA file.
        foldseek_db_path (Path): Path to the directory where Foldseek database will be stored.
        logdir (Path): Path to the directory where logs will be stored.
        prefix (str): Prefix for the Foldseek database.

    Returns:
        None
    """
    # read in amino-acid sequences
    sequences_aa = {}
    for record in SeqIO.parse(fasta_aa, "fasta"):
        sequences_aa[record.id] = str(record.seq)

    # read in 3Di strings
    sequences_3di = {}
    for record in SeqIO.parse(fasta_3di, "fasta"):
        if not record.id in sequences_aa.keys():
            logger.warning(
                "Warning: ignoring 3Di entry {}, since it is not in the amino-acid FASTA file".format(
                    record.id
                )
            )
        else:
            sequences_3di[record.id] = str(record.seq)  #no upper if masked

    # assert that we parsed 3Di strings for all sequences in the amino-acid FASTA file
    for id in sequences_aa.keys():
        if not id in sequences_3di.keys():
            logger.warning(
                "Warning: entry {} in amino-acid FASTA file has no corresponding 3Di string".format(
                    id
                )
            )
            logger.warning("Removing: entry {} from the Foldseek database ".format(id))
            sequences_aa = {
                id: sequence
                for id, sequence in sequences_aa.items()
                if id in sequences_3di
            }

    # generate TSV file contents
    tsv_aa = ""
    tsv_3di = ""
    tsv_header = ""
    for i, id in enumerate(sequences_aa.keys()):
        tsv_aa += "{}\t{}\n".format(str(i + 1), sequences_aa[id])
        tsv_3di += "{}\t{}\n".format(str(i + 1), sequences_3di[id])
        tsv_header += "{}\t{}\n".format(str(i + 1), id)

    #### write temp tsv files

    # write TSV files
    temp_aa_tsv: Path = Path(foldseek_db_path) / "aa.tsv"
    temp_3di_tsv: Path = Path(foldseek_db_path) / "3di.tsv"
    temp_header_tsv: Path = Path(foldseek_db_path) / "header.tsv"
    with open(temp_aa_tsv, "w") as f:
        f.write(tsv_aa)
    with open(temp_3di_tsv, "w") as f:
        f.write(tsv_3di)
    with open(temp_header_tsv, "w") as f:
        f.write(tsv_header)

    # create foldseek db names

    short_db_name = f"{prefix}"
    aa_db_name: Path = Path(foldseek_db_path) / short_db_name
    tsv_db_name: Path = Path(foldseek_db_path) / f"{short_db_name}_ss"
    header_db_name: Path = Path(foldseek_db_path) / f"{short_db_name}_h"

    # create Foldseek database with foldseek tsv2db

    foldseek_tsv2db(temp_aa_tsv, aa_db_name, 0, logdir)
    foldseek_tsv2db(temp_3di_tsv, tsv_db_name, 0, logdir)
    foldseek_tsv2db(temp_header_tsv, header_db_name, 12, logdir)

    # clean up
    remove_file(temp_aa_tsv)
    remove_file(temp_3di_tsv)
    remove_file(temp_header_tsv)

generate_foldseek_db_from_structures(fasta_aa, foldseek_db_path, structure_dir, logdir, prefix, proteins_flag)

Generate Foldseek database from PDB files.

Parameters:

Name Type Description Default
fasta_aa Path

Path to the amino-acid FASTA file.

required
foldseek_db_path Path

Path to the directory where Foldseek database will be stored.

required
structure_dir Path

Path to the directory containing .pdb or .cif structure files.

required
logdir Path

Path to the directory where logs will be stored.

required
prefix str

Prefix for the Foldseek database.

required
proteins_flag bool

Flag - True if proteins-compare is run

required

Returns:

Type Description
None

None

Source code in src/baktfold/features/create_foldseek_db.py
def generate_foldseek_db_from_structures(
    fasta_aa: Path,
    foldseek_db_path: Path,
    structure_dir: Path,
    logdir: Path,
    prefix: str,
    proteins_flag: bool,
) -> None:
    """
    Generate Foldseek database from PDB files.

    Args:
        fasta_aa (Path): Path to the amino-acid FASTA file.
        foldseek_db_path (Path): Path to the directory where Foldseek database will be stored.
        structure_dir (Path): Path to the directory containing .pdb or .cif structure files.
        logdir (Path): Path to the directory where logs will be stored.
        prefix (str): Prefix for the Foldseek database.
        proteins_flag (bool): Flag - True if proteins-compare is run

    Returns:
        None
    """

    # read in amino-acid sequences
    sequences_aa = {}
    for record in SeqIO.parse(fasta_aa, "fasta"):
        sequences_aa[record.id] = str(record.seq)

    # lists all the pdb files

    structure_files = [
        file
        for file in os.listdir(structure_dir)
        if file.endswith(".pdb") or file.endswith(".cif")
    ]

    num_structures = len(structure_files)

    num_structures = 0

    # Checks that ID is in the pdbs

    no_structure_cds_ids = []

    for cds_id in sequences_aa.keys():

        matching_files = [
            file
            for file in structure_files
            if f"{cds_id}.pdb" == file or f"{cds_id}.cif" == file
        ]

        if len(matching_files) == 1:
            num_structures += 1

        # should neve happen but in case
        if len(matching_files) > 1:
            logger.warning(f"More than 1 structures found for {cds_id}")
            logger.warning("Taking the first one")
            num_structures += 1
        elif len(matching_files) == 0:
            logger.warning(f"No structure found for {cds_id}")
            logger.warning(f"{cds_id} will be ignored in annotation")
            no_structure_cds_ids.append(cds_id)

    if num_structures == 0:
        logger.error(
            f"No structures with matching CDS ids were found at all. Check the {structure_dir} directory"
        )

    # generate the db
    short_db_name = f"{prefix}"
    structure_db_name: Path = Path(foldseek_db_path) / short_db_name
    query_structure_dir = structure_dir


    foldseek_createdb_from_structures = ExternalTool(
        tool="foldseek",
        input=f"",
        output=f"",
        params=f"createdb {query_structure_dir} {structure_db_name} ",
        logdir=logdir,
    )

    ExternalTool.run_tool(foldseek_createdb_from_structures)

create_result_tsv(query_db, target_db, result_db, result_tsv, logdir, foldseek_gpu, structures, threads)

Create a TSV file containing the results of a Foldseek search.

Parameters:

Name Type Description Default
query_db Path

Path to the query database.

required
target_db Path

Path to the target database.

required
result_db Path

Path to the result database generated by the search.

required
result_tsv Path

Path to save the resulting TSV file.

required
logdir Path

Path to the directory where logs will be stored.

required
foldseek_gpu bool

Run Foldseek-GPU with accelerate ungapped prefilter

required
structures bool

Whether structures were input (not ProstT5)

required
threads int

Number of threads to use.

required

Returns:

Type Description
None

None

Source code in src/baktfold/features/run_foldseek.py
def create_result_tsv(
    query_db: Path, target_db: Path, result_db: Path, result_tsv: Path, logdir: Path, foldseek_gpu: bool, structures: bool, threads: int
) -> None:
    """
    Create a TSV file containing the results of a Foldseek search.

    Args:
        query_db (Path): Path to the query database.
        target_db (Path): Path to the target database.
        result_db (Path): Path to the result database generated by the search.
        result_tsv (Path): Path to save the resulting TSV file.
        logdir (Path): Path to the directory where logs will be stored.
        foldseek_gpu (bool): Run Foldseek-GPU with accelerate ungapped prefilter
        structures (bool): Whether structures were input (not ProstT5)
        threads (int): Number of threads to use.

    Returns:
        None
    """
    if structures:
        format_string= "--format-output query,target,bits,fident,evalue,qstart,qend,qlen,tstart,tend,tlen,alntmscore,lddt"
    else:
        format_string = "--format-output query,target,bits,fident,evalue,qstart,qend,qlen,tstart,tend,tlen"
    if foldseek_gpu:
        target_db = f"{target_db}_gpu"


    cmd = f"convertalis {query_db} {target_db} {result_db} {result_tsv} {format_string} --threads {threads}"

    foldseek_createtsv = ExternalTool(
        tool="foldseek",
        input=f"",
        output=f"",
        params=f"{cmd}",
        logdir=logdir,
    )


    ExternalTool.run_tool(foldseek_createtsv)

Run a Foldseek search using given parameters.

Parameters:

Name Type Description Default
query_db Path

Path to the query database.

required
target_db Path

Path to the target database.

required
result_db Path

Path to store the result database.

required
temp_db Path

Path to store temporary files.

required
threads int

Number of threads to use for the search.

required
logdir Path

Path to the directory where logs will be stored.

required
evalue float

E-value threshold for the search.

required
sensitivity float

Sensitivity threshold for the search.

required
max_seqs int

Maximum results per query sequence allowed to pass the prefilter for foldseek.

required
ultra_sensitive bool

Whether to skip foldseek prefilter for maximum sensitivity

required
extra_foldseek_params str

Extra foldseek search params

required
foldseek_gpu bool

Run Foldseek-GPU with accelerate ungapped prefilter

required
structures bool

Run Foldseek with structures, not ProstT5 3Dis

required
gpus Optional[str]

Comma-separated CUDA indices (e.g. "0,2") to restrict foldseek's GPU prefilter to a subset of devices. When foldseek_gpu is True and this resolves to ≥1 CUDA device, the foldseek subprocess gets CUDA_VISIBLE_DEVICES set accordingly. None = use all visible CUDA GPUs (foldseek default). Ignored when foldseek_gpu is False.

None

Returns:

Type Description
None

None

Source code in src/baktfold/features/run_foldseek.py
def run_foldseek_search(
    query_db: Path,
    target_db: Path,
    result_db: Path,
    temp_db: Path,
    threads: int,
    logdir: Path,
    evalue: float,
    sensitivity: float,
    max_seqs: int,
    ultra_sensitive: bool,
    extra_foldseek_params: str,
    foldseek_gpu: bool,
    structures: bool,
    gpus: Optional[str] = None,
) -> None:
    """
    Run a Foldseek search using given parameters.

    Args:
        query_db (Path): Path to the query database.
        target_db (Path): Path to the target database.
        result_db (Path): Path to store the result database.
        temp_db (Path): Path to store temporary files.
        threads (int): Number of threads to use for the search.
        logdir (Path): Path to the directory where logs will be stored.
        evalue (float): E-value threshold for the search.
        sensitivity (float): Sensitivity threshold for the search.
        max_seqs (int): Maximum results per query sequence allowed to pass the prefilter for foldseek.
        ultra_sensitive (bool): Whether to skip foldseek prefilter for maximum sensitivity
        extra_foldseek_params (str): Extra foldseek search params
        foldseek_gpu (bool): Run Foldseek-GPU with accelerate ungapped prefilter
        structures (bool): Run Foldseek with structures, not ProstT5 3Dis
        gpus (Optional[str]): Comma-separated CUDA indices (e.g. "0,2") to
            restrict foldseek's GPU prefilter to a subset of devices. When
            ``foldseek_gpu`` is True and this resolves to ≥1 CUDA device,
            the foldseek subprocess gets ``CUDA_VISIBLE_DEVICES`` set
            accordingly. None = use all visible CUDA GPUs (foldseek default).
            Ignored when ``foldseek_gpu`` is False.

    Returns:
        None
    """

    if ultra_sensitive:
        cmd = f"search {query_db} {target_db} {result_db} {temp_db} --threads {str(threads)} -e {evalue} -s {sensitivity} --exhaustive-search"
    else:
        cmd = f"search {query_db} {target_db} {result_db} {temp_db} --threads {str(threads)} -e {evalue} -s {sensitivity} --max-seqs {max_seqs}"

    # support foldseek gpu only for the regular DB search for now
    if foldseek_gpu:
        cmd = f"search {query_db} {target_db}_gpu {result_db} {temp_db} --threads {str(threads)} -e {evalue}  --gpu 1 --prefilter-mode 1 --max-seqs {max_seqs}"

    if extra_foldseek_params:
        cmd += f" {extra_foldseek_params}"

    # need -a 1 to compute the alignment so tmscore and lddt can be output (if using --structures)
    if structures:
        cmd += f" -a 1"

    # Build optional env for multi-GPU foldseek. Only applies when GPU mode is
    # on; foldseek selects devices via CUDA_VISIBLE_DEVICES (per its README).
    env = None
    if foldseek_gpu and gpus is not None:
        devices = parse_gpus(cpu=False, gpus=gpus)
        cvd = cuda_visible_devices_value(devices)
        if cvd is not None:
            env = {"CUDA_VISIBLE_DEVICES": cvd}

    foldseek_search = ExternalTool(
        tool="foldseek",
        input=f"",
        output=f"",
        params=f"{cmd}",
        logdir=logdir,
        env=env,
    )

    ExternalTool.run_tool(foldseek_search)

summarise_hits(result_db, result_db_greedy_best_hits, logdir, threads)

Get all non-overlapping tophits covering a query (designed for CATH)

Parameters:

Name Type Description Default
result_db Path

Path to the result database generated by the search.

required
result_db_greedy_best_hits Path

Path to save the greedy best hits results db.

required
logdir Path

Path to the directory where logs will be stored.

required
threads int

Number of threads to use.

required

Returns:

Type Description
None

None

Source code in src/baktfold/features/run_foldseek.py
def summarise_hits(result_db: Path, result_db_greedy_best_hits: Path, logdir: Path, threads: int) -> None:
    """
    Get all non-overlapping tophits covering a query (designed for CATH)

    Args:
        result_db (Path): Path to the result database generated by the search.
        result_db_greedy_best_hits (Path): Path to save the greedy best hits results db.
        logdir (Path): Path to the directory where logs will be stored.
        threads (int): Number of threads to use.

    Returns:
        None
    """

    cmd = f"summarizeresult  {result_db} {result_db_greedy_best_hits} --threads {threads} -a 1"

    foldseek_summarizeresult = ExternalTool(
        tool="foldseek",
        input=f"",
        output=f"",
        params=f"{cmd}",
        logdir=logdir,
    )

    ExternalTool.run_tool(foldseek_summarizeresult)

3Di prediction for baktfold — wraps pholdlib's shared inference engine.

Baktfold-specific: flat cds_dict (no contig nesting), Bakta hypotheticals format with in-place annotation updates, has_duplicate_locus support.

Code adapted from @mheinzinger https://github.com/mheinzinger/ProstT5/blob/main/scripts/predict_3Di_encoderOnly.py

get_embeddings(hypotheticals, cds_dict, out_path, prefix, model_dir, model_name, checkpoint_path, output_3di, output_h5_per_residue, output_h5_per_protein, half_precision, max_residues=100000, max_seq_len=30000, max_batch=10000, cpu=False, output_probs=True, save_per_residue_embeddings=False, save_per_protein_embeddings=False, threads=1, mask_threshold=0, has_duplicate_locus=False, gpus=None)

Run ProstT5 + CNN 3Di prediction for all sequences in cds_dict.

Parameters:

Name Type Description Default
hypotheticals List[Dict]

List of Bakta feature dicts (mutated in-place with "3di").

required
cds_dict Dict[str, str]

Flat {seq_id: amino_acid_str} dict.

required
out_path Path

Directory for output files.

required
prefix str

Filename prefix for CSV / JSONL outputs.

required
model_dir Path

Directory where ProstT5 is cached.

required
model_name str

HuggingFace model identifier.

required
checkpoint_path Path

Path to the CNN .pt checkpoint.

required
output_3di Path

Output FASTA path for 3Di sequences.

required
output_h5_per_residue Path

HDF5 path for per-residue embeddings.

required
output_h5_per_protein Path

HDF5 path for per-protein embeddings.

required
half_precision bool

If True, cast model + predictor to fp16 after loading.

required
max_residues int

Max total residues per inference batch.

100000
max_seq_len int

Sequences longer than this flush a batch immediately.

30000
max_batch int

Max sequences per batch.

10000
cpu bool

Force CPU inference.

False
output_probs bool

Whether to write per-residue probability JSONL.

True
save_per_residue_embeddings bool

Save per-residue HDF5.

False
save_per_protein_embeddings bool

Save per-protein HDF5.

False
threads int

Number of CPU threads for torch.

1
mask_threshold float

Residues with max softmax prob < threshold/100 → 'X'.

0
has_duplicate_locus bool

If True use feat["id"] rather than feat["locus"].

False
gpus Optional[str]

Comma-separated CUDA indices (e.g. "0,2"). None = auto-detect all visible CUDA GPUs. Overridden by cpu=True.

None

Returns:

Name Type Description
predictions Dict

Flat {seq_id: (pred, mean_prob, all_prob)} dict, in original cds_dict key order.

Source code in src/baktfold/features/predict_3Di.py
def get_embeddings(
    hypotheticals: List[Dict],
    cds_dict: Dict[str, str],
    out_path: Path,
    prefix: str,
    model_dir: Path,
    model_name: str,
    checkpoint_path: Path,
    output_3di: Path,
    output_h5_per_residue: Path,
    output_h5_per_protein: Path,
    half_precision: bool,
    max_residues: int = 100000,
    max_seq_len: int = 30000,
    max_batch: int = 10000,
    cpu: bool = False,
    output_probs: bool = True,
    save_per_residue_embeddings: bool = False,
    save_per_protein_embeddings: bool = False,
    threads: int = 1,
    mask_threshold: float = 0,
    has_duplicate_locus: bool = False,
    gpus: Optional[str] = None,
) -> Dict:
    """Run ProstT5 + CNN 3Di prediction for all sequences in *cds_dict*.

    Args:
        hypotheticals: List of Bakta feature dicts (mutated in-place with "3di").
        cds_dict: Flat ``{seq_id: amino_acid_str}`` dict.
        out_path: Directory for output files.
        prefix: Filename prefix for CSV / JSONL outputs.
        model_dir: Directory where ProstT5 is cached.
        model_name: HuggingFace model identifier.
        checkpoint_path: Path to the CNN ``.pt`` checkpoint.
        output_3di: Output FASTA path for 3Di sequences.
        output_h5_per_residue: HDF5 path for per-residue embeddings.
        output_h5_per_protein: HDF5 path for per-protein embeddings.
        half_precision: If True, cast model + predictor to fp16 after loading.
        max_residues: Max total residues per inference batch.
        max_seq_len: Sequences longer than this flush a batch immediately.
        max_batch: Max sequences per batch.
        cpu: Force CPU inference.
        output_probs: Whether to write per-residue probability JSONL.
        save_per_residue_embeddings: Save per-residue HDF5.
        save_per_protein_embeddings: Save per-protein HDF5.
        threads: Number of CPU threads for torch.
        mask_threshold: Residues with max softmax prob < threshold/100 → 'X'.
        has_duplicate_locus: If True use feat["id"] rather than feat["locus"].
        gpus: Comma-separated CUDA indices (e.g. "0,2"). None = auto-detect
              all visible CUDA GPUs. Overridden by ``cpu=True``.

    Returns:
        predictions: Flat ``{seq_id: (pred, mean_prob, all_prob)}`` dict,
                     in original cds_dict key order.
    """
    # ── resolve devices ─────────────────────────────────────────────────────
    devices = parse_gpus(cpu, gpus)
    logger.info(f"Beginning ProstT5 predictions on device(s): {devices}")
    if half_precision and devices == ["cpu"]:
        logger.info("CPU device — forcing full-precision (half-precision disabled).")
        half_precision = False
    if half_precision:
        logger.info("Using models in half-precision")
    else:
        logger.info("Using models in full-precision")

    # ── build seq_dict (skip empty / non-string entries) ────────────────────
    original_keys = list(cds_dict.keys())
    seq_dict: List[Tuple] = []
    fail_ids: List[str] = []

    for k, seq in cds_dict.items():
        if isinstance(seq, str) and seq:
            seq_dict.append((k, seq, len(seq)))
        else:
            logger.warning(
                f"Protein header {k} is corrupt or empty — will be saved in fails.tsv"
            )
            fail_ids.append(k)

    # sort descending by length (minimises padding in each batch)
    seq_dict.sort(key=lambda x: x[2], reverse=True)

    # ── run shared inference engine (single- or multi-GPU) ──────────────────
    predictions, emb_res, emb_prot, inf_fail_ids = run_prostt5_inference_multi_gpu(
        seq_dict,
        devices=devices,
        model_dir=model_dir,
        model_name=model_name,
        checkpoint_path=checkpoint_path,
        half_precision=half_precision,
        threads=threads,
        check_fn=check_prostT5_download,
        zenodo_fn=download_zenodo_prostT5,
        max_residues=max_residues,
        max_seq_len=max_seq_len,
        max_batch=max_batch,
        output_probs=output_probs,
        save_per_residue_embeddings=save_per_residue_embeddings,
        save_per_protein_embeddings=save_per_protein_embeddings,
        desc="Predicting 3Di",
    )
    fail_ids.extend(inf_fail_ids)

    # restore original key order
    predictions = {k: predictions[k] for k in original_keys if k in predictions}

    # ── write outputs ────────────────────────────────────────────────────────
    if fail_ids:
        write_fail_ids(fail_ids, Path(out_path) / "fails.tsv")

    write_predictions(
        hypotheticals, predictions, output_3di, mask_threshold, has_duplicate_locus
    )

    if save_per_residue_embeddings:
        write_embeddings(emb_res, output_h5_per_residue)
    if save_per_protein_embeddings:
        write_embeddings(emb_prot, output_h5_per_protein)

    mean_probs_path = Path(out_path) / f"{prefix}_prostT5_3di_mean_probabilities.csv"
    all_probs_path = (
        Path(out_path) / f"{prefix}_prostT5_3di_all_probabilities.json"
        if output_probs else None
    )
    write_probs(predictions, mean_probs_path, all_probs_path, original_keys)

    return predictions

write_embeddings(embeddings, out_path)

Write per-residue or per-protein embeddings to HDF5 (flat key structure).

Source code in src/baktfold/features/predict_3Di.py
def write_embeddings(embeddings: Dict[str, Any], out_path: Path) -> None:
    """Write per-residue or per-protein embeddings to HDF5 (flat key structure)."""
    with h5py.File(str(out_path), "w") as hf:
        for sequence_id, embedding in embeddings.items():
            hf.create_dataset(sequence_id, data=embedding)

write_predictions(hypotheticals, predictions, out_path, mask_threshold, has_duplicate_locus=False)

Write 3Di predictions to FASTA and update Bakta hypotheticals in-place.

Parameters:

Name Type Description Default
hypotheticals List[Dict]

List of Bakta feature dicts. Each is mutated in-place with a "3di" key set to the predicted 3Di string (or None if prediction failed / was skipped).

required
predictions Dict[str, Tuple]

Flat {seq_id: (pred, mean_prob, all_prob)} dict.

required
out_path Path

Output FASTA path.

required
mask_threshold float

Residues with max softmax prob (0–100) below this threshold are replaced with 'X'.

required
has_duplicate_locus bool

If True, use feat["id"] as seq_id (needed for eukaryotic inputs that may have duplicate locus tags). Otherwise use feat["locus"].

False
Source code in src/baktfold/features/predict_3Di.py
def write_predictions(
    hypotheticals: List[Dict],
    predictions: Dict[str, Tuple],
    out_path: Path,
    mask_threshold: float,
    has_duplicate_locus: bool = False,
) -> None:
    """Write 3Di predictions to FASTA and update Bakta hypotheticals in-place.

    Args:
        hypotheticals: List of Bakta feature dicts. Each is mutated in-place
                       with a ``"3di"`` key set to the predicted 3Di string
                       (or None if prediction failed / was skipped).
        predictions: Flat ``{seq_id: (pred, mean_prob, all_prob)}`` dict.
        out_path: Output FASTA path.
        mask_threshold: Residues with max softmax prob (0–100) below this
                        threshold are replaced with 'X'.
        has_duplicate_locus: If True, use ``feat["id"]`` as seq_id (needed for
                             eukaryotic inputs that may have duplicate locus tags).
                             Otherwise use ``feat["locus"]``.
    """
    mask_prop = mask_threshold / 100

    # drop zero-length predictions (issue #47)
    predictions = {k: v for k, v in predictions.items() if len(v[0]) > 0}

    # apply confidence masking in-place on pred index arrays
    for seq_id, (pred, mean_prob, all_prob) in predictions.items():
        for i in range(len(pred)):
            if all_prob[0][i] < mask_prop:
                pred[i] = 20  # 'X'

    with open(out_path, "w+") as out_f:
        for feat in hypotheticals:
            seq_id = feat["id"] if has_duplicate_locus else feat["locus"]
            pred_tuple = predictions.get(seq_id)
            if pred_tuple is not None:
                yhats = pred_tuple[0]
                threedi_seq = "".join(SS_MAPPING[int(y)] for y in yhats)
                feat["3di"] = threedi_seq  # mutate Bakta feature dict in-place
                out_f.write(f">{seq_id}\n{threedi_seq}\n")
            else:
                feat["3di"] = None  # no prediction (OOM / corrupt entry)

    logger.info(f"Finished writing 3Di FASTA to {out_path}")

autotune_batching_real_data(model_dir, model_name, cpu, threads, probe_seqs, start_bs=1, max_bs=100, step=5, device=None)

Autotunes the batch size for a given model and set of sequences.

Parameters:

Name Type Description Default
model_dir str

The directory where the model is stored.

required
model_name str

The name of the model.

required
cpu bool

Whether to use the CPU or not.

required
threads int

The number of threads to use.

required
probe_seqs list

A list of sequences to use for probing.

required
start_bs int

The starting batch size to use.

1
max_bs int

The maximum batch size to use.

100
step int

The step size to use when increasing the batch size.

5
device Optional[str]

Torch device string (e.g. "cuda:1") to pin autotune to a specific GPU. None preserves the original auto-detection behaviour. Used by the multi-GPU caller.

None

Returns:

Name Type Description
int

The optimal batch size.

int

The maximum number of residues per batch.

Examples:

>>> autotune_batching_real_data("model_dir", "model_name", True, 4, ["ATCG", "GCTA"], 1, 100, 5)
(10, 100)
Source code in src/baktfold/features/autotune.py
def autotune_batching_real_data(
    model_dir,
    model_name,
    cpu,
    threads,
    probe_seqs,
    start_bs=1,
    max_bs=100,
    step=5, # step size
    device: Optional[str] = None,
):
    """
    Autotunes the batch size for a given model and set of sequences.

    Args:
      model_dir (str): The directory where the model is stored.
      model_name (str): The name of the model.
      cpu (bool): Whether to use the CPU or not.
      threads (int): The number of threads to use.
      probe_seqs (list): A list of sequences to use for probing.
      start_bs (int): The starting batch size to use.
      max_bs (int): The maximum batch size to use.
      step (int): The step size to use when increasing the batch size.
      device (Optional[str]): Torch device string (e.g. "cuda:1") to pin
        autotune to a specific GPU. None preserves the original
        auto-detection behaviour. Used by the multi-GPU caller.

    Returns:
      int: The optimal batch size.
      int: The maximum number of residues per batch.

    Examples:
      >>> autotune_batching_real_data("model_dir", "model_name", True, 4, ["ATCG", "GCTA"], 1, 100, 5)
      (10, 100)
    """

    model, tokenizer, device = get_T5_model(
        model_dir, model_name, cpu, threads, device=device
    )
    model.eval()
    model.half()

    bs = start_bs
    results = []


    while bs <= max_bs:
        try:

            # seqs = probe_seqs
            n_tokens = sum(len(s) for s in probe_seqs)

            logger.info(f"Running with batch size {bs}")

            model.eval()

            total_tokens = 0
            total_time = 0.0
            batches = 0

            # iterate over real sequences in batches
            for i in tqdm(range(0, len(probe_seqs), bs), desc="Processing"):
                batch_seqs = probe_seqs[i : i + bs]

                n_tokens = sum(len(s) for s in batch_seqs)
                total_tokens += n_tokens

                inputs = tokenizer(
                    batch_seqs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs.pop("token_type_ids", None)
                inputs = {k: v.to(device) for k, v in inputs.items()}

                # timing — device_synchronize handles CUDA/MPS/XPU/CPU (PR #129)
                device_synchronize(device)
                t0 = time.perf_counter()
                with torch.no_grad():
                    _ = model(**inputs)
                device_synchronize(device)

                total_time += time.perf_counter() - t0

                batches += 1

            time_per_token = total_time / total_tokens


            token_per_batch = math.floor(total_tokens / batches)


            results.append({
                "bs": bs,
                "tokens_per_batch": token_per_batch,
                "time": total_time,
                "time_per_token": time_per_token,
            })

            logger.info(f"Time elapsed {round(total_time,5)}")
            logger.info(f"Tokens per batch {token_per_batch}")

            bs += step

        except (torch.cuda.OutOfMemoryError, RuntimeError):
            # RuntimeError covers XPU/MPS OOM; torch.cuda.OutOfMemoryError covers CUDA.
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            break


    if not results:
        raise RuntimeError("No batch size fits on this GPU")

    best_entry = min(results, key=lambda x: x["time_per_token"])

    best_bs = best_entry["bs"]
    best_residues = best_entry["tokens_per_batch"]
    # best_tpt = best_bs["time_per_token"]

    logger.info(f"##########################")
    logger.info(f"Best batch size: {best_bs}")
    # logger.info(f"best max residues: {best_residues}")

    return best_bs, best_residues

run_autotune(input_path, model_dir, model_name, cpu, threads, step, min_batch, max_batch, sample_seqs, gpus=None)

Runs the batch size autotuning process.

Parameters:

Name Type Description Default
input_path str

The path to the input file.

required
model_dir str

The directory where the model is stored.

required
model_name str

The name of the model.

required
cpu bool

Whether to use the CPU or not.

required
threads int

The number of threads to use.

required
step int

The step size to use when increasing the batch size.

required
min_batch int

The minimum batch size to use.

required
max_batch int

The maximum batch size to use.

required
sample_seqs int

The number of sequences to sample for probing.

required
gpus Optional[str]

Comma-separated CUDA indices (e.g. "0,2"). When set, autotune runs on the lowest selected index. Default None = existing behaviour (cuda:0 / mps / xpu / cpu auto-detect).

None

Returns:

Name Type Description
int

The optimal batch size.

Examples:

>>> run_autotune("input_path", "model_dir", "model_name", True, 4, 5, 1, 100, 10)
10
Source code in src/baktfold/features/autotune.py
def run_autotune(
    input_path,
    model_dir,
    model_name,
    cpu,
    threads,
    step,
    min_batch,
    max_batch,
    sample_seqs,
    gpus: Optional[str] = None,
):
    """
    Runs the batch size autotuning process.

    Args:
      input_path (str): The path to the input file.
      model_dir (str): The directory where the model is stored.
      model_name (str): The name of the model.
      cpu (bool): Whether to use the CPU or not.
      threads (int): The number of threads to use.
      step (int): The step size to use when increasing the batch size.
      min_batch (int): The minimum batch size to use.
      max_batch (int): The maximum batch size to use.
      sample_seqs (int): The number of sequences to sample for probing.
      gpus (Optional[str]): Comma-separated CUDA indices (e.g. "0,2"). When
        set, autotune runs on the lowest selected index. Default None =
        existing behaviour (cuda:0 / mps / xpu / cpu auto-detect).

    Returns:
      int: The optimal batch size.

    Examples:
      >>> run_autotune("input_path", "model_dir", "model_name", True, 4, 5, 1, 100, 10)
      10
    """

    # Resolve devices early so we can pick the autotune GPU (homogeneous-card
    # assumption: same batch size applies to every GPU we'll later use).
    devices = parse_gpus(cpu, gpus)
    autotune_device: Optional[str] = None
    if len(devices) >= 1 and devices != ["cpu"]:
        autotune_device = devices[0]
    if len(devices) > 1:
        logger.info(
            f"Multi-GPU detected ({len(devices)} devices); autotuning on "
            f"{autotune_device} and applying the chosen batch to all devices."
        )

    # Dictionary to store the records
    cds_dict = {}


    with open_protein_fasta_file(input_path) as handle:  # handles gzip too
        records = list(SeqIO.parse(handle, "fasta"))
        if not records:
            logger.warning(f"No proteins were found in your input file {input_path}.")
            logger.error(
                f"Your input file {input_path} is likely not a amino acid FASTA file. Please check this."
            )
        for record in records:
            prot_id = record.id
            feature_location = FeatureLocation(0, len(record.seq))
            # Seq needs to be saved as the first element in list hence the closed brackets [str(record.seq)]
            seq_feature = SeqFeature(
                feature_location,
                type="CDS",
                qualifiers={
                    "ID": record.id,
                    "description": record.description,
                    "translation": str(record.seq),
                },
            )

            cds_dict[prot_id] = seq_feature

    if not cds_dict:
        logger.error(f"Error: no AA protein sequences found in {input_path} file")


    seqs = []
    for feat in cds_dict.values():
        v = feat.qualifiers.get("translation")
        if v and isinstance(v, str):
            seqs.append(v)

    logger.info("Beginning batch size tuning")
    logger.info(f"Using minimum batch size of 1 and maximum batch size of {max_batch}")

    # define the sampling

    probe_seqs = sample_probe_sequences(seqs, n=sample_seqs)

    batch_size, max_residues = autotune_batching_real_data(
        model_dir,
        model_name,
        cpu,
        threads,
        probe_seqs,
        start_bs=min_batch,
        max_bs=max_batch,
        step=step, # step size
        device=autotune_device,
    )

    logger.info(f"Optimal batch size is {batch_size} (residues per batch {max_residues})")

    return batch_size

sample_probe_sequences(seqs, n=5000, seed=0)

samples sequences

Source code in src/baktfold/features/autotune.py
def sample_probe_sequences(seqs, n=5000, seed=0):
    """
    samples sequences 

    """

    rng = random.Random(seed)

    if n >= len(seqs):
        sampled = list(seqs)
    else:
        sampled = rng.sample(seqs, n)

    # sort by sequence length
    sampled.sort(key=len, reverse=True)

    return sampled