diff --git a/articles/gpt-oss/fine-tune-korean.ipynb b/articles/gpt-oss/fine-tune-korean.ipynb new file mode 100644 index 0000000000..14d485d56b --- /dev/null +++ b/articles/gpt-oss/fine-tune-korean.ipynb @@ -0,0 +1,1196 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "538f25ce", + "metadata": {}, + "source": [ + "\n", + "이 노트북은 OpenAI의 **gpt-oss (open‑weight)** 모델을 **한국 뉴스 문체 + 최신 대화체**로 세밀 튜닝하는 방법을\n", + "한국어/영어 **이중 언어**로 제공합니다. \n", + "This notebook shows how to fine‑tune OpenAI's **gpt-oss (open‑weight)** models for **Korean news style + modern chat tone**, in **Korean & English**.\n", + "\n", + "---\n", + "\n", + "### MXFP4 workflow clarifications · MXFP4 워크플로 정리\n", + "\n", + "**EN:** \n", + "- Training or fine-tuning **directly in MXFP4 is not supported** by public frameworks today. \n", + "- Recommended path: train in **BF16** (or **QLoRA 4‑bit nf4**) → **merge LoRA** → **post‑training quantize to MXFP4** → `save_pretrained()` for deployment. \n", + "- If you need an MXFP4 artifact, you must **re‑quantize from BF16** after merging adapters. (Export utilities are evolving; if your toolchain already supports MXFP4 serialization, that’s ideal.)\n", + "\n", + "**KR:** \n", + "- 현재 공개 프레임워크에서는 **MXFP4로 직접 학습/파인튜닝**이 지원되지 않습니다. \n", + "- 권장 경로: **BF16**(또는 **QLoRA 4‑bit nf4**)로 학습 → **LoRA 병합** → **사후(MXFP4) 양자화** → 배포용으로 `save_pretrained()` 저장. \n", + "- MXFP4 아티팩트가 필요하면, 어댑터 병합 후 **BF16 → MXFP4 재양자화**가 필요합니다. (직렬화 유틸은 진화 중이며, 툴체인에서 MXFP4 저장을 지원하면 가장 좋습니다.)\n", + "\n", + "---\n", + "\n", + "### LoRA targets (MoE) · LoRA 타깃(MoE 포함)\n", + "\n", + "**EN:** \n", + "- Minimal config (fast, low VRAM): target attention only, e.g. `[\"q_proj\",\"v_proj\"]`. \n", + "- MoE‑aware config (better domain adaptation, more VRAM/time): include **expert projection layers** in addition to attention. \n", + "\n", + "```python\n", + "from peft import LoraConfig\n", + "\n", + "TARGET_MODULES = [\"q_proj\", \"v_proj\"] # baseline\n", + "MOE_TARGET_PARAMETERS = [\n", + " # example expert layers; adjust indices to your model depth\n", + " \"mlp.experts.gate_up_proj\",\n", + " \"mlp.experts.down_proj\",\n", + "]\n", + "\n", + "lora_cfg = LoraConfig(\n", + " r=16, lora_alpha=32, lora_dropout=0.05,\n", + " target_modules=\"all-linear\", # cover all linear layers\n", + " target_parameters=MOE_TARGET_PARAMETERS, # add expert projections\n", + " bias=\"none\", task_type=\"CAUSAL_LM\",\n", + ")\n", + "```\n", + "\n", + "- Start with attention‑only; if KR domain fit is insufficient, enable MoE targets and re‑eval.\n", + "\n", + "**KR:** \n", + "- 최소 구성(빠르고 VRAM 절약): `[\"q_proj\",\"v_proj\"]` 등 **어텐션만** 적용. \n", + "- **MoE 인지 구성**(도메인 적합성↑, 자원 소모↑): 어텐션에 **전문가(Expert) 투영 레이어**를 추가로 포함. \n", + "- 먼저 어텐션만으로 시도한 뒤, 한국어 도메인 적합성이 부족하면 MoE 타깃을 켜고 재평가하세요." + ] + }, + { + "cell_type": "markdown", + "id": "bd7c12ff", + "metadata": {}, + "source": [ + "## Contents · 목차\n", + "0) Goals & Scope · 목표 & 범위 \n", + "1) Environment check · 환경 점검 \n", + "2) 설정값 · Config \n", + "3) 패키지 설치 · Install Deps \n", + "4) 데이터 소싱(한국형) · KR‑Context Data Sourcing \n", + "5) 샘플 데이터 생성 · Create Sample Data \n", + "6) 전처리(PIPA) & 스타일 라벨 · PII Scrubbing & Style Tags \n", + "7) 데이터 로딩/포맷팅 · Load & Format \n", + "8) 모델/토크나이저 로드 · Load Model & Tokenizer \n", + "9) Fine‑Tuning (LoRA/QLoRA) · 세밀 튜닝 \n", + " 9a) Data curation & splits \n", + " 9b) Hyperparameters (r/alpha/dropout) \n", + " 9c) Merge adapters (BF16) \n", + " 9d) Save merged BF16 (`save_pretrained`) \n", + " 9e) Export & Quantize (BF16 → MXFP4) · 내보내기 & 양자화 \n", + "10) 평가(뉴스/대화) · Evaluation (News/Chat) \n", + "11) Inference Prompt Templates · 추론 프롬프트 템플릿 \n", + "12) 최신성 유지 · Freshness Strategy \n", + "13) 안전/컴플라이언스 · Safety & Compliance \n", + "14) 문제해결 & 다음 단계 · Troubleshooting & Next Steps\n" + ] + }, + { + "cell_type": "markdown", + "id": "bb8655d2", + "metadata": {}, + "source": [ + "### ⚙️ Training vs Quantization — What’s supported\n", + "- **Do:** Train with BF16/FP16 or QLoRA; export merged weights.\n", + "- **Then:** Quantize to **MXFP4** for inference using provided conversion scripts/utilities.\n", + "- **Don’t:** Attempt to run an end‑to‑end “train in MXFP4” pipeline — not supported today." + ] + }, + { + "cell_type": "markdown", + "id": "bb24a3d9", + "metadata": {}, + "source": [ + "> **PII & Compliance Reminder:** For KR data, follow your enterprise policy (mask RRN/phone/account IDs, remove emails) **before** training & logging. Keep train/val/test splits stratified by source and style tags." + ] + }, + { + "cell_type": "markdown", + "id": "e1e883f5", + "metadata": {}, + "source": [ + "### 🧪 MoE adapters (optional)\n", + "You can target MoE layers with adapters, but treat this as **advanced/experimental**. Start with attention projections first and validate KR benchmarks before expanding scope." + ] + }, + { + "cell_type": "markdown", + "id": "179543e6", + "metadata": {}, + "source": [ + "> **Note:** Keep `transformers`, `peft`, `accelerate`, and `trl` at versions known to support BF16/4‑bit LoRA. \n", + "If you pin `safetensors`, remember that **native MXFP4 serialization is not yet standardized**; loaders may upcast internally." + ] + }, + { + "cell_type": "markdown", + "id": "f8e743f0", + "metadata": {}, + "source": [ + "### 🔎 Support Matrix — At a glance\n", + "- **Fine‑tuning precision:** BF16/FP16 ✅ · QLoRA 4‑bit ✅ · **MXFP4 FT ❌**\n", + "- **Quantization target:** MXFP4 ✅ (post‑training)\n", + "- **API FT (hosted) for OSS models:** ❌\n", + "- **Open‑source FT (Transformers/TRL/PEFT):** ✅\n", + "- **LoRA targets:** `q_proj`, `k_proj`, `v_proj`, `o_proj` ✅; MoE expert adapters **experimental** ⚠️" + ] + }, + { + "cell_type": "markdown", + "id": "f4dec1f6", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "e3d489c2", + "metadata": {}, + "source": [ + "## 0) Goals & Scope · 목표 & 범위\n", + "- **KR**: 한국어 일반 뉴스 + 일상/상담 대화체에 최적화. `style=news_headline|news_lead|news_body|kakao_casual|kakao_formal` 제어.\n", + "- **EN**: Optimize for Korean news writing and modern chat tone; control output via style tags above.\n", + "- **Stack**: `transformers`, `trl(SFTTrainer)`, `peft(LoRA/QLoRA)`, `datasets`.\n", + "- **Hardware**: Single/few GPUs (BF16 preferred). CPU/Mac for lightweight tests." + ] + }, + { + "cell_type": "markdown", + "id": "db97218d", + "metadata": {}, + "source": [ + "## 1) Environment check · 환경 점검" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5babb2c3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python: 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0]\n", + "OS/Platform: Linux-6.8.0-60-generic-x86_64-with-glibc2.35\n", + "CUDA_VISIBLE_DEVICES: \n", + "Torch: 2.7.1+cu126 CUDA: True\n", + "GPU: NVIDIA H100 80GB HBM3\n" + ] + } + ], + "source": [ + "import os, sys, platform\n", + "print(\"Python:\", sys.version)\n", + "print(\"OS/Platform:\", platform.platform())\n", + "print(\"CUDA_VISIBLE_DEVICES:\", os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"\"))\n", + "\n", + "try:\n", + " import torch\n", + " print(\"Torch:\", torch.__version__, \"CUDA:\", torch.cuda.is_available())\n", + " if torch.cuda.is_available():\n", + " print(\"GPU:\", torch.cuda.get_device_name(0))\n", + "except Exception as e:\n", + " print(\"Torch not installed or GPU not detected:\", e)" + ] + }, + { + "cell_type": "markdown", + "id": "25688688", + "metadata": {}, + "source": [ + "## 2) 설정값 · Config" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c15817f7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Config ready.\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "import os\n", + "\n", + "# === Model & Training Params ===\n", + "BASE_URL = \"http://localhost:8000/v1\" # vLLM OpenAI-compatible endpoint\n", + "API_KEY = \"dummy-key\" # vLLM ignores; SDK requires a value\n", + "MODEL = \"openai/gpt-oss-120b\" # must match the model vLLM loaded\n", + "OUTPUT_DIR = \"ft-oss-kr-news-chat-bilingual\"\n", + "\n", + "# Data mix (news : chat)\n", + "MIX_NEWS = 0.6\n", + "MIX_CHAT = 0.4\n", + "\n", + "# LoRA\n", + "LORA_R = 8\n", + "LORA_ALPHA = 16\n", + "LORA_DROPOUT = 0.05\n", + "TARGET_MODULES = [\"q_proj\", \"v_proj\"] # adjust per model\n", + "\n", + "# Training\n", + "EPOCHS = 1\n", + "PER_DEVICE_BS = 2\n", + "GRAD_ACCUM = 8\n", + "LEARNING_RATE = 2e-4\n", + "BF16 = True\n", + "LOG_STEPS = 20\n", + "SAVE_STEPS = 200\n", + "SAVE_TOTAL_LIMIT = 2\n", + "\n", + "print(\"Config ready.\")" + ] + }, + { + "cell_type": "markdown", + "id": "85f258eb", + "metadata": {}, + "source": [ + "## 3) 패키지 설치 · Install Deps" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b1b75968", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "transformers: 4.55.3\n", + "accelerate: 1.10.0\n", + "datasets: 4.0.0\n", + "peft: not installed\n", + "trl: 0.21.0\n", + "bitsandbytes: not installed\n", + "sentencepiece: 0.2.1\n", + "vllm: 0.10.1\n", + "llama_cpp: 0.3.16\n", + "pip: 25.2\n", + "Install cells are commented. Un-comment in your environment.\n" + ] + } + ], + "source": [ + "# %pip install --upgrade pip\n", + "# %pip install transformers accelerate datasets peft trl bitsandbytes sentencepiece\n", + "# (optional) serving/runtimes\n", + "# %pip install vllm\n", + "# %pip install llama-cpp-python\n", + "\n", + "import importlib, pip\n", + "\n", + "for dep in [\"transformers\",\"accelerate\",\"datasets\",\"peft\",\"trl\",\n", + " \"bitsandbytes\",\"sentencepiece\",\"vllm\",\"llama_cpp\"]:\n", + " try:\n", + " print(f\"{dep}: {importlib.import_module(dep).__version__}\")\n", + " except Exception:\n", + " print(f\"{dep}: not installed\")\n", + "\n", + "print(f\"pip: {pip.__version__}\")\n", + "\n", + "print(\"Install cells are commented. Un-comment in your environment.\")" + ] + }, + { + "cell_type": "markdown", + "id": "de8647fd", + "metadata": {}, + "source": [ + "## 4) 데이터 소싱(한국형) · KR‑Context Data Sourcing" + ] + }, + { + "cell_type": "markdown", + "id": "da22cbd6", + "metadata": {}, + "source": [ + "**KR** \n", + "- 공개 벤치마크(주제 분류/요약/QA) + **허용된 뉴스 API의 메타데이터(제목/요약/섹션)** 중심으로 스타일 보정.\n", + "- 기사 **원문 대량 재학습은 저작권/약관 이슈** → 메타데이터·공개 코퍼스 위주.\n", + "- 대화체는 합법 공개 코퍼스(반말/존댓말/이모티콘/축약어 라벨 포함) 우선.\n", + "- PIPA: 주민번호/연락처/이메일/계좌 등 개인정보는 **훈련 전/로그 전** 스크러빙.\n", + "\n", + "**EN** \n", + "- Prefer public KR benchmarks (topic classification / summarization / QA) and **allowed news API metadata** for style calibration.\n", + "- Avoid mass training on news full texts due to license/ToS constraints; use metadata + open corpora.\n", + "- For chat, use lawful open corpora with tone/emoji/informal‑formal annotations.\n", + "- Scrub PII (phone, RRNs, emails, accounts) before training/logging." + ] + }, + { + "cell_type": "markdown", + "id": "9b918411", + "metadata": {}, + "source": [ + "## 5) 샘플 데이터 생성 · Create Sample Data" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "18db10a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created: data/news.jsonl, data/chat.jsonl\n" + ] + } + ], + "source": [ + "import json, pathlib\n", + "pathlib.Path(\"data\").mkdir(exist_ok=True)\n", + "\n", + "news_samples = [\n", + " {\"style\":\"news_lead\",\"topic\":\"경제\",\"title\":\"반도체 수출 호조… 7월 수출액 20% 증가\",\"summary\":\"수출 개선세가 이어지며 경기 회복 기대가 커졌다.\"},\n", + " {\"style\":\"news_headline\",\"topic\":\"정치\",\"title\":\"국회, 데이터 산업 육성법 본회의 통과\",\"summary\":\"데이터 활용 촉진과 개인정보 보호를 강화하는 내용.\"},\n", + " {\n", + " \"style\": \"news_lead\",\n", + " \"topic\": \"경제\",\n", + " \"title\": \"카카오페이 보안 점검… 고객문의: help+vip@corp.co.kr\",\n", + " \"summary\": \"고객센터 010-1234-5678로 문의 폭주. 계좌 110-123-456789 관련 결제 오류 논란.\"\n", + " },\n", + " {\n", + " \"style\": \"news_headline\",\n", + " \"topic\": \"사회\",\n", + " \"title\": \"개인정보 유출 의혹… 주민번호 901010-1234567 유통 주장\",\n", + " \"summary\": \"서울특별시 강남구 테헤란로 123에서 자료 확보… 담당자 john.doe+news@example.com\"\n", + " }\n", + "]\n", + "\n", + "chat_samples = [\n", + " {\"style\":\"kakao_casual\",\"dialog\":[\"주말에 비 온대?\",\"응 일요일에 꽤 온다더라 ☔\",\"헐 우산 챙겨야겠다\"]},\n", + " {\"style\":\"kakao_formal\",\"dialog\":[\"안녕하세요. 배송 일정 확인 부탁드립니다.\",\"내일 중 도착 예정입니다.\",\"안내 감사합니다.\"]},\n", + " {\n", + " \"style\": \"kakao_formal\",\n", + " \"dialog\": [\n", + " \"배송 확인 부탁드립니다. 주문번호 ORD-2025-0001 입니다.\",\n", + " \"연락처는 010-2222-3333 입니다. (유니코드 하이픈)\",\n", + " \"주민등록번호는 제공할 수 없습니다.\"\n", + " ]\n", + " }\n", + "]\n", + "\n", + "with open(\"data/news.jsonl\",\"w\",encoding=\"utf-8\") as f:\n", + " for ex in news_samples: f.write(json.dumps(ex, ensure_ascii=False)+\"\\n\")\n", + "with open(\"data/chat.jsonl\",\"w\",encoding=\"utf-8\") as f:\n", + " for ex in chat_samples: f.write(json.dumps(ex, ensure_ascii=False)+\"\\n\")\n", + "\n", + "print(\"Created: data/news.jsonl, data/chat.jsonl\")" + ] + }, + { + "cell_type": "markdown", + "id": "4f1eaa27", + "metadata": {}, + "source": [ + "## 6) 전처리(PIPA) & 스타일 라벨 · PII Scrubbing & Style Tags" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "430c1b68", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data/news.jsonl -> data/news_clean.jsonl | rows: 4, redacted_rows: 2, hits: {'[EMAIL]': 2, '[ACCOUNT]': 1, '[RRN]': 1, '[CITY]': 1}\n", + "data/chat.jsonl -> data/chat_clean.jsonl | rows: 3, redacted_rows: 1, hits: {'[PHONE]': 1}\n" + ] + } + ], + "source": [ + "# Step 6 — PII scrubbing + style tags (no Harmony here)\n", + "import json, re, unicodedata\n", + "from pathlib import Path\n", + "\n", + "# --- Normalization helpers ---\n", + "HYPHENS = dict.fromkeys(map(ord, \"‐-‒–—―﹘﹣-\"), ord(\"-\")) # map unicode hyphens → ASCII\n", + "def normalize(s: str) -> str:\n", + " if not isinstance(s, str): return s\n", + " s = unicodedata.normalize(\"NFKC\", s)\n", + " s = s.translate(HYPHENS)\n", + " return s\n", + "\n", + "# --- PII patterns (illustrative; tune for production) ---\n", + "RE_EMAIL = re.compile(r\"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}\")\n", + "# KR mobile numbers with spaces/hyphens: 010-1234-5678, 010 1234 5678, etc.\n", + "RE_PHONE = re.compile(r\"\\b01[016789][-\\s]?\\d{3,4}[-\\s]?\\d{4}\\b\")\n", + "# Korean RRN (주민등록번호) basic pattern\n", + "RE_RRN = re.compile(r\"\\b\\d{6}-\\d{7}\\b\")\n", + "# Bank-ish account numbers: strictly digits in groups (avoid codes with letters)\n", + "RE_ACCOUNT = re.compile(r\"\\b\\d{2,3}-\\d{2,4}-\\d{3,6}\\b\")\n", + "# Very simple postal address cue (city names) – conservative, just redact the token (optional)\n", + "RE_CITY = re.compile(r\"(서울특별시|부산광역시|대구광역시|인천광역시|광주광역시|대전광역시|울산광역시|세종특별자치시|경기도|강원도|충청북도|충청남도|전라북도|전라남도|경상북도|경상남도|제주특별자치도)\")\n", + "\n", + "# Allowlist: things that look like PII but aren’t (e.g., bill/order codes w/ letters)\n", + "def looks_like_code(s: str) -> bool:\n", + " return bool(re.search(r\"[A-Za-z]\", s)) # if letters present, treat as code, not account/phone\n", + "\n", + "# Order of application matters (longest/most specific first sometimes helps)\n", + "SCRUBBERS = [\n", + " (\"[RRN]\", RE_RRN),\n", + " (\"[EMAIL]\", RE_EMAIL),\n", + " (\"[PHONE]\", RE_PHONE),\n", + " (\"[ACCOUNT]\", RE_ACCOUNT),\n", + " (\"[CITY]\", RE_CITY), # optional; comment out if you don't want to redact city tokens\n", + "]\n", + "\n", + "def scrub_text(text: str) -> tuple[str, dict]:\n", + " \"\"\"Return (scrubbed_text, hits_dict). Avoid false positives with basic allowlisting.\"\"\"\n", + " if not isinstance(text, str) or not text:\n", + " return text, {}\n", + " orig = text\n", + " text = normalize(text)\n", + " hits = {}\n", + "\n", + " # Guard account-like and phone-like strings that contain letters (likely codes)\n", + " guarded = set()\n", + " for m in RE_ACCOUNT.finditer(text):\n", + " if looks_like_code(m.group(0)):\n", + " guarded.add(m.span())\n", + " for m in RE_PHONE.finditer(text):\n", + " if looks_like_code(m.group(0)):\n", + " guarded.add(m.span())\n", + "\n", + " # Apply scrubs\n", + " for label, pattern in SCRUBBERS:\n", + " out = []\n", + " last = 0\n", + " count = 0\n", + " for m in pattern.finditer(text):\n", + " span = m.span()\n", + " if pattern in (RE_ACCOUNT, RE_PHONE) and span in guarded:\n", + " continue\n", + " out.append(text[last:span[0]])\n", + " out.append(label)\n", + " last = span[1]\n", + " count += 1\n", + " out.append(text[last:])\n", + " text = \"\".join(out)\n", + " if count:\n", + " hits[label] = hits.get(label, 0) + count\n", + "\n", + " return text, hits if text != orig else {}\n", + "\n", + "def scrub_record(rec: dict, kind: str) -> tuple[dict, dict]:\n", + " \"\"\"Scrub fields in a news/chat record; return (new_rec, hits).\"\"\"\n", + " rec = dict(rec) # shallow copy\n", + " total_hits = {}\n", + "\n", + " def scrub_field(key):\n", + " val = rec.get(key)\n", + " new, hits = scrub_text(val) if isinstance(val, str) else (val, {})\n", + " rec[key] = new\n", + " for k, v in hits.items():\n", + " total_hits[k] = total_hits.get(k, 0) + v\n", + "\n", + " if kind == \"news\":\n", + " for key in (\"title\", \"summary\", \"topic\"):\n", + " scrub_field(key)\n", + " elif kind == \"chat\":\n", + " scrub_field(\"style\")\n", + " if isinstance(rec.get(\"dialog\"), list):\n", + " cleaned_dialog = []\n", + " for turn in rec[\"dialog\"]:\n", + " new, hits = scrub_text(turn) if isinstance(turn, str) else (turn, {})\n", + " cleaned_dialog.append(new)\n", + " for k, v in hits.items():\n", + " total_hits[k] = total_hits.get(k, 0) + v\n", + " rec[\"dialog\"] = cleaned_dialog\n", + "\n", + " return rec, total_hits\n", + "\n", + "# --- Style tagger (lightweight labels for later routing/metrics) ---\n", + "def build_style_tags(rec: dict, kind: str) -> list[str]:\n", + " tags = []\n", + " if kind == \"news\":\n", + " tags.append(\"domain:\" + (rec.get(\"topic\") or \"unknown\"))\n", + " tags.append(\"style:\" + (rec.get(\"style\") or \"news\"))\n", + " tags.append(\"tone:formal\")\n", + " tags.append(\"medium:news\")\n", + " elif kind == \"chat\":\n", + " style = (rec.get(\"style\") or \"\").lower()\n", + " tags.append(\"style:\" + (style or \"chat\"))\n", + " tags.append(\"tone:\" + (\"formal\" if \"formal\" in style else \"casual\"))\n", + " tags.append(\"medium:kakao\")\n", + " return [t.replace(\" \", \"_\") for t in tags]\n", + "\n", + "# --- Process files ---\n", + "def process_file(src: str, dst: str, kind: str):\n", + " total = 0\n", + " redacted = 0\n", + " counters = {}\n", + " with open(src, encoding=\"utf-8\") as fin, open(dst, \"w\", encoding=\"utf-8\") as fout:\n", + " for line in fin:\n", + " if not line.strip(): continue\n", + " rec = json.loads(line)\n", + " total += 1\n", + " cleaned, hits = scrub_record(rec, kind)\n", + " cleaned[\"style_tags\"] = build_style_tags(cleaned, kind)\n", + " cleaned[\"_pii_hits\"] = hits # keep for inspection; drop later if you want\n", + " if hits: redacted += 1\n", + " for k, v in hits.items():\n", + " counters[k] = counters.get(k, 0) + v\n", + " fout.write(json.dumps(cleaned, ensure_ascii=False) + \"\\n\")\n", + " print(f\"{src} -> {dst} | rows: {total}, redacted_rows: {redacted}, hits: {counters}\")\n", + "\n", + "process_file(\"data/news.jsonl\", \"data/news_clean.jsonl\", kind=\"news\")\n", + "process_file(\"data/chat.jsonl\", \"data/chat_clean.jsonl\", kind=\"chat\")" + ] + }, + { + "cell_type": "markdown", + "id": "6ac01dca", + "metadata": {}, + "source": [ + "## 7) 데이터 로딩/포맷팅 · Load & Format" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9cd825e3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created: data/news_harmony.jsonl data/chat_harmony.jsonl\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6f769d524f424ed5a11781a157cfa796", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generating news split: 0 examples [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "af2e4dc971884747a719d500caf52722", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generating chat split: 0 examples [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'train': 3, 'validation': 4}\n" + ] + } + ], + "source": [ + "# Step 7 — Harmony conversion + dataset loading & tokenization\n", + "import json, math\n", + "from pathlib import Path\n", + "from datasets import load_dataset, Dataset, concatenate_datasets\n", + "from transformers import AutoTokenizer\n", + "\n", + "DATA = Path(\"data\")\n", + "assert (DATA / \"news_clean.jsonl\").exists(), \"Run Step 6 first\"\n", + "assert (DATA / \"chat_clean.jsonl\").exists(), \"Run Step 6 first\"\n", + "\n", + "# ---------- 7A) Convert cleaned → Harmony messages ----------\n", + "\n", + "def news_to_messages(rec):\n", + " # system style from Step 6 tags; default to KR news tone\n", + " system = \"한국 뉴스 문체로 간결하고 사실 위주로 작성.\"\n", + " # user asks for a headline+lead from topic; assistant is the expected formatted answer\n", + " user = f\"주제: {rec.get('topic','알수없음')}. 기사 제목과 요약을 생성해줘.\"\n", + " assistant = f\"{rec.get('title','')} — {rec.get('summary','')}\"\n", + " return [{\"role\":\"system\",\"content\":system},\n", + " {\"role\":\"user\",\"content\":user},\n", + " {\"role\":\"assistant\",\"content\":assistant}]\n", + "\n", + "def chat_to_messages(rec):\n", + " # Keep style hint (casual/formal) in system\n", + " style = (rec.get(\"style\") or \"\").lower()\n", + " system = f\"카카오톡 대화 스타일. style={style or 'chat'}\"\n", + " dialog = rec.get(\"dialog\") or []\n", + " msgs = [{\"role\":\"system\",\"content\":system}]\n", + " # Alternate user/assistant turns; if odd length, last user stays without assistant label\n", + " roles = [\"user\",\"assistant\"]\n", + " for i, turn in enumerate(dialog[:6]): # cap tiny demos to avoid runaway\n", + " msgs.append({\"role\": roles[i % 2], \"content\": str(turn)})\n", + " # Ensure there is at least one assistant turn for SFT\n", + " if not any(m[\"role\"]==\"assistant\" for m in msgs):\n", + " msgs.append({\"role\":\"assistant\",\"content\":\"네, 확인했습니다.\"})\n", + " return msgs\n", + "\n", + "def write_harmony(src, dst, kind):\n", + " convert = news_to_messages if kind==\"news\" else chat_to_messages\n", + " with open(src, encoding=\"utf-8\") as fin, open(dst, \"w\", encoding=\"utf-8\") as fout:\n", + " for line in fin:\n", + " if not line.strip(): continue\n", + " rec = json.loads(line)\n", + " msgs = convert(rec)\n", + " fout.write(json.dumps({\"messages\": msgs}, ensure_ascii=False) + \"\\n\")\n", + "\n", + "write_harmony(DATA/\"news_clean.jsonl\", DATA/\"news_harmony.jsonl\", \"news\")\n", + "write_harmony(DATA/\"chat_clean.jsonl\", DATA/\"chat_harmony.jsonl\", \"chat\")\n", + "print(\"Created:\", DATA/\"news_harmony.jsonl\", DATA/\"chat_harmony.jsonl\")\n", + "\n", + "# ---------- 7B) Load Harmony JSONL with 🤗 Datasets ----------\n", + "raw = load_dataset(\n", + " \"json\",\n", + " data_files={\"news\": str(DATA/\"news_harmony.jsonl\"),\n", + " \"chat\": str(DATA/\"chat_harmony.jsonl\")}\n", + ")\n", + "\n", + "# Mix train split using your Step-2 mix ratios\n", + "news = raw[\"news\"]\n", + "chat = raw[\"chat\"]\n", + "\n", + "def take_portion(ds, frac):\n", + " n = max(1, int(round(len(ds) * frac)))\n", + " return ds.select(range(n)) if n < len(ds) else ds\n", + "\n", + "news_part = take_portion(news, MIX_NEWS if 'MIX_NEWS' in globals() else 0.5)\n", + "chat_part = take_portion(chat, MIX_CHAT if 'MIX_CHAT' in globals() else 0.5)\n", + "train_ds = concatenate_datasets([news_part, chat_part]).shuffle(seed=42)\n", + "\n", + "# Tiny validation built from remaining examples (if any)\n", + "remaining_news = news.select(range(len(news_part), len(news))) if len(news) > len(news_part) else news_part\n", + "remaining_chat = chat.select(range(len(chat_part), len(chat))) if len(chat) > len(chat_part) else chat_part\n", + "val_candidates = concatenate_datasets([remaining_news, remaining_chat])\n", + "val_ds = val_candidates.shuffle(seed=43).select(range(min(64, len(val_candidates)))) if len(val_candidates) else train_ds.select(range(min(32, len(train_ds))))\n", + "\n", + "dataset = {\"train\": train_ds, \"validation\": val_ds}\n", + "print({k: len(v) for k, v in dataset.items()})\n" + ] + }, + { + "cell_type": "markdown", + "id": "c95c9122", + "metadata": {}, + "source": [ + "## 8) 모델/토크나이저 로드 · Load Model & Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "db67b6b3", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1cfc411479e145e4b5b161df311d4b13", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ebea3ddd62e340cc83e2a484a04e3e89", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "330fd60c5e1248998f0f5bc8c394b2ce", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer.json: 0%| | 0.00/27.9M [00:00\n", + "{{ m['content'] }}<|end|>\n", + "{%- elif m['role'] == 'user' -%}<|user|>\n", + "{{ m['content'] }}<|end|>\n", + "{%- elif m['role'] == 'assistant' -%}<|assistant|>\n", + "{{ m['content'] }}<|end|>\n", + "{%- endif -%}\n", + "{%- endfor -%}\"\"\"\n", + "\n", + "# Ensure pad/eos are sane\n", + "tokenizer.pad_token = tokenizer.eos_token or tokenizer.pad_token\n", + "\n", + "# ---------- 7D) Tokenize with assistant-only labels ----------\n", + "ASST_TOKEN = None\n", + "END_TOKEN = None\n", + "try:\n", + " ASST_TOKEN = tokenizer.convert_tokens_to_ids(\"<|assistant|>\")\n", + " END_TOKEN = tokenizer.convert_tokens_to_ids(\"<|end|>\")\n", + "except Exception:\n", + " # If the base vocab lacks these tokens, it's okay; masking fallback below will still work heuristically\n", + " pass\n", + "\n", + "MAX_LEN = 2048 # you can raise this if you have room\n", + "\n", + "def tokenize_with_labels(example):\n", + " # 1) Render with chat template (includes assistant answer)\n", + " text = tokenizer.apply_chat_template(example[\"messages\"], tokenize=False, add_generation_prompt=False)\n", + " # 2) Tokenize\n", + " enc = tokenizer(text, truncation=True, max_length=MAX_LEN)\n", + " input_ids = enc[\"input_ids\"]\n", + " labels = [-100] * len(input_ids)\n", + "\n", + " # 3) Label only assistant content\n", + " if ASST_TOKEN is not None and END_TOKEN is not None:\n", + " start = None\n", + " for i, tid in enumerate(input_ids):\n", + " if tid == ASST_TOKEN:\n", + " start = i + 1 # learn after the tag\n", + " elif start is not None and tid == END_TOKEN:\n", + " start = None\n", + " elif start is not None:\n", + " labels[i] = input_ids[i]\n", + " else:\n", + " # Heuristic fallback: learn on the last third of tokens (crude but avoids total silence)\n", + " start = int(len(input_ids) * 0.66)\n", + " for i in range(start, len(input_ids)):\n", + " labels[i] = input_ids[i]\n", + "\n", + " return {\"input_ids\": input_ids, \"attention_mask\": enc[\"attention_mask\"], \"labels\": labels}\n", + "\n", + "tokenized_train = dataset[\"train\"].map(tokenize_with_labels, remove_columns=[\"messages\"])\n", + "tokenized_val = dataset[\"validation\"].map(tokenize_with_labels, remove_columns=[\"messages\"])\n", + "\n", + "print(\"Tokenization done.\",\n", + " \"train:\", len(tokenized_train),\n", + " \"val:\", len(tokenized_val),\n", + " \"example lens:\", tokenized_train[0][\"input_ids\"][:12], \"...\")" + ] + }, + { + "cell_type": "markdown", + "id": "f67dd4ef", + "metadata": {}, + "source": [ + "## 9) Fine‑Tuning (LoRA/QLoRA) · 세밀 튜닝\n", + "### 9a) Data curation & splits\n", + "_(See Section 7/8 for dataset prep; move relevant snippets here if needed.)_\n", + "### 9b) Hyperparameters (r/alpha/dropout)\n", + "```python\n", + "# Example LoRA hyperparameters\n", + "LORA_R = 8\n", + "LORA_ALPHA = 16\n", + "LORA_DROPOUT = 0.05\n", + "```\n", + "\n", + "### 9c) Merge adapters (BF16)\n", + "```python\n", + "# Example merge step (after training)\n", + "# model = PeftModel.from_pretrained(base_model, adapter_path)\n", + "# merged_model = model.merge_and_unload()\n", + "```\n", + "\n", + "### 9d) Save merged BF16 (`save_pretrained`)\n", + "```python\n", + "# merged_model.save_pretrained(OUTPUT_DIR)\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "id": "c9157315", + "metadata": {}, + "source": [ + "### 9e) Export & Quantize (BF16 → MXFP4) · 내보내기 & 양자화\n", + "\n", + "**EN (neutral, framework-agnostic):** \n", + "Public libraries currently do **not** support training/fine‑tuning *directly* in MXFP4. The common pipeline is:\n", + "1) **Train/SFT** in **BF16** (or **QLoRA 4‑bit nf4**). \n", + "2) **Merge LoRA adapters** into the base model (BF16). \n", + "3) **Save** the merged BF16 checkpoint with `save_pretrained()`. \n", + "4) **Post‑training quantize** the merged BF16 tensors to **MXFP4** using a **vendor/toolchain‑provided packer**. \n", + "5) **Save/export** the MXFP4 artifact (same shape as Hugging Face `save_pretrained()` output) for deployment/serving.\n", + "\n", + "> Notes: \n", + "> - If your serving stack supports **LoRA at inference**, you may skip merging and quantization and ship: **base (MXFP4 or BF16) + LoRA adapters**. \n", + "> - If your runtime requires **merged MXFP4**, you must run a **BF16 → MXFP4** quantization step after merging adapters. \n", + "> - Keep **tokenizer/config** files aligned across BF16 and MXFP4 exports.\n", + "\n", + "**KR (중립적, 도구 비의존):** \n", + "현재 공개 라이브러리는 MXFP4에서 **직접 학습/파인튜닝을 지원하지 않습니다**. 일반적인 파이프라인은 다음과 같습니다: \n", + "1) **BF16**(또는 **QLoRA 4‑bit nf4**)로 **학습/파인튜닝** \n", + "2) **LoRA 어댑터 병합**(BF16 기준) \n", + "3) `save_pretrained()`로 **병합된 BF16 체크포인트 저장** \n", + "4) 벤더/툴체인에서 제공하는 **양자화 도구**로 **BF16 → MXFP4 사후 양자화** \n", + "5) 배포/서빙용 **MXFP4 아티팩트 저장/내보내기** (Hugging Face `save_pretrained()` 구조와 동일)\n", + "\n", + "> 참고: \n", + "> - **서빙에서 LoRA를 지원**한다면, 병합·양자화를 생략하고 **기저( MXFP4 또는 BF16 ) + LoRA 어댑터**로 제공할 수 있습니다. \n", + "> - **병합된 MXFP4**가 필요한 런타임의 경우, 어댑터 병합 후 **BF16 → MXFP4 재양자화** 단계가 필요합니다. \n", + "> - **tokenizer/config** 파일은 BF16과 MXFP4 아티팩트 간에 일관되게 유지하세요.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "48a5cbc9", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fine‑tuning skeleton ready. Un‑comment on your machine.\n" + ] + } + ], + "source": [ + "from trl import SFTTrainer, SFTConfig\n", + "from peft import LoraConfig, get_peft_model\n", + "\n", + "lora_cfg = LoraConfig(\n", + " task_type=\"CAUSAL_LM\",\n", + " r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n", + " target_modules=TARGET_MODULES\n", + ")\n", + "\n", + "# base_model = get_peft_model(base_model, lora_cfg)\n", + "\n", + "sft_args = SFTConfig(\n", + " output_dir=OUTPUT_DIR,\n", + " num_train_epochs=EPOCHS,\n", + " per_device_train_batch_size=PER_DEVICE_BS,\n", + " gradient_accumulation_steps=GRAD_ACCUM,\n", + " learning_rate=LEARNING_RATE,\n", + " lr_scheduler_type=\"cosine\",\n", + " bf16=BF16,\n", + " logging_steps=LOG_STEPS,\n", + " save_steps=SAVE_STEPS,\n", + " save_total_limit=SAVE_TOTAL_LIMIT\n", + ")\n", + "\n", + "# trainer = SFTTrainer(model=base_model, args=sft_args, train_dataset=combined, tokenizer=tokenizer)\n", + "# trainer.train()\n", + "# trainer.save_model(OUTPUT_DIR)\n", + "print(\"Fine‑tuning skeleton ready. Un‑comment on your machine.\")" + ] + }, + { + "cell_type": "markdown", + "id": "490798f2", + "metadata": {}, + "source": [ + "## 10) 평가(뉴스/대화) · Evaluation (News/Chat)" + ] + }, + { + "cell_type": "markdown", + "id": "d1bdafe4", + "metadata": {}, + "source": [ + "**KR 지표 · KR Metrics** \n", + "- 뉴스성: 주제 분류 적합도(F1), 요약 품질(ROUGE‑1/2/L), 독해 QA(EM/F1). \n", + "- 대화성: 자연성/맥락 유지, 경어/반말 전환 정확도, 이모티콘/축약어 적절성.\n", + "\n", + "**EN Notes** \n", + "- Use public KR benchmarks (e.g., topic classification, KorQuAD‑like QA) where licenses permit.\n", + "- Mix automatic metrics (F1/ROUGE) with human eval for tone & politeness." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "971b8dbd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eval stubs ready.\n" + ] + } + ], + "source": [ + "# Example helpers (stub)\n", + "def simple_accuracy(preds, labels):\n", + " return sum(int(p==g) for p,g in zip(preds, labels)) / max(1, len(labels))\n", + "\n", + "# For ROUGE:\n", + "# import evaluate\n", + "# rouge = evaluate.load(\"rouge\")\n", + "# result = rouge.compute(predictions=pred_texts, references=ref_texts)\n", + "# print(result)\n", + "\n", + "print(\"Eval stubs ready.\")" + ] + }, + { + "cell_type": "markdown", + "id": "e0b5594e", + "metadata": {}, + "source": [ + "## 11) Inference Prompt Templates · 추론 프롬프트 템플릿" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1f690452", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\n", + "Knowledge cutoff: 2024-06\n", + "Current date: 2025-08-21\n", + "\n", + "Reasoning: medium\n", + "\n", + "# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions\n", + "\n", + "너는 한국 고객을 돕는 유능한 AI 어시스턴트다.\n", + "\n", + "<|end|><|start|>user<|message|>국내 PIPA 규정을 준수하면서 사내 문서 요약기를 구성하려면 어떤 아키텍처가 좋을까?<|end|><|start|>assistant\n" + ] + } + ], + "source": [ + "from openai_harmony import Message, ChatFormatter\n", + "\n", + "# Example prompt construction using Harmony\n", + "messages = [\n", + " Message(role=\"system\", content=\"너는 한국 고객을 돕는 유능한 AI 어시스턴트다.\"),\n", + " Message(role=\"user\", content=\"국내 PIPA 규정을 준수하면서 사내 문서 요약기를 구성하려면 어떤 아키텍처가 좋을까?\")\n", + "]\n", + "\n", + "prompt = ChatFormatter.to_chat_prompt(messages)\n", + "print(prompt) # For preview; pass to tokenizer when running inference\n" + ] + }, + { + "cell_type": "markdown", + "id": "5216d049", + "metadata": {}, + "source": [ + "## 12) 최신성 유지 · Freshness Strategy" + ] + }, + { + "cell_type": "markdown", + "id": "452decd1", + "metadata": {}, + "source": [ + "- **주간 보정 SFT**: 허용된 뉴스 API **메타데이터(제목/요약/섹션)** 샘플링 → 스타일 보정. \n", + "- **대화체 업데이트**: 최신 축약어/신조어/이모티콘 사전 반영(예: ㄱㄱ, ㅇㅋ, ㅋㅋ, ㄹㅇ). \n", + "- **회귀 평가**: 동일 지표로 before/after 비교 → 혼합비/온도/패널티 튜닝.\n", + "\n", + "- Weekly calibration SFT using **allowed news API metadata** for style; \n", + "- Update slang/emoji lexicons; \n", + "- Regression evals to track drift and adjust data mix/decoding." + ] + }, + { + "cell_type": "markdown", + "id": "718b9f2a", + "metadata": {}, + "source": [ + "## 13) 안전/컴플라이언스 · Safety & Compliance" + ] + }, + { + "cell_type": "markdown", + "id": "61ad24ef", + "metadata": {}, + "source": [ + "- 데이터 출처/라이선스 확인(벤치마크, API, 내부 데이터) · Verify dataset/API licenses.\n", + "- 개인정보 스크러빙(훈련/로그/평가 전) · Scrub PII before training/logging/eval.\n", + "- 저작권/약관 준수(기사 **원문 대량 재학습 금지**) · Avoid mass training on full news articles.\n", + "- 출력 검증(스키마/금칙어/민감도 규칙) · Output validation & forbidden‑term filters.\n", + "- 버전/평가 리포트 관리 · Version datasets/models and keep eval reports." + ] + }, + { + "cell_type": "markdown", + "id": "5cb8464b", + "metadata": {}, + "source": [ + "## 14) 문제해결 & 다음 단계 · Troubleshooting & Next Steps" + ] + }, + { + "cell_type": "markdown", + "id": "8ee17077", + "metadata": {}, + "source": [ + "- 혼합 비율 튜닝: (뉴스:대화) 6:4 → 7:3 또는 5:5로 조정 \n", + "- LoRA 하이퍼파라미터: r=8~16, α=16~32, dropout=0.05~0.1 \n", + "- 서비스화: vLLM/llama.cpp 서빙 + 토픽/스타일 라우팅 \n", + "- RAG 결합: 최신 사실성 보강을 위해 뉴스/문서 인덱스 결합 \n", + "- A/B 테스트: 톤/길이/이모티콘 사용량 등 사용자 만족도 측정\n", + "\n", + "- Tune mix ratios, run A/B tests, consider vLLM serving, and pair with RAG for factuality." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/registry.yaml b/registry.yaml index aafb4a938c..b462dbd069 100644 --- a/registry.yaml +++ b/registry.yaml @@ -4,6 +4,17 @@ # should build pages for, and indicates metadata such as tags, creation date and # authors for each page. +- title: "Fine-tune gpt-oss for better Korean language performance" + path: articles/gpt-oss/fine-tune-korean.ipynb + description: "Guide to fine-tuning an open-weight model for Korean and workflow tips." + authors: + - heejingithub + - danial-openai + - joanneshin-openai + tags: + - gpt-oss + - open-models + - title: Verifying gpt-oss implementations path: articles/gpt-oss/verifying-implementations.md date: 2025-08-11