autotag: allow recursively processing directories.
* Make it so you can give the `autotag` script a mixed list of files and directories, and it will recursively process every file in each directory. * Allow choosing the output format: one (filename, tag, score) tuple per line, or one (filename, tags) tuple per line.
This commit is contained in:
@@ -2,22 +2,40 @@
|
||||
|
||||
import json
|
||||
import click
|
||||
import itertools
|
||||
from autotagger import Autotagger
|
||||
from fastai.vision.core import PILImage
|
||||
from pathlib import Path
|
||||
from more_itertools import ichunked
|
||||
|
||||
@click.command(help="Automatically generate tags for an image.")
|
||||
@click.command(help="Automatically generate tags for an image.", context_settings=dict(max_content_width=140))
|
||||
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
|
||||
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return.")
|
||||
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), help="The model to use.")
|
||||
@click.argument("file", nargs=-1, type=click.File("rb"), required=True)
|
||||
def main(file, threshold, limit, model):
|
||||
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
|
||||
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
|
||||
@click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
|
||||
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
|
||||
@click.argument("file", nargs=-1, type=click.Path(exists=True, allow_dash=True), required=True)
|
||||
def main(file, threshold, limit, bs, group_tags, model):
|
||||
autotagger = Autotagger(model)
|
||||
images = [PILImage.create(f) for f in file]
|
||||
predictions = autotagger.predict(images, threshold=threshold, limit=limit)
|
||||
|
||||
for i, tags in enumerate(predictions):
|
||||
data = { "filename": file[i].name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
for filepaths in ichunked(get_filepaths(file), bs):
|
||||
filepaths = list(filepaths)
|
||||
predictions = autotagger.predict(filepaths, threshold=threshold, limit=limit, bs=bs)
|
||||
|
||||
for filepath, tags in zip(filepaths, predictions):
|
||||
if group_tags:
|
||||
data = { "filename": filepath.name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
else:
|
||||
for tag, score in tags.items():
|
||||
data = { "filename": filepath.name, "tag": tag, "score": score }
|
||||
click.echo(json.dumps(data))
|
||||
|
||||
def get_filepaths(paths):
|
||||
files = (recurse_dir(p) if Path(p).is_dir() else iter([p]) for p in paths)
|
||||
return itertools.chain(*files)
|
||||
|
||||
def recurse_dir(directory):
|
||||
return (path for path in Path(directory).glob("**/*") if not path.is_dir())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Generated
+15
-3
@@ -254,13 +254,13 @@ xxhash = "*"
|
||||
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||
audio = ["librosa"]
|
||||
benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"]
|
||||
dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"]
|
||||
dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"]
|
||||
docs = ["s3fs"]
|
||||
quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"]
|
||||
s3 = ["fsspec", "boto3", "botocore", "s3fs"]
|
||||
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
||||
tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
||||
tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "importlib-resources"]
|
||||
tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore (==1.4.2)", "boto3 (==1.17.106)", "botocore (==1.20.106)", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "lz4", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "sentencepiece", "sacremoses", "bert-score (>=0.3.6)", "jiwer", "mauve-text", "rouge-score", "sacrebleu", "scikit-learn", "scipy", "seqeval", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "importlib-resources"]
|
||||
torch = ["torch"]
|
||||
vision = ["Pillow (>=6.2.1)"]
|
||||
|
||||
@@ -842,6 +842,14 @@ category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
version = "8.13.0"
|
||||
description = "More routines for operating on iterables, beyond itertools"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
|
||||
[[package]]
|
||||
name = "multidict"
|
||||
version = "6.0.2"
|
||||
@@ -1876,7 +1884,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "~3.9"
|
||||
content-hash = "be9447310d082f76089168072555978e3162f27b5be0ebc097b3d3ab9453f3c7"
|
||||
content-hash = "aff9fd920f83203dace39d1840200a4e2f51685ecaac539e6b8168194184356c"
|
||||
|
||||
[metadata.files]
|
||||
aiohttp = [
|
||||
@@ -2483,6 +2491,10 @@ mistune = [
|
||||
{file = "mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4"},
|
||||
{file = "mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e"},
|
||||
]
|
||||
more-itertools = [
|
||||
{file = "more-itertools-8.13.0.tar.gz", hash = "sha256:a42901a0a5b169d925f6f217cd5a190e32ef54360905b9c39ee7db5313bfec0f"},
|
||||
{file = "more_itertools-8.13.0-py3-none-any.whl", hash = "sha256:c5122bffc5f104d37c1626b8615b511f3427aa5389b94d61e5ef8236bfbc3ddb"},
|
||||
]
|
||||
multidict = [
|
||||
{file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"},
|
||||
{file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"},
|
||||
|
||||
@@ -15,6 +15,7 @@ ipywidgets = "^7.7.0"
|
||||
timm = "^0.5.4"
|
||||
scipy = "^1.8.1"
|
||||
gunicorn = "^20.1.0"
|
||||
more-itertools = "^8.13.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user