Per-dataset evaluation of the unified model

Loads the unified CRNN checkpoint and measures its accuracy on a random sample from each of the ten R captcha datasets, printing both captcha-level (exact match) and character-level accuracy and showing a labeled grid per dataset.

import random
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd

from txtcaptcha import (
    Captcha,
    available_datasets,
    decrypt,
    load_model,
    read_captcha,
    sequence_accuracy,
)
from txtcaptcha.dataset import _label_from_path, _list_files

MODEL_PATH = Path('txtcaptcha_unified.pt')
DATA_ROOT = Path('data')
SAMPLE_SIZE = 200
SEED = 0

random.seed(SEED)
model = load_model(MODEL_PATH)
model.eval()
print(f'Loaded model with vocab size {len(model.vocab)}')
Loaded model with vocab size 62
def sample_dataset(name: str, n: int, seed: int = 0) -> Captcha:
    """Pick `n` random labeled files from the dataset directory and load them."""
    root = DATA_ROOT / name
    files = [p for p in _list_files(root) if _label_from_path(p) is not None]
    if not files:
        raise FileNotFoundError(f'No labeled files in {root}')
    rng = random.Random(seed)
    picked = rng.sample(files, min(n, len(files)))
    return read_captcha(picked, lab_in_path=True)

def char_accuracy(preds, golds) -> float:
    """Position-wise character accuracy, normalized by total gold characters.
    Length mismatches count every position up to max(len_p, len_g) as wrong."""
    total = 0
    correct = 0
    for p, g in zip(preds, golds):
        m = max(len(p), len(g))
        total += m
        correct += sum(1 for i in range(m) if i < len(p) and i < len(g) and p[i] == g[i])
    return correct / max(1, total)
def evaluate_dataset(name: str):
    cap = sample_dataset(name, SAMPLE_SIZE, seed=SEED)
    golds = [lab.lower() for lab in cap.labels]
    preds = decrypt(cap, model, case_sensitive=False)
    return {
        'dataset': name,
        'n': len(cap),
        'captcha_acc': sequence_accuracy(preds, golds),
        'char_acc': char_accuracy(preds, golds),
        'captcha': cap,
        'preds': preds,
        'golds': golds,
    }

results = {name: evaluate_dataset(name) for name in available_datasets()}
summary = pd.DataFrame([
    {k: r[k] for k in ('dataset', 'n', 'captcha_acc', 'char_acc')}
    for r in results.values()
]).sort_values('captcha_acc', ascending=False).reset_index(drop=True)
summary
dataset n captcha_acc char_acc
0 cadesp 200 1.000 1.000000
1 tjrs 200 1.000 1.000000
2 trt 200 0.995 0.999167
3 jucesp 200 0.985 0.997000
4 rfb 200 0.980 0.996667
5 esaj 200 0.975 0.990010
6 tjmg 200 0.970 0.994000
7 trf5 200 0.955 0.987500
8 sei 200 0.935 0.970000
9 tjpe 200 0.000 0.025000
fig, ax = plt.subplots(figsize=(8, 4))
ax.barh(summary['dataset'], summary['captcha_acc'], label='captcha exact-match')
ax.barh(summary['dataset'], summary['char_acc'], alpha=0.35, label='char accuracy')
ax.set_xlim(0, 1)
ax.set_xlabel('accuracy')
ax.invert_yaxis()
ax.legend(loc='lower right')
ax.set_title(f'Unified model @ {SAMPLE_SIZE} random samples per dataset')
fig.tight_layout()

Labeled sample grids

For each dataset we show up to 12 random images. Green titles mean the prediction matched the gold label (case-insensitive); red titles highlight mistakes with pred / gold.

def show_predictions(result, n_show: int = 12, cols: int = 4):
    cap = result['captcha']
    preds = result['preds']
    golds = result['golds']
    rng = random.Random(SEED)
    idxs = rng.sample(range(len(cap)), min(n_show, len(cap)))
    rows = (len(idxs) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.2, rows * 1.4))
    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
    for ax, idx in zip(axes, idxs):
        ax.imshow(cap.images[idx])
        ax.axis('off')
        p, g = preds[idx], golds[idx]
        ok = p == g
        title = p if ok else f'{p} / {g}'
        ax.set_title(title, fontsize=9, color='green' if ok else 'red')
    for ax in axes[len(idxs):]:
        ax.axis('off')
    fig.suptitle(
        f"{result['dataset']}  "
        f"captcha_acc={result['captcha_acc']:.2%}  "
        f"char_acc={result['char_acc']:.2%}",
        fontsize=11,
    )
    fig.tight_layout()
    return fig

for name in available_datasets():
    show_predictions(results[name])
plt.show()