txtcaptcha — Training a unified captcha model

This notebook downloads every labeled dataset published by the original R captcha package, merges them into a single training corpus, and fits a CRNN+CTC model on the full alphanumeric vocabulary (0-9a-zA-Z).

Run on a GPU machine — the dataset is around ~300 MB and training is much faster on CUDA.

0. Install

From the repo root:

uv sync --extra dev --extra notebook
uv run python -m ipykernel install --user --name txtcaptcha
import shutil
from pathlib import Path

import torch

from txtcaptcha import (
    available_datasets,
    download_dataset,
    fit_model,
    save_model,
    decrypt,
)

print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('Device:', torch.cuda.get_device_name(0))

1. Download every dataset

Each release zip is extracted under data/<name>/. Total download size is ~300 MB.

DATA_ROOT = Path('data')
DATA_ROOT.mkdir(exist_ok=True)

for name in available_datasets():
    print(f'>>> {name}')
    download_dataset(name, DATA_ROOT)

2. Flatten everything into a single training directory

fit_model expects a flat folder of <id>_<label>.<ext> files. We rename each file with a dataset prefix to keep ids unique.

import re

MERGED = Path('data_merged')
if MERGED.exists():
    shutil.rmtree(MERGED)
MERGED.mkdir()

label_re = re.compile(r'(?<=_)[0-9a-zA-Z]+')
exts = {'.png', '.jpg', '.jpeg', '.bmp'}

n_total = 0
for name in available_datasets():
    src = DATA_ROOT / name
    n = 0
    for f in src.rglob('*'):
        if f.suffix.lower() not in exts:
            continue
        if not label_re.search(f.stem):
            continue
        new_name = f'{name}-{f.name}'
        shutil.copy2(f, MERGED / new_name)
        n += 1
    print(f'{name}: {n} images')
    n_total += n
print('Total:', n_total)

3. Train the unified model

fit_model builds a CRNN with the full alphanumeric vocabulary, trains with CTC loss, and applies early stopping based on validation sequence accuracy.

model, history = fit_model(
    dir=MERGED,
    prop_valid=0.1,
    batch_size=64,
    epochs=80,
    lr=1e-3,
    weight_decay=1e-4,
    dropout=0.3,
    early_stopping_patience=10,
    case_sensitive=True,
    num_workers=4,
    verbose=True,
)
save_model(model, 'txtcaptcha_unified.pt')

4. Inspect training curves

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(10, 3))
axes[0].plot(history.train_loss, label='train')
axes[0].plot(history.valid_loss, label='valid')
axes[0].set_title('Loss'); axes[0].legend()
axes[1].plot(history.valid_acc, color='green')
axes[1].set_title('Validation sequence accuracy')
fig.tight_layout()

5. Quick smoke test with masking

decrypt accepts a mask (list of allowed characters or regex character class) and case_sensitive flag. Useful when you know the target site uses, say, only lowercase letters.

import random
sample = random.sample(list(MERGED.iterdir()), 16)
preds = decrypt(sample, model, mask='[0-9a-zA-Z]')
for p, f in zip(preds, sample):
    print(f'{f.name:40s} -> {p}')