diff --git a/autotag b/autotag index 570fc9b..1b9cdd9 100755 --- a/autotag +++ b/autotag @@ -2,22 +2,40 @@ import json import click +import itertools from autotagger import Autotagger -from fastai.vision.core import PILImage +from pathlib import Path +from more_itertools import ichunked -@click.command(help="Automatically generate tags for an image.") +@click.command(help="Automatically generate tags for an image.", context_settings=dict(max_content_width=140)) @click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.") -@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return.") -@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), help="The model to use.") -@click.argument("file", nargs=-1, type=click.File("rb"), required=True) -def main(file, threshold, limit, model): +@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.") +@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.") +@click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.") +@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.") +@click.argument("file", nargs=-1, type=click.Path(exists=True, allow_dash=True), required=True) +def main(file, threshold, limit, bs, group_tags, model): autotagger = Autotagger(model) - images = [PILImage.create(f) for f in file] - predictions = autotagger.predict(images, threshold=threshold, limit=limit) - for i, tags in enumerate(predictions): - data = { "filename": file[i].name, "tags": tags } - click.echo(json.dumps(data)) + for filepaths in ichunked(get_filepaths(file), bs): + filepaths = list(filepaths) + predictions = autotagger.predict(filepaths, threshold=threshold, limit=limit, bs=bs) + + for filepath, tags in zip(filepaths, predictions): + if group_tags: + data = { "filename": filepath.name, "tags": tags } + click.echo(json.dumps(data)) + else: + for tag, score in tags.items(): + data = { "filename": filepath.name, "tag": tag, "score": score } + click.echo(json.dumps(data)) + +def get_filepaths(paths): + files = (recurse_dir(p) if Path(p).is_dir() else iter([p]) for p in paths) + return itertools.chain(*files) + +def recurse_dir(directory): + return (path for path in Path(directory).glob("**/*") if not path.is_dir()) if __name__ == "__main__": main() diff --git a/poetry.lock b/poetry.lock index 2ee96b6..a1b64d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -254,13 +254,13 @@ xxhash = "*" apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa"] benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"] -dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"] +dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"] docs = ["s3fs"] quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"] s3 = ["fsspec", "boto3", "botocore", "s3fs"] tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"] tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "importlib-resources"] +tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "importlib-resources"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] @@ -842,6 +842,14 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "more-itertools" +version = "8.13.0" +description = "More routines for operating on iterables, beyond itertools" +category = "main" +optional = false +python-versions = ">=3.5" + [[package]] name = "multidict" version = "6.0.2" @@ -1876,7 +1884,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "~3.9" -content-hash = "be9447310d082f76089168072555978e3162f27b5be0ebc097b3d3ab9453f3c7" +content-hash = "aff9fd920f83203dace39d1840200a4e2f51685ecaac539e6b8168194184356c" [metadata.files] aiohttp = [ @@ -2483,6 +2491,10 @@ mistune = [ {file = "mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4"}, {file = "mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e"}, ] +more-itertools = [ + {file = "more-itertools-8.13.0.tar.gz", hash = "sha256:a42901a0a5b169d925f6f217cd5a190e32ef54360905b9c39ee7db5313bfec0f"}, + {file = "more_itertools-8.13.0-py3-none-any.whl", hash = "sha256:c5122bffc5f104d37c1626b8615b511f3427aa5389b94d61e5ef8236bfbc3ddb"}, +] multidict = [ {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"}, {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"}, diff --git a/pyproject.toml b/pyproject.toml index 23674f3..bbacc46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ ipywidgets = "^7.7.0" timm = "^0.5.4" scipy = "^1.8.1" gunicorn = "^20.1.0" +more-itertools = "^8.13.0" [tool.poetry.dev-dependencies]