Skip to content

CLIPTextModel / CLIPVisionModel fail to load old checkpoints after architecture flattening #45390

@sayakpaul

Description

@sayakpaul

Description

After the recent refactoring that flattened CLIPTextModel (removed the self.text_model wrapper) and CLIPVisionModel (removed the self.vision_model wrapper), old checkpoints that were saved with the nested structure can no longer be loaded correctly.

All weights end up randomly initialized because the checkpoint keys (e.g. text_model.embeddings.token_embedding.weight) don't match the new model's state dict keys (e.g. embeddings.token_embedding.weight).

Minimal reproducer

import torch
from transformers import CLIPTextModel, CLIPTextConfig

# Any old-format CLIP checkpoint works; this one ships with diffusers tests
model_path = "hf-internal-testing/tiny-stable-diffusion-torch"

# Download so it's cached
from huggingface_hub import hf_hub_download
ckpt_dir = hf_hub_download(model_path, "text_encoder/pytorch_model.bin")

# Show the checkpoint has text_model.* keys
sd = torch.load(ckpt_dir, map_location="cpu", weights_only=True)
print("Checkpoint key example:", list(sd.keys())[1])
# → text_model.embeddings.token_embedding.weight

expected_sum = sd["text_model.embeddings.token_embedding.weight"].sum().item()
print(f"Expected token_embedding sum: {expected_sum:.4f}")

# Load via from_pretrained
te = CLIPTextModel.from_pretrained(
    model_path, subfolder="text_encoder"
)
actual_sum = te.state_dict()["embeddings.token_embedding.weight"].sum().item()
print(f"Actual token_embedding sum:   {actual_sum:.4f}")

assert abs(expected_sum - actual_sum) < 1e-5, (
    f"Weights were NOT loaded! expected={expected_sum:.4f}, got={actual_sum:.4f}"
)

Output (failing):

Checkpoint key example: text_model.embeddings.token_embedding.weight
Expected token_embedding sum: -4.9096
Actual token_embedding sum:   -0.0497    # ← random init, not checkpoint value

AssertionError: Weights were NOT loaded! expected=-4.9096, got=-0.0497

Impact

This breaks any downstream code that loads CLIPTextModel or CLIPVisionModel from checkpoints saved with previous transformers versions — including all Stable Diffusion pipelines in diffusers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions