pythonpytorchmachine-learningcomputer-visiondeep-dive

I trained a neural network to find hamster memes

An artist I follow on Instagram posts hundreds of drawings, but mixed in are variations of this one specific hamster reaction image. I scraped their entire profile and built a classifier to find them all.

/3 min read

@almarts27 draws all kinds of stuff on Instagram, but scattered throughout their hundreds of posts are these variations of one specific hamster reaction image. Different expressions, different situations, same hamster. I wanted all of them, and I didn't want to scroll through 400+ posts to get them.

The hamster in question source

Three scripts. Scrape the profile, download the images, train a model to sort them.

Scraping

Instagram doesn't want you doing this, and their API makes that clear. So I wrote a Tampermonkey userscript instead. Visit the profile, click the button, and it auto-scrolls through the page collecting every post:

async function startScraping() {
  if (isRunning) return;
  isRunning = true;

  let allPosts = [];
  let previousCount = 0;
  let noChangeRounds = 0;

  while (true) {
    const currentPosts = collectPosts();
    allPosts = deduplicatePosts([...allPosts, ...currentPosts]);

    if (allPosts.length === previousCount) {
      noChangeRounds++;
    } else {
      noChangeRounds = 0;
    }

    // 5 scrolls with no new posts = we've hit the bottom
    if (noChangeRounds >= 5) break;

    previousCount = allPosts.length;
    window.scrollBy(0, window.innerHeight);
    await sleep(1500);
  }

  // dump to clipboard, paste into a file, done
  GM_setClipboard(JSON.stringify(output, null, 2), 'text');
}

It grabs every <a> matching Instagram's /p/ URL pattern and pulls the <img> src from inside. When the post count stops changing for 5 consecutive scrolls, it assumes it's hit the bottom and copies everything as JSON to your clipboard.

388 posts from @almarts27. Paste into posts.json, move on.

Downloading

Nothing interesting here. Feed it the JSON, it downloads every image with the right headers so Instagram's CDN doesn't reject the requests. Names each file {profile}_{shortcode}.jpg. Skips duplicates.

for i, post in enumerate(posts, start=1):
    image_url = post.get("imageSrc")
    post_url = post.get("postUrl", "")

    shortcode = post_url.split("/p/")[-1].strip("/").split("/")[0]
    filename = f"{profile}_{shortcode}.jpg"
    filepath = output_dir / filename

    if filepath.exists():
        skip_count += 1
        continue

    req = urllib.request.Request(image_url, headers={
        "User-Agent": "Mozilla/5.0 ...",
        "Referer": "https://www.instagram.com/",
    })
    with urllib.request.urlopen(req, timeout=30) as response:
        with open(filepath, "wb") as out_file:
            out_file.write(response.read())

388 images downloaded. 233 of them were hamsters.

The classifier

Pretrained ResNet18. Freeze everything except the last residual block, replace the head with a single binary output:

def build_model():
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

    for param in model.parameters():
        param.requires_grad = False

    # unfreeze layer4 - the earlier layers handle generic stuff
    # (edges, textures) that transfers fine from ImageNet, but
    # layer4 is where it starts recognizing higher-level patterns.
    # cartoon hamsters ≠ ImageNet photos, so this needs to adapt.
    for param in model.layer4.parameters():
        param.requires_grad = True

    model.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(model.fc.in_features, 1),
    )

    return model

The classes aren't evenly split. 233 hamsters vs 155 non-hamsters. Without correction, the model would be biased toward predicting "hamster" on everything and coast on majority-class accuracy. So the loss function weights each class inversely to its frequency:

pos_weight = torch.tensor([neg_count / pos_count], device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

The training data was labeled by hand. I just dragged images into hamster/ and not_hamster/ folders. Took about 10 minutes. Augmentation handles the rest: random crops, flips, rotations, and color jitter so the model doesn't just memorize 388 images.

python ig_classifier.py train hamster/ not_hamster/ hamster_model.pth
# 20 epochs, done in under a minute

Classification sorts everything into directories and spits out a JSON with confidence scores per image:

python ig_classifier.py classify hamster_model.pth downloads/ output/ 0.5

The threshold's adjustable. Lower it if you want to catch borderline hamsters at the cost of false positives.

Did it work?

Yeah. The whole thing took maybe an hour to build, and I could've just manually saved the images in 20 minutes. But now the scraper works on any Instagram profile, the classifier works on any binary image task, and I have every single hamster.

Code

No repo for this one. The userscript is on Greasyfork, and the Python scripts are here: