Skip to content

Commit 71160c1

Browse files
authored
Merge pull request #2 from Rganeshk/finetune
Added finetuning script
2 parents 2426f76 + ff6faef commit 71160c1

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

scripts/finetune.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import sys
2+
import os
3+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
4+
import torch
5+
from torch.utils.data import DataLoader
6+
from tqdm import tqdm
7+
from datasets import load_dataset
8+
from sklearn.metrics import precision_recall_fscore_support
9+
import torch.nn.functional as F
10+
11+
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
12+
13+
# === Config ===
14+
device = "cuda" if torch.cuda.is_available() else "cpu"
15+
batch_size = 32
16+
epochs = 3
17+
lr = 1e-5
18+
max_length = 77
19+
save_dir = "./checkpoints"
20+
os.makedirs(save_dir, exist_ok=True)
21+
save_every_n_steps = 1000 # Save every 1000 batches
22+
23+
# === Dataset ===
24+
class CocoCountingDataset(torch.utils.data.Dataset):
25+
def __init__(self, split="train", tokenizer=None, max_length=77):
26+
self.dataset = load_dataset("conceptual_captions", split=split)
27+
self.tokenizer = tokenizer
28+
self.max_length = max_length
29+
self.number_words = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten']
30+
31+
def __len__(self):
32+
return len(self.dataset)
33+
34+
def __getitem__(self, idx):
35+
caption = self.dataset[idx]['caption'].lower()
36+
label = int(any(word in caption for word in self.number_words)) # label 1 if counting word exists
37+
38+
if label == 0:
39+
caption = "one object."
40+
41+
encoding = self.tokenizer(caption, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
42+
input_ids = encoding["input_ids"].squeeze(0)
43+
attention_mask = encoding["attention_mask"].squeeze(0)
44+
return input_ids, attention_mask, label
45+
46+
# === Model ===
47+
model = FrozenCLIPEmbedder(version="openai/clip-vit-large-patch14", device=device, max_length=max_length)
48+
49+
for param in model.transformer.parameters():
50+
param.requires_grad = True
51+
52+
model = model.to(device)
53+
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.transformer.parameters()), lr=lr)
54+
55+
# === Dataloader ===
56+
dataset = CocoCountingDataset(split="train", tokenizer=model.tokenizer, max_length=max_length)
57+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
58+
59+
# === Training ===
60+
model.train()
61+
global_step = 0
62+
for epoch in range(epochs):
63+
total_loss = 0
64+
preds, targets = [], []
65+
66+
for batch_idx, (input_ids, attention_mask, labels) in enumerate(tqdm(dataloader)):
67+
input_ids = input_ids.to(device)
68+
attention_mask = attention_mask.to(device)
69+
labels = labels.to(device)
70+
71+
outputs = model.transformer(input_ids=input_ids, attention_mask=attention_mask)
72+
embeddings = outputs.last_hidden_state
73+
74+
loss = torch.mean(torch.norm(embeddings, dim=-1))
75+
76+
optimizer.zero_grad()
77+
loss.backward()
78+
optimizer.step()
79+
80+
total_loss += loss.item()
81+
82+
# Mock "classification" for precision/recall: use embedding norm as pseudo-score
83+
scores = torch.norm(embeddings[:, 0, :], dim=-1) # CLS token norm
84+
pred_labels = (scores > scores.mean()).long()
85+
86+
preds.extend(pred_labels.cpu().tolist())
87+
targets.extend(labels.cpu().tolist())
88+
89+
global_step += 1
90+
91+
# === Save checkpoint mid-epoch
92+
if global_step % save_every_n_steps == 0:
93+
checkpoint_path = os.path.join(save_dir, f"clip_rope_step{global_step}.pth")
94+
torch.save(model.transformer.state_dict(), checkpoint_path)
95+
print(f"[Checkpoint] Saved at step {global_step}")
96+
97+
# === End of epoch logging ===
98+
precision, recall, f1, _ = precision_recall_fscore_support(targets, preds, average='binary')
99+
print(f"Epoch {epoch+1}/{epochs}: Loss={total_loss/len(dataloader):.4f}")
100+
print(f"Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}")
101+
102+
# Save after each epoch
103+
checkpoint_path = os.path.join(save_dir, f"clip_rope_epoch{epoch+1}.pth")
104+
torch.save(model.transformer.state_dict(), checkpoint_path)
105+
print(f"[Checkpoint] Saved model after epoch {epoch+1}")
106+
107+
# === Final Save ===
108+
torch.save(model.transformer.state_dict(), "./clip_rope_finetuned_final.pth")
109+
print("[Final Save] Fine-tuned text encoder saved!")

0 commit comments

Comments
 (0)