Ssul's Blog
Gemma3 finetuning(파인튜닝)하기 본문
최근에 진행하는 R&D프로젝트가
사전학습된 LLM의 지식을 활용하여, 약 4,000~5,000건의 데이터만 학습하여,
특정 도메인에 특화된 과업을 해결하는 모델을 만드는 것이다.
(예: 스팸문자 분류하는 LLM)
그래서 한글을 잘한다고 소문난 모델을 이것 저것 파인튜닝 해보고 있다.
- EEVE, Qwen2.5의 경우 gpt-4o-mini를 api로 파인튜닝 한 모델보다 성능이 떨어졌다.
- EXAONE3.5를 기점으로 gpt-4o-mini와 비슷하거나 높게 나오기 시작했다.
1. Gemma3 발표
이놈의 AI쪽은 허구헛날 새로운 모델이 나오고, 기존 성능을 갱신한다.
EXAONE에서 만족하고 다음 진도를 나가려 했는데..... 그래도 Gemma3가 나왔다고 하니 파인튜닝을 안할수 없었다.
코드를 열심히 검색해봐도 대부분 unsloth코드 밖에 없다.
그래서 gpt/claude와 함께 파인튜닝을 진행해보기로 하였다.
2. Gemma3 차이점
EXAONE을 파인튜닝하며, alpaca템플릿이 아닌 chat_template를 경험했는데,
이번 Gemma3는 멀티모달 모델이라
- tokenizer대신 processor가 있다는 것과
- 이미지를 다루다보니, 기존 chat_template보다 한차원 더 있다는 것을 기억해야 한다
(파인튜닝을 마치고 나서, 텍스트모델만 파인튜닝 하는 방법을 설명한 자료를 찾기는 하였다. 하지만, 디테일하게 파보자)
# pip install accelerate
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/gemma-3-12b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_id)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
}
]
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
# **Overall Impression:** The image is a close-up shot of a vibrant garden scene,
# focusing on a cluster of pink cosmos flowers and a busy bumblebee.
# It has a slightly soft, natural feel, likely captured in daylight.
(huggingface 문서를 보면, tokenizer가 아닌 processor를 사용하는 것을 확인할수 있다. 멀티모달 모델이라 그런듯 하다)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
},
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
{"type": "text", "text": "What animal is on the candy?"}
]
}
]
(템플릿도, content부분이 type으로 텍스트, 이미지를 구분하고 text에 본내용이 있는 것을 확인할 수 있다)
3. Gemma3파인튜닝하기(예: 스팸문자 분류모델 만들기) - 코드첨부로 길어요~
3-1. 프롬프트 및 하이퍼파라미터 설정(config.py)
import torch
import wandb
# Model configuration
MODEL_ID = "google/gemma-3-12b-it" # Gemma3 모델로 변경
HF_TOKEN = "여러분의 허깅페이스 토큰"
# wandb API 키 설정
WANDB_API_KEY = "여러분 wandb api"
# 데이터셋 설정 추가
DATASET_NAME = "허깅페이스 데이터셋 주소"
# 제 데이터 셋은 train만 있습니다. test용 데이터는 다른주소에 따로 저장
# Training configuration
CUTOFF_LEN = 4098
TRAIN_ON_INPUTS = False
ADD_EOS_TOKEN = False
VAL_SIZE = 0.005
# LoRA configuration
LORA_CONFIG = {
"r": 16,
"lora_alpha": 16,
"target_modules": ["q_proj", "k_proj", "v_proj"],
"lora_dropout": 0.05,
"bias": "none",
"task_type": "CAUSAL_LM",
}
# Training arguments
TRAINING_ARGS = {
"output_dir": "./Spam_Gemma3-12b-it",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
"gradient_accumulation_steps": 8,
"warmup_ratio": 0.03, # 비율 기반 웜업으로 변경
# "max_steps": 1000, # 에포크 대신 최대 스텝 사용
"num_train_epochs": 1, # 에포크 수만 사용
"learning_rate": 5e-5,
"adam_beta1": 0.9,
"adam_beta2": 0.999, # 표준 값으로 변경
"weight_decay": 0.01, # 가중치 감쇠 추가
"fp16": False,
"bf16": True,
"logging_steps": 10,
"optim": "adamw_torch", # 표준 AdamW 옵티마이저 사용
"lr_scheduler_type": "cosine", # 코사인 스케줄러 명시적 설정
"evaluation_strategy": "steps",
"eval_steps": 100,
"save_strategy": "steps",
"save_steps": 100,
"load_best_model_at_end": True,
"metric_for_best_model": "loss", # 손실 기준으로 최적 모델 선택
"greater_is_better": False, # 손실은 낮을수록 좋음
"group_by_length": False,
"report_to": "wandb",
"run_name": "Spam_Gemma3-12b-it",
"remove_unused_columns": False, # 모든 열 유지
"gradient_checkpointing": False, # 메모리 효율성 향상
"torch_compile": False, # 컴파일 비활성화 (안정성 향상)
"ddp_find_unused_parameters": False, # DDP 최적화
"dataloader_num_workers": 0, # 데이터 로더 워커 수 감소 (안정성 향상)
"seed": 42, # 재현성을 위한 시드 설정
"debug": "underflow_overflow", # 언더플로우/오버플로우 감지
"full_determinism": True, # 완전한 결정론적 동작
}
# Gemma3 모델용 시스템 프롬프트
SYSTEM_PROMPT = """당신은 스팸분류기 모델입니다.
문자내용을 보고, 아래 5가지 결정 카테고리 중에서 **오직 하나**를 골라
정확히 해당 이름만 출력하세요.
카테고리 목록:
- 스팸문자
- 프로모션문자
- 정치문자
- 인증문자
- 개인문자
### 필수 지침
1. **출력 형식**: 카테고리 이름만 정확히 출력.
2. **추가 텍스트**: 어떤 부연 설명, 문장, 구두점도 포함하지 말 것.
3. **불확실할 경우**: 최대한 가장 근접한 카테고리를 선택.
---
아래에 문자메세지가 주어집니다.
이 문자를 토대로 가장 적합한 카테고리를 **하나만** 골라주세요.
"""
# Template for instruction tuning
INSTRUCT_TEMPLATE = {
"prompt_input": "{instruction}",
"prompt_no_input": "{instruction}",
}
# W&B 초기화
wandb.login(key=WANDB_API_KEY) # API 키로 로그인
wandb.init(
project="SpamAI", # 원하는 프로젝트 이름으로 변경
name=TRAINING_ARGS.get("run_name"), # run 이름 설정
config=TRAINING_ARGS, # 설정값 로깅
)
config.py에는 딱히 다른 파인튜닝과 다른점이 별로 없습니다. 하이퍼파라미터만 학습시켜보면서 조정하면 될듯.
저의 경우 학습률(lr)을 너무 작게 잡아서 학습이 안일어난(너무 늦음) 것 같이 보이는 상황이 있었습니다. lr크기 잘 조절하기
3-2. 데이터 셋팅하기(data.py) - 이게 중요, processor사용해서 ㅜㅜ
import pandas as pd
from datasets import Dataset, load_dataset as hf_load_dataset
from typing import Union, List, Dict, Any
from sklearn.model_selection import train_test_split
from transformers import DataCollatorForSeq2Seq
from config import (
INSTRUCT_TEMPLATE,
CUTOFF_LEN,
TRAIN_ON_INPUTS,
ADD_EOS_TOKEN,
SYSTEM_PROMPT,
)
import random
import logging
import requests
from PIL import Image
from io import BytesIO
class Prompter:
def __init__(self, verbose: bool = False):
self.template = INSTRUCT_TEMPLATE
self.system_prompt = SYSTEM_PROMPT
# 유효한 카테고리 목록 정의
self.valid_categories = [
"스팸문자",
"프로모션문자",
"정치문자",
"인증문자",
"개인문자",
]
# Gemma3 모델의 메시지 형식에 맞게 프롬프트 생성
def generate_chat_prompt(
self,
instruction: str,
label: Union[None, str] = None,
image_path: Union[None, str] = None,
) -> List[Dict[str, Any]]:
"""Gemma3 모델의 chat_template 형식에 맞게 프롬프트 생성
Args:
instruction: 사용자 지시문
label: 모델의 응답 (학습 시 정답으로 사용)
image_path: 이미지 경로 (로컬 경로 또는 URL)
"""
# 시스템 메시지
messages = [
{
"role": "system",
"content": [{"type": "text", "text": self.system_prompt}]
}
]
# 사용자 메시지 (이미지가 있는 경우 멀티모달 입력으로 구성)
user_content = []
# 이미지 추가 (존재하는 경우)
if image_path:
user_content.append({"type": "image", "image": image_path})
# 텍스트 추가
user_content.append({"type": "text", "text": instruction})
# 사용자 메시지 완성
messages.append({
"role": "user",
"content": user_content
})
# 어시스턴트 응답 (있는 경우)
if label:
messages.append({
"role": "assistant",
"content": [{"type": "text", "text": label}]
})
return messages
def extract_generated_text(self, processor, full_outputs, inputs):
"""
모델의 전체 출력에서 실제로 생성된 텍스트만 추출합니다.
Args:
processor: 사용된 프로세서
full_outputs: 모델이 생성한 전체 출력 토큰 (입력 + 생성)
inputs: 모델에 입력된 토큰
Returns:
dict: 입력 텍스트, 전체 출력, 생성된 텍스트를 포함하는 딕셔너리
"""
# 입력 길이 계산
input_length = 0
input_text_decoded = ""
# BatchFeature 객체인지 확인
if hasattr(inputs, "data") and isinstance(inputs.data, dict) and "input_ids" in inputs.data:
# input_ids 추출
input_ids_data = inputs.data["input_ids"]
# 텐서인 경우
if hasattr(input_ids_data, "shape"):
# GPU 텐서인 경우 CPU로 이동
if input_ids_data.device.type != 'cpu':
input_ids_cpu = input_ids_data.cpu()
else:
input_ids_cpu = input_ids_data
# 배치 차원이 있는 경우 첫 번째 항목 사용
if len(input_ids_cpu.shape) > 1:
input_ids_one = input_ids_cpu[0]
input_length = input_ids_one.shape[0]
input_text_decoded = processor.decode(input_ids_one, skip_special_tokens=True)
else:
input_length = input_ids_cpu.shape[0]
input_text_decoded = processor.decode(input_ids_cpu, skip_special_tokens=True)
# 리스트인 경우
elif isinstance(input_ids_data, list):
# 중첩 리스트 [[...]] 구조인 경우
if len(input_ids_data) > 0 and isinstance(input_ids_data[0], list):
input_ids_one = input_ids_data[0] # 첫 번째 배치 항목 사용
input_length = len(input_ids_one)
try:
input_text_decoded = processor.decode(input_ids_one, skip_special_tokens=True)
except Exception as e:
logging.warning(f"입력 디코딩 중 오류: {e}")
input_text_decoded = f"디코딩 오류: {str(e)}"
else:
# 일반 1차원 리스트인 경우
input_length = len(input_ids_data)
try:
input_text_decoded = processor.decode(input_ids_data, skip_special_tokens=True)
except Exception as e:
logging.warning(f"입력 디코딩 중 오류: {e}")
input_text_decoded = f"디코딩 오류: {str(e)}"
else:
# 일반 딕셔너리인 경우
try:
input_ids = inputs["input_ids"]
# 텐서인 경우
if hasattr(input_ids, "shape"):
# GPU 텐서인 경우 CPU로 이동
if input_ids.device.type != 'cpu':
input_ids_cpu = input_ids.cpu()
else:
input_ids_cpu = input_ids
# 배치 차원이 있는 경우 첫 번째 항목 사용
if len(input_ids_cpu.shape) > 1:
input_ids_one = input_ids_cpu[0]
input_length = input_ids_one.shape[0]
input_text_decoded = processor.decode(input_ids_one, skip_special_tokens=True)
else:
input_length = input_ids_cpu.shape[0]
input_text_decoded = processor.decode(input_ids_cpu, skip_special_tokens=True)
# 리스트인 경우
elif isinstance(input_ids, list):
# 중첩 리스트 구조인 경우
if len(input_ids) > 0 and isinstance(input_ids[0], list):
input_ids_one = input_ids[0] # 첫 번째 배치 항목 사용
input_length = len(input_ids_one)
try:
input_text_decoded = processor.decode(input_ids_one, skip_special_tokens=True)
except Exception as e:
logging.warning(f"입력 디코딩 중 오류: {e}")
input_text_decoded = f"디코딩 오류: {str(e)}"
else:
# 일반 1차원 리스트인 경우
input_length = len(input_ids)
try:
input_text_decoded = processor.decode(input_ids, skip_special_tokens=True)
except Exception as e:
logging.warning(f"입력 디코딩 중 오류: {e}")
input_text_decoded = f"디코딩 오류: {str(e)}"
except Exception as e:
logging.warning(f"입력 길이 계산 중 오류: {e}")
input_length = 0
input_text_decoded = f"입력 추출 오류: {str(e)}"
# 출력 처리 - 항상 텐서로 가정
try:
# full_outputs가 텐서인 경우 (일반적인 경우)
if hasattr(full_outputs, "shape"):
# GPU 텐서인 경우 CPU로 이동
if full_outputs.device.type != 'cpu':
outputs_cpu = full_outputs.cpu()
else:
outputs_cpu = full_outputs
# 배치 차원이 있는 경우 첫 번째 항목 사용
if len(outputs_cpu.shape) > 1:
outputs_one = outputs_cpu[0]
else:
outputs_one = outputs_cpu
# 생성된 텍스트 추출
if input_length > 0:
generation = outputs_one[input_length:]
else:
generation = outputs_one
# 디코딩
generated_text = processor.decode(generation, skip_special_tokens=True)
full_output = processor.decode(outputs_one, skip_special_tokens=True)
# 리스트인 경우
elif isinstance(full_outputs, list):
# 중첩 리스트 구조인 경우 [[...]]
if len(full_outputs) > 0 and isinstance(full_outputs[0], list):
outputs_one = full_outputs[0] # 첫 번째 배치 항목 사용
else:
outputs_one = full_outputs
# 생성된 텍스트 추출
if input_length > 0 and input_length < len(outputs_one):
generation = outputs_one[input_length:]
else:
generation = outputs_one
# 디코딩
generated_text = processor.decode(generation, skip_special_tokens=True)
full_output = processor.decode(outputs_one, skip_special_tokens=True)
else:
# 예상치 못한 형식
logging.warning(f"예상치 못한, full_outputs 구조: {type(full_outputs)}")
generated_text = "출력 형식 오류"
full_output = "출력 형식 오류"
except Exception as e:
logging.error(f"출력 처리 중 오류: {e}")
generated_text = f"출력 처리 오류: {str(e)}"
full_output = f"출력 처리 오류: {str(e)}"
return {
"input_decoded": input_text_decoded,
"full_output": full_output,
"generated_text": generated_text,
}
def load_dataset(dataset_name: str):
"""
Load a dataset from Hugging Face Hub
Args:
dataset_name: Name of the dataset on Hugging Face Hub
Returns:
train and test splits of the dataset
"""
dataset = hf_load_dataset(dataset_name)
if "train" in dataset and "test" in dataset:
return dataset["train"], dataset["test"]
elif "train" in dataset:
train_test_split_data = dataset["train"].train_test_split(
test_size=0.2, shuffle=True, seed=42
)
return train_test_split_data["train"], train_test_split_data["test"]
else:
# If dataset has a different structure, split the default split
default_split = next(iter(dataset.values()))
train_test_split_data = default_split.train_test_split(
test_size=0.2, shuffle=True, seed=42
)
return train_test_split_data["train"], train_test_split_data["test"]
def generate_and_tokenize_chat_prompt(data_point, prompter, processor, image_field=None, is_sample=False):
"""Gemma3 멀티모달 모델용 채팅 형식 프롬프트 토큰화 함수
Args:
data_point: 데이터셋의 한 항목
prompter: Prompter 클래스 인스턴스
processor: Gemma3 프로세서
image_field: 이미지 경로가 들어있는 필드 이름 (없으면 None)
is_sample: 샘플 데이터인지 여부 (True인 경우에만 로깅)
"""
# 1. 원본 데이터에서 필요한 정보 추출
instruction = data_point["문자"]
response = data_point["문자분류 라벨링"]
# 이미지 경로 (존재하는 경우)
image_path = data_point.get(image_field) if image_field else None
# 2. 채팅 형식 메시지 생성 - [시스템 메시지, 사용자 메시지, 어시스턴트 메시지] 구조
messages = prompter.generate_chat_prompt(
instruction=instruction,
label=response,
image_path=image_path
)
# 샘플 데이터인 경우에만 로깅
if is_sample:
# 생성된 메시지 구조 로깅
logging.info(f"생성된 메시지 구조: {len(messages)}개 메시지")
for i, msg in enumerate(messages):
logging.info(f"메시지 {i+1} 역할: {msg['role']}")
# 채팅 템플릿 적용 및 토큰화
tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True, # 생성을 위한 프롬프트 추가
tokenize=True, # 토큰화 수행
return_tensors=None, # 텐서 변환 안함
padding=False, # 패딩 안함
return_dict=True # 딕셔너리 형태로 반환
)
# 토큰화 결과 상세 정보 - 샘플 데이터인 경우에만 로깅
# BatchFeature 객체인 경우 처리
token_length = 0
if hasattr(tokenized, "data") and isinstance(tokenized.data, dict) and "input_ids" in tokenized.data:
input_ids_list = tokenized.data["input_ids"]
# input_ids는 리스트의 리스트 형태 [[...]] - 배치 차원이 있음
if isinstance(input_ids_list, list) and len(input_ids_list) > 0 and isinstance(input_ids_list[0], list):
input_ids = input_ids_list[0] # 첫 번째(유일한) 배치 항목 사용
token_length = len(input_ids)
# 샘플 데이터인 경우에만 로깅
if is_sample:
logging.info(f"토크나이징 - input_ids 실제 길이: {token_length}")
# 토큰 ID 통계
if token_length > 0:
logging.info(f"토큰 통계 - 최소값: {min(input_ids)}, 최대값: {max(input_ids)}")
# 특수 토큰 개수 (샘플 데이터에만 적용)
if processor.tokenizer.eos_token_id is not None:
eos_count = input_ids.count(processor.tokenizer.eos_token_id)
logging.info(f"EOS 토큰 개수: {eos_count}")
if processor.tokenizer.bos_token_id is not None:
bos_count = input_ids.count(processor.tokenizer.bos_token_id)
logging.info(f"BOS 토큰 개수: {bos_count}")
else:
# 예상과 다른 구조인 경우 처리
if is_sample:
logging.warning(f"예상과 다른 input_ids 구조: {type(input_ids_list)}")
token_length = 0
# 딕셔너리인 경우 처리 (이전 코드와 동일)
elif isinstance(tokenized, dict) and "input_ids" in tokenized:
token_length = len(tokenized['input_ids'])
# 샘플 데이터인 경우에만 로깅
if is_sample:
logging.info(f"토크나이징 - input_ids 길이: {token_length}")
# 토큰 ID 통계
if token_length > 0:
logging.info(f"토큰 통계 - 최소값: {min(tokenized['input_ids'])}, 최대값: {max(tokenized['input_ids'])}")
# 특수 토큰 개수 (샘플 데이터에만 적용)
if processor.tokenizer.eos_token_id is not None:
eos_count = tokenized['input_ids'].count(processor.tokenizer.eos_token_id)
logging.info(f"EOS 토큰 개수: {eos_count}")
if processor.tokenizer.bos_token_id is not None:
bos_count = tokenized['input_ids'].count(processor.tokenizer.bos_token_id)
logging.info(f"BOS 토큰 개수: {bos_count}")
elif is_sample:
logging.warning("토큰화 결과가 예상 형식이 아닙니다. 토큰 정보를 출력할 수 없습니다.")
token_length = 0
# 라벨링 처리 (응답 부분만 학습하도록)
if not TRAIN_ON_INPUTS:
# 응답 없는 프롬프트 생성
input_messages = prompter.generate_chat_prompt(
instruction=instruction,
image_path=image_path
)
# 응답 없는 프롬프트 토큰화
input_tokenized = processor.apply_chat_template(
input_messages,
add_generation_prompt=True,
tokenize=True,
return_tensors=None,
padding=False,
return_dict=True # 반환값을 딕셔너리로 설정
)
# 샘플 데이터인 경우에만 로깅
if is_sample:
# input_tokenized 타입 확인 추가
logging.info(f"input_tokenized 타입: {type(input_tokenized)}")
if hasattr(input_tokenized, "__dict__"):
logging.info(f"input_tokenized 속성들: {input_tokenized.__dict__.keys()}")
# 입력 부분 길이
input_len = 0
try:
# BatchFeature 객체인 경우 처리
if hasattr(input_tokenized, "data") and isinstance(input_tokenized.data, dict) and "input_ids" in input_tokenized.data:
input_ids_list = input_tokenized.data["input_ids"]
# input_ids는 리스트의 리스트 형태 [[...]] - 배치 차원이 있음
if isinstance(input_ids_list, list) and len(input_ids_list) > 0 and isinstance(input_ids_list[0], list):
input_ids = input_ids_list[0] # 첫 번째(유일한) 배치 항목 사용
input_len = len(input_ids)
# 샘플 데이터인 경우에만 로깅
if is_sample:
logging.info(f"토크나이징 - 입력 부분 실제 길이: {input_len}")
# 전체 길이 중 입력이 차지하는 비율
if token_length > 0:
input_ratio = input_len / token_length
logging.info(f"입력 vs 전체 비율: {input_ratio:.2%}")
else:
# 예상과 다른 구조인 경우 처리
if is_sample:
logging.warning(f"예상과 다른 input_tokenized.data['input_ids'] 구조: {type(input_ids_list)}")
input_len = 0
# 딕셔너리인 경우 처리 (이전 코드와 동일)
elif isinstance(input_tokenized, dict) and "input_ids" in input_tokenized:
input_len = len(input_tokenized["input_ids"])
# 샘플 데이터인 경우에만 로깅
if is_sample:
logging.info(f"토크나이징 - 입력 부분 길이: {input_len}")
# 전체 길이 중 입력이 차지하는 비율
if token_length > 0:
input_ratio = input_len / token_length
logging.info(f"입력 vs 전체 비율: {input_ratio:.2%}")
# 문자열인 경우 처리 (이전 코드와 동일)
elif isinstance(input_tokenized, str):
# 문자열로 반환된 경우 처리 방법 (예: 직접 토큰화)
input_tokens = processor.tokenizer(input_tokenized, return_tensors="pt")
input_len = len(input_tokens["input_ids"][0])
# 샘플 데이터인 경우에만 로깅
if is_sample:
logging.info(f"토크나이징 - 문자열 토큰화 후 입력 길이: {input_len}")
elif is_sample:
logging.warning("입력 토큰화 결과가 예상 형식이 아닙니다.")
input_len = 0
except Exception as e:
if is_sample:
logging.error(f"토크나이징 에러: {e}")
# 에러 발생 시 기본값으로 대체
input_len = 0
# 4-4. 입력 부분은 -100으로 마스킹 (손실 계산에서 제외)
# 여기서 labels 배열이 생성됨:
# - 처음 input_len 개는 -100 (문자 부분, 손실 계산 제외)
# - 그 이후는 input_ids 값 그대로
# BatchFeature 객체인 경우
if hasattr(tokenized, "data") and isinstance(tokenized.data, dict) and "input_ids" in tokenized.data:
input_ids_list = tokenized.data["input_ids"]
# input_ids는 리스트의 리스트 형태 [[...]] - 배치 차원이 있음
if isinstance(input_ids_list, list) and len(input_ids_list) > 0 and isinstance(input_ids_list[0], list):
input_ids = input_ids_list[0] # 첫 번째(유일한) 배치 항목 사용
# labels도 같은 구조로 생성
labels = [-100] * input_len + input_ids[input_len:]
tokenized.data["labels"] = [labels] # 2차원 형태로 설정
# 샘플 데이터인 경우에만 로깅
if is_sample:
# 라벨 마스킹 정보
non_masked = sum(1 for x in labels if x != -100)
logging.info(f"손실 계산 대상 토큰 수: {non_masked}/{token_length} ({non_masked/token_length:.2%})")
elif is_sample:
logging.warning("라벨링을 위한 input_ids의 구조가 예상과 다릅니다.")
# 딕셔너리인 경우
elif isinstance(tokenized, dict) and "input_ids" in tokenized:
tokenized["labels"] = [-100] * input_len + tokenized["input_ids"][input_len:]
# 샘플 데이터인 경우에만 로깅
if is_sample:
# 라벨 마스킹 정보
non_masked = sum(1 for x in tokenized["labels"] if x != -100)
logging.info(f"손실 계산 대상 토큰 수: {non_masked}/{token_length} ({non_masked/token_length:.2%})")
elif is_sample:
logging.warning("토큰화 결과가 예상 형식이 아니므로 라벨링을 수행할 수 없습니다.")
else:
# 전체 시퀀스에 대해 학습 (입력 포함) - 문자도 학습 대상으로 설정
# BatchFeature 객체인 경우
if hasattr(tokenized, "data") and isinstance(tokenized.data, dict) and "input_ids" in tokenized.data:
input_ids_list = tokenized.data["input_ids"]
# input_ids는 리스트의 리스트 형태 [[...]] - 배치 차원이 있음
if isinstance(input_ids_list, list) and len(input_ids_list) > 0 and isinstance(input_ids_list[0], list):
# 첫 번째(유일한) 배치 항목 복사
tokenized.data["labels"] = [input_ids_list[0].copy()]
elif is_sample:
logging.warning("TRAIN_ON_INPUTS=True 처리 중 예상치 못한 input_ids 구조")
# 딕셔너리인 경우
elif isinstance(tokenized, dict) and "input_ids" in tokenized:
tokenized["labels"] = tokenized["input_ids"].copy()
elif is_sample:
logging.warning("토큰화 결과가 예상 형식이 아니므로 라벨링을 수행할 수 없습니다.")
# 샘플 데이터인 경우에만 로깅
if is_sample:
logging.info("전체 시퀀스에 대해 학습 진행 (TRAIN_ON_INPUTS=True)")
# 5. 최종 토큰화된 결과 반환 - DataCollator로 전달됨
return tokenized
- generate_chat_prompt이 기존 Chat_template와 유사하지만, content구조가 다르다. 그리고 멀티모달이기 때문에 type: image가 있다.
-(중요) 하나하나 까보느라 코드가 길다. 사실 구조를 알았기 때문에 삭제해도 되지만, 참고용으로 남겨 놓는다. 아래 내용만 확인하셔도 된다.
기존의 많은 LLM은 "나는 스팸 문자 입니다"를 토크나이징 하면 아래와 같다
tokenized = tokenizer.apply_chat_template(
messages, #여기에 시스템 메세지와 user메세지 들어감
tokenize=True,
add_generation_prompt=True,
return_tensors=None,
padding=False,
truncation=True,
max_length=CUTOFF_LEN,
)
# tokenized는 아래와 같은 구조
{
'input_ids': [01, 02, 88,.....]
'attention_mask': [1,1,1,1,1....]
'labels': [01, 02, 88,....]
}
기본적으로 input_ids, attention_mask, labels를 가지고 있는 구조.
근데 gemma3를 토크나이징 해보면 아래와 같다. 아마 이미지도 함께 처리하는 멀티모달 모델이기 때문에 그런것 같다
두가지 특징이 있다.
특징1. tokenized.data에 'input_ids', 'attention_mask', 'label'이 있다.
특징2. input_ids/attemtion_mask, label이 [[01, 02, 88,...]] 중첩 리스트다.
그래서 최종적으로 tokenized.data['input_ids'][0]이 [01, 02, 88,...]이다
tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True, # 생성을 위한 프롬프트 추가
tokenize=True, # 토큰화 수행
return_tensors=None, # 텐서 변환 안함
padding=False, # 패딩 안함
return_dict=True # 딕셔너리 형태로 반환
)
tokenized.data가 중첩리스트(아마도 이미지도 처리해야 하기 때문에)
{
'input_ids': [[01, 02, 88,.....]]
'attention_mask': [[1,1,1,1,1....]]
'labels': [[01, 02, 88,....]]
}
data.py의 무수히 많은 if문들은 위 구조를 알아내기 위해, 그리고 오류를 찾기위한 검토용 코드라고 보면 되겠다.
3-3. 모델 양자화하고, LoRA붙이기(model.py)
- Gemma3모델 양자화가 가능하다.
import torch
from transformers import Gemma3ForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import get_peft_model, prepare_model_for_kbit_training, LoraConfig
from config import MODEL_ID, LORA_CONFIG
import logging
def setup_processor():
"""Initialize and configure the processor for multimodal inputs"""
logging.info(f"멀티모달 프로세서 로딩 중 - {MODEL_ID}")
processor = AutoProcessor.from_pretrained(MODEL_ID)
return processor
def print_model_structure(model, title="모델 구조"):
"""모델의 구조를 출력하는 함수"""
logging.info(f"===== {title} =====")
print(f"\n===== {title} =====")
print(model)
# 모델의 주요 속성과 레이어 정보 출력
for name, module in model.named_modules():
if len(name.split('.')) <= 3: # 너무 깊은 레이어는 출력하지 않음
print(f"{name}: {type(module).__name__}")
print(f"===== {title} 끝 =====\n")
def setup_model():
"""Initialize and configure the multimodal model with LoRA"""
logging.info("모델 양자화 설정 중...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_storage=torch.uint8,
)
logging.info(f"멀티모달 베이스 모델 로딩 중 - {MODEL_ID}")
model = Gemma3ForConditionalGeneration.from_pretrained(
MODEL_ID,
quantization_config=quantization_config,
attn_implementation="eager", # eager(일반적인 PyTorch 방식), sdpa(효율적인 Attention 전용 연산, FP16/BF16 최적화 잘 됨)
torch_dtype=torch.bfloat16,
device_map="auto", # 자동 장치 매핑 사용
trust_remote_code=True,
)
# 모델 설정 확인 및 수정
model.config.use_cache = False # 그래디언트 체크포인팅과 호환되도록 캐시 비활성화
# vocab_size 확인 및 로깅 (Gemma3 모델은 text_config 내에 vocab_size가 있음)
try:
# 먼저 text_config에서 vocab_size 확인 시도
if hasattr(model.config, 'text_config') and hasattr(model.config.text_config, 'vocab_size'):
vocab_size = model.config.text_config.vocab_size
logging.info(f"모델 vocab_size (text_config): {vocab_size:,}")
print(f"모델 vocab_size (text_config): {vocab_size:,}")
# 직접 vocab_size 접근 시도
elif hasattr(model.config, 'vocab_size'):
vocab_size = model.config.vocab_size
logging.info(f"모델 vocab_size: {vocab_size:,}")
print(f"모델 vocab_size: {vocab_size:,}")
# 모델의 임베딩 레이어에서 추출 시도
elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'):
vocab_size = model.model.embed_tokens.num_embeddings
logging.info(f"모델 vocab_size (임베딩 레이어): {vocab_size:,}")
print(f"모델 vocab_size (임베딩 레이어): {vocab_size:,}")
else:
logging.warning("모델에서 vocab_size를 찾을 수 없습니다.")
print("모델에서 vocab_size를 찾을 수 없습니다.")
except Exception as e:
logging.error(f"vocab_size 확인 중 오류 발생: {e}")
print(f"vocab_size 확인 중 오류 발생: {e}")
# 모델 구성 정보 로깅
logging.info(f"모델 구성 정보:")
for key, value in vars(model.config).items():
if not key.startswith('_'):
logging.info(f" {key}: {value}")
# LoRA 적용 전 모델 구조 출력
print_model_structure(model, "LoRA 적용 전 모델 구조")
logging.info("Preparing model for k-bit training...")
model = prepare_model_for_kbit_training(model)
logging.info("LoRA 설정 적용 중...")
# config.py에서 LORA_CONFIG 사용
lora_config = LoraConfig(**LORA_CONFIG)
# LoRA 적용
model = get_peft_model(model, lora_config)
# 학습 가능한 파라미터 확인
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
logging.info(f"총 파라미터: {all_params:,}")
logging.info(f"학습 가능한 파라미터: {trainable_params:,} ({trainable_params/all_params:.2%})")
# LoRA 적용 후 모델 구조 출력
print_model_structure(model, "LoRA 적용 후 모델 구조")
return model
# 양자화 없이 원래 모델에 LoRA를 적용하는 함수
def setup_model_without_quantization():
# Initialize and configure the model with LoRA without quantization
logging.info(f"Loading base model from {MODEL_ID} without quantization")
model = Gemma3ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # bfloat16 precision for efficiency
device_map={"": 0}, # GPU 0에 모델 로드
trust_remote_code=True,
)
# vocab_size 확인 및 로깅 (Gemma3 모델은 text_config 내에 vocab_size가 있음)
try:
# 먼저 text_config에서 vocab_size 확인 시도
if hasattr(model.config, 'text_config') and hasattr(model.config.text_config, 'vocab_size'):
vocab_size = model.config.text_config.vocab_size
logging.info(f"양자화 없는 모델 vocab_size (text_config): {vocab_size:,}")
print(f"양자화 없는 모델 vocab_size (text_config): {vocab_size:,}")
# 직접 vocab_size 접근 시도
elif hasattr(model.config, 'vocab_size'):
vocab_size = model.config.vocab_size
logging.info(f"양자화 없는 모델 vocab_size: {vocab_size:,}")
print(f"양자화 없는 모델 vocab_size: {vocab_size:,}")
# 모델의 임베딩 레이어에서 추출 시도
elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'):
vocab_size = model.model.embed_tokens.num_embeddings
logging.info(f"양자화 없는 모델 vocab_size (임베딩 레이어): {vocab_size:,}")
print(f"양자화 없는 모델 vocab_size (임베딩 레이어): {vocab_size:,}")
else:
logging.warning("양자화 없는 모델에서 vocab_size를 찾을 수 없습니다.")
print("양자화 없는 모델에서 vocab_size를 찾을 수 없습니다.")
except Exception as e:
logging.error(f"양자화 없는 모델 vocab_size 확인 중 오류 발생: {e}")
print(f"양자화 없는 모델 vocab_size 확인 중 오류 발생: {e}")
# 모델 구성 정보 로깅
logging.info(f"양자화 없는 모델 구성 정보:")
for key, value in vars(model.config).items():
if not key.startswith('_'):
logging.info(f" {key}: {value}")
# LoRA 적용 전 모델 구조 출력
print_model_structure(model, "양자화 없이 LoRA 적용 전 모델 구조")
logging.info("Applying LoRA configuration...")
lora_config = LoraConfig(**LORA_CONFIG)
model = get_peft_model(model, lora_config)
# LoRA 적용 후 모델 구조 출력
print_model_structure(model, "양자화 없이 LoRA 적용 후 모델 구조")
trainable_params = model.print_trainable_parameters()
logging.info(f"Model setup complete. Trainable parameters: {trainable_params}")
return model
# 양자화 없이 원래 모델에 LoRA를 적용하는 함수 (메모리 요구사항 높음)
"""
def setup_full_precision_model():
# 양자화 없이 전체 정밀도로 모델 로드
logging.info(f"Loading base model from {MODEL_ID} in full precision")
model = Gemma3ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # bfloat16 precision for better memory efficiency
device_map="auto", # 자동으로 GPU에 모델 분산 (여러 GPU 사용 가능)
trust_remote_code=True,
)
# LoRA 적용 전 모델 구조 출력
print_model_structure(model, "전체 정밀도 LoRA 적용 전 모델 구조")
logging.info("Applying LoRA configuration...")
lora_config = LoraConfig(**LORA_CONFIG)
model = get_peft_model(model, lora_config)
# LoRA 적용 후 모델 구조 출력
print_model_structure(model, "전체 정밀도 LoRA 적용 후 모델 구조")
trainable_params = model.print_trainable_parameters()
logging.info(f"Full precision model setup complete. Trainable parameters: {trainable_params}")
return model
"""
3-4. 파인튜닝 진행하기(trainer.py)
- 여기도 토큰화된 데이터 구조가 달라서 그걸 점검해주는 코드가 추가 되어 길다.
- data.py에서 언급한 토큰화 이후, 조금 다른 구조를 반영한 것이다.
(*이렇게 파인튜닝 완료이후, 팀원중 한명이 Gemma3 멀티모달 모델에서 텍스트 파인튜닝만 하는 문서를 찾아내기는 하였다....그래도 공부할겸, 자세히 뜯어보는 버전으로 가자)
from transformers import (
Trainer,
TrainingArguments,
DataCollatorForSeq2Seq,
TrainerCallback,
)
from config import TRAINING_ARGS, SYSTEM_PROMPT
import logging
import torch
import numpy as np
import os
from data import Prompter # Prompter 클래스 import 추가
# 커스텀 데이터 콜레이터 추가
class CustomDataCollator:
"""Gemma3Processor에 맞는 커스텀 데이터 콜레이터"""
def __init__(self, processor, pad_to_multiple_of=None, return_tensors="pt", padding=True):
self.processor = processor
self.pad_to_multiple_of = pad_to_multiple_of
self.return_tensors = return_tensors
self.padding = padding
# 로깅 횟수를 제한하기 위한 카운터 추가
self.log_count = 0
self.max_logs = 3 # 최대 로깅 횟수 감소
logging.info(f"CustomDataCollator 초기화됨")
def __call__(self, features):
"""데이터 배치를 정리하여 입력 형식으로 변환"""
# 1. 기본 로깅 설정
verbose_logging = self.log_count < self.max_logs
self.log_count += 1
# 2. 입력 데이터 구조 분석 및 로깅 (디버깅용)
if verbose_logging:
logging.info(f"DataCollator - 배치 크기: {len(features)}")
# 첫 번째 샘플 확인
if len(features) > 0:
first_sample = features[0]
logging.info(f"첫 번째 샘플 타입: {type(first_sample)}")
# 샘플 전체 구조 확인 (디버깅용)
if isinstance(first_sample, dict):
logging.info(f"샘플 키: {list(first_sample.keys())}")
# 각 키의 구조도 확인
for key in first_sample.keys():
value = first_sample[key]
logging.info(f" - {key} 타입: {type(value)}")
if isinstance(value, list):
logging.info(f" - 길이: {len(value)}")
if len(value) > 0:
logging.info(f" - 첫 요소 타입: {type(value[0])}")
if isinstance(value[0], list):
logging.info(f" - 중첩된 리스트, 내부 길이: {len(value[0])}")
# BatchFeature 객체인 경우 더 자세히 확인
if hasattr(first_sample, "data") and isinstance(first_sample.data, dict):
logging.info(f"BatchFeature 객체 감지 - 키: {first_sample.data.keys()}")
# data 내부의 각 키 구조도 확인
for key in first_sample.data.keys():
value = first_sample.data[key]
logging.info(f" - data['{key}'] 타입: {type(value)}")
if isinstance(value, list):
logging.info(f" - 길이: {len(value)}")
if len(value) > 0:
logging.info(f" - 첫 요소 타입: {type(value[0])}")
if isinstance(value[0], list):
logging.info(f" - 중첩된 리스트, 내부 길이: {len(value[0])}")
# input_ids 상세 분석
if "input_ids" in first_sample.data:
input_ids_data = first_sample.data["input_ids"]
# input_ids가 리스트의 리스트인 경우
if isinstance(input_ids_data, list) and len(input_ids_data) > 0 and isinstance(input_ids_data[0], list):
input_ids = input_ids_data[0] # 첫 번째 배치 항목 사용
logging.info(f"샘플 input_ids 길이: {len(input_ids)}")
logging.info(f"중첩된 리스트 구조 감지 - 내부 리스트 사용 (길이: {len(input_ids)})")
# 처음과 끝 토큰 표시
if len(input_ids) > 10:
logging.info(f"처음 10개 토큰: {input_ids[:10]}")
logging.info(f"마지막 10개 토큰: {input_ids[-10:]}")
# 첫 10개 토큰 디코딩
try:
decoded_start = self.processor.decode(input_ids[:10])
logging.info(f"처음 10개 토큰 디코딩: {decoded_start}")
except Exception as e:
logging.warning(f"토큰 디코딩 실패: {e}")
else:
logging.warning(f"예상과 다른 input_ids 구조: {type(input_ids_data)}")
if isinstance(input_ids_data, list):
logging.info(f"input_ids 리스트 길이: {len(input_ids_data)}")
if len(input_ids_data) > 0:
logging.info(f"첫 번째 요소 타입: {type(input_ids_data[0])}")
# labels 상세 분석
if "labels" in first_sample.data:
labels_data = first_sample.data["labels"]
# labels가 리스트의 리스트인 경우
if isinstance(labels_data, list) and len(labels_data) > 0 and isinstance(labels_data[0], list):
labels = labels_data[0] # 첫 번째 배치 항목 사용
non_masked = sum(1 for x in labels if x != -100)
total = len(labels)
logging.info(f"중첩된 labels 구조 감지 - 내부 리스트 사용 (길이: {len(labels)})")
logging.info(f"Labels 샘플 - 전체: {total}, 마스킹 안됨: {non_masked}, 비율: {non_masked/total:.2%}")
# 마스킹되지 않은 첫 부분 찾기
first_non_masked_idx = next((i for i, x in enumerate(labels) if x != -100), -1)
if first_non_masked_idx != -1:
logging.info(f"마스킹되지 않은 첫 토큰 위치: {first_non_masked_idx}")
else:
logging.warning(f"예상과 다른 labels 구조: {type(labels_data)}")
if isinstance(labels_data, list):
logging.info(f"labels 리스트 길이: {len(labels_data)}")
if len(labels_data) > 0:
logging.info(f"첫 번째 요소 타입: {type(labels_data[0])}")
# attention_mask 분석 (일반 딕셔너리 경우)
if "attention_mask" in first_sample:
attention_mask = first_sample['attention_mask']
# attention_mask가 리스트의 리스트인지 확인 (Gemma3 구조 처리)
if isinstance(attention_mask, list) and len(attention_mask) == 1 and isinstance(attention_mask[0], list):
attention_mask = attention_mask[0] # 내부 리스트 사용
logging.info(f"중첩된 attention_mask 구조 감지 - 내부 리스트 사용")
# attention_mask가 리스트인 경우 합을 계산
if isinstance(attention_mask, list):
attention_sum = sum(1 for x in attention_mask if x == 1)
logging.info(f"Attention mask - 길이: {len(attention_mask)}, 합계: {attention_sum} (패딩 제외 토큰 수)")
else:
# 단일 값이거나 다른 경우
logging.info(f"Attention mask - 타입: {type(attention_mask)}")
# BatchFeature 객체의 attention_mask도 처리
if hasattr(first_sample, "data") and isinstance(first_sample.data, dict) and "attention_mask" in first_sample.data:
mask_data = first_sample.data["attention_mask"]
# attention_mask가 리스트의 리스트인지 확인 (Gemma3 구조 처리)
if isinstance(mask_data, list) and len(mask_data) == 1 and isinstance(mask_data[0], list):
mask = mask_data[0] # 내부 리스트 사용
logging.info(f"BatchFeature 중첩된 attention_mask 구조 감지 - 내부 리스트 사용")
attention_sum = sum(1 for x in mask if x == 1)
logging.info(f"BatchFeature attention_mask - 길이: {len(mask)}, 합계: {attention_sum} (패딩 제외 토큰 수)")
elif isinstance(mask_data, list) and len(mask_data) > 0:
if isinstance(mask_data[0], list):
# 첫 번째 배치 항목 사용
mask = mask_data[0]
attention_sum = sum(1 for x in mask if x == 1)
logging.info(f"BatchFeature attention_mask - 길이: {len(mask)}, 합계: {attention_sum} (패딩 제외 토큰 수)")
else:
logging.info(f"BatchFeature attention_mask - 타입: {type(mask_data)}, 길이: {len(mask_data)}")
# 3. 주요 필드 추출
input_ids = []
attention_mask = []
labels = []
for i, feature in enumerate(features):
# BatchFeature 객체인 경우 처리
if hasattr(feature, "data") and isinstance(feature.data, dict):
# input_ids 처리
if "input_ids" in feature.data:
input_ids_data = feature.data["input_ids"]
# 2차원 구조인 경우 첫 번째 항목만 사용
if isinstance(input_ids_data, list) and len(input_ids_data) > 0 and isinstance(input_ids_data[0], list):
ids = input_ids_data[0]
input_ids.append(ids)
if verbose_logging and i == 0:
logging.info(f"BatchFeature input_ids 중첩 구조 처리 (길이: {len(ids)})")
else:
# 중첩 구조가 아닌 경우 그대로 사용
input_ids.append(input_ids_data)
logging.warning(f"추출 중 input_ids 평면 구조 사용: {type(input_ids_data)}")
# attention_mask 처리
if "attention_mask" in feature.data:
mask_data = feature.data["attention_mask"]
# 2차원 구조인 경우 첫 번째 항목만 사용
if isinstance(mask_data, list) and len(mask_data) > 0 and isinstance(mask_data[0], list):
mask = mask_data[0]
attention_mask.append(mask)
if verbose_logging and i == 0:
logging.info(f"BatchFeature - 중첩된 attention_mask 구조 처리 (길이: {len(mask)})")
else:
# 중첩 구조가 아닌 경우 그대로 사용
attention_mask.append(mask_data)
# 모든 값이 0인 경우 (일반적으로 오류) 대체
if attention_mask and isinstance(attention_mask[-1], list) and all(x == 0 for x in attention_mask[-1]):
if verbose_logging and i == 0:
logging.warning(f"모든 값이 0인 BatchFeature attention_mask 감지 - 수정 필요")
# input_ids가 있으면 그 길이에 맞게 1로 채움
if input_ids and input_ids[-1]:
attention_mask[-1] = [1] * len(input_ids[-1])
elif input_ids and input_ids[-1]: # input_ids가 있으면 동일한 길이의 마스크 생성
attention_mask.append([1] * len(input_ids[-1]))
else:
attention_mask.append([1]) # 기본값을 [0]에서 [1]로 변경
# labels 처리
if "labels" in feature.data:
labels_data = feature.data["labels"]
# 2차원 구조인 경우 첫 번째 항목만 사용
if isinstance(labels_data, list) and len(labels_data) > 0 and isinstance(labels_data[0], list):
label = labels_data[0]
labels.append(label)
if verbose_logging and i == 0:
logging.info(f"BatchFeature labels 중첩 구조 처리 (길이: {len(label)})")
else:
# 중첩 구조가 아닌 경우 그대로 사용
labels.append(labels_data)
elif input_ids and input_ids[-1]: # input_ids가 있으면 모두 -100으로 설정 (손실 계산 제외)
labels.append([-100] * len(input_ids[-1]))
else:
labels.append([-100])
# 일반 딕셔너리인 경우 (기존 코드)
else:
if "input_ids" in feature:
# input_ids 추출
ids = feature["input_ids"]
# 중첩된 리스트 구조인지 확인 (길이가 1인 리스트 속 실제 토큰)
if isinstance(ids, list) and len(ids) == 1 and isinstance(ids[0], list):
ids = ids[0] # 내부 리스트 사용
if verbose_logging and i == 0: # 첫 항목에 대해서만 로깅
logging.info(f"일반 dict - 중첩된 input_ids 구조 감지 (길이 {len(ids)})")
# 빈 input_ids인 경우 예외 처리
if not ids:
ids = [0]
input_ids.append(ids)
# attention_mask 추출 (제공되지 않으면 1로 채움)
if "attention_mask" in feature:
mask = feature["attention_mask"]
# 중첩된 리스트 구조인지 확인 - Gemma3에서는 [0]이 중요할 수 있음
if isinstance(mask, list) and len(mask) == 1 and isinstance(mask[0], list):
mask = mask[0] # 내부 리스트 사용
if verbose_logging and i == 0:
logging.info(f"일반 dict - 중첩된 attention_mask 구조 감지 (길이 {len(mask)})")
# 모든 값이 0인 경우 (일반적으로 오류) 대체
if isinstance(mask, list) and all(x == 0 for x in mask):
if verbose_logging and i == 0:
logging.warning(f"모든 값이 0인 attention_mask 감지 - 수정 필요")
# input_ids가 있으면 그 길이에 맞게 1로 채움
if input_ids and input_ids[-1]:
mask = [1] * len(input_ids[-1])
attention_mask.append(mask)
elif input_ids and input_ids[-1]: # input_ids가 있으면 그 길이에 맞춰 1로 채움
attention_mask.append([1] * len(input_ids[-1]))
else:
attention_mask.append([1]) # 기본값을 [0]에서 [1]로 변경 - 최소한 하나의 토큰은 주의해야 함
# labels 추출 (제공되지 않으면 -100으로 채움)
if "labels" in feature:
label_data = feature["labels"]
# 중첩된 리스트 구조인지 확인
if isinstance(label_data, list) and len(label_data) == 1 and isinstance(label_data[0], list):
label_data = label_data[0] # 내부 리스트 사용
if verbose_logging and i == 0: # 첫 항목에 대해서만 로깅
logging.info(f"일반 dict - 중첩된 labels 구조 감지 (길이 {len(label_data)})")
labels.append(label_data)
elif input_ids and input_ids[-1]: # input_ids가 있으면 그 길이에 맞춰 -100으로 채움 (손실 계산 제외)
labels.append([-100] * len(input_ids[-1]))
else:
labels.append([-100])
# 4. 빈 배치 처리
if not input_ids:
error_msg = "No valid input_ids found in features"
logging.error(error_msg)
return {
"input_ids": torch.zeros((1, 1), dtype=torch.long),
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
"labels": torch.ones((1, 1), dtype=torch.long) * -100, # -100: 손실 계산에서 제외
}
# 5. 시퀀스 길이 확인 및 로깅
if verbose_logging:
# 전체 배치의 길이 통계
all_input_lengths = [len(ids) for ids in input_ids]
logging.info(f"배치 길이 통계 - 최소: {min(all_input_lengths)}, 최대: {max(all_input_lengths)}, 평균: {sum(all_input_lengths)/len(all_input_lengths):.1f}")
# 6. 모든 시퀀스 길이 동기화
# input_ids, attention_mask, labels의 길이가 일치하도록 조정
mismatch_found = False
for i in range(len(input_ids)):
# 6-1. input_ids와 attention_mask 길이 일치시키기
if i < len(attention_mask) and len(input_ids[i]) != len(attention_mask[i]):
mismatch_found = True
if len(input_ids[i]) > len(attention_mask[i]):
# attention_mask 확장
attention_mask[i] = attention_mask[i] + [0] * (len(input_ids[i]) - len(attention_mask[i]))
else:
# input_ids 확장
input_ids[i] = input_ids[i] + [0] * (len(attention_mask[i]) - len(input_ids[i]))
# 6-2. input_ids와 labels 길이 일치시키기
if i < len(labels) and len(input_ids[i]) != len(labels[i]):
mismatch_found = True
if len(input_ids[i]) > len(labels[i]):
# labels 확장 (패딩 토큰은 -100으로 설정해 loss 계산에서 제외)
labels[i] = labels[i] + [-100] * (len(input_ids[i]) - len(labels[i]))
else:
# input_ids와 attention_mask 확장
padding_len = len(labels[i]) - len(input_ids[i])
input_ids[i] = input_ids[i] + [0] * padding_len
if i < len(attention_mask):
attention_mask[i] = attention_mask[i] + [0] * padding_len
# 7. 길이 일치 여부 로깅
if mismatch_found and verbose_logging:
logging.info("텐서 길이 불일치가 발견되어 수정되었습니다")
# 8. 배치 내 최대 시퀀스 길이 계산
max_length = max(max(len(ids) for ids in input_ids),
max(len(mask) for mask in attention_mask) if attention_mask else 0,
max(len(lbl) for lbl in labels) if labels else 0)
if verbose_logging:
logging.info(f"배치 내 최대 시퀀스 길이: {max_length}")
# 9. 필요시 max_length를 pad_to_multiple_of에 맞게 조정
if self.pad_to_multiple_of is not None:
max_length = ((max_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of)
# 10. 패딩 적용하여 모든 시퀀스 길이 동일하게 만들기
padded_input_ids = []
padded_attention_mask = []
padded_labels = []
# 입력 데이터 길이 확인 (디버깅용)
if verbose_logging:
logging.info(f"=== 배치 항목 길이 확인 ===")
logging.info(f"배치 크기: {len(input_ids)}")
# 각 항목 길이 출력
for idx, (ids, mask, label) in enumerate(zip(input_ids, attention_mask, labels)):
if idx < 5: # 처음 5개 항목만 출력
logging.info(f"항목 {idx}: input_ids={len(ids)}, attention_mask={len(mask)}, labels={len(label)}")
# 내부 구조 샘플링
if len(input_ids) > 0:
first_ids = input_ids[0]
logging.info(f"첫 번째 input_ids 타입: {type(first_ids)}, 길이: {len(first_ids)}")
if len(first_ids) > 0:
logging.info(f"첫 번째 토큰 타입: {type(first_ids[0])}")
for i in range(len(input_ids)):
try:
# input_ids 패딩
ids = input_ids[i] if i < len(input_ids) else []
if len(ids) == 0:
ids = [0]
padding_length = max_length - len(ids)
padded_ids = ids + [0] * padding_length
padded_input_ids.append(padded_ids)
# attention_mask 패딩
mask = attention_mask[i] if i < len(attention_mask) else []
if len(mask) == 0:
mask = [0]
padding_length = max_length - len(mask)
padded_mask = mask + [0] * padding_length
padded_attention_mask.append(padded_mask)
# labels 패딩
lbl = labels[i] if i < len(labels) else []
if len(lbl) == 0:
lbl = [-100]
padding_length = max_length - len(lbl)
padded_lbl = lbl + [-100] * padding_length
padded_labels.append(padded_lbl)
except Exception as e:
logging.error(f"Error padding item {i}: {e}")
# 오류 발생 시 기본값으로 대체
padded_input_ids.append([0] * max_length)
padded_attention_mask.append([0] * max_length)
padded_labels.append([-100] * max_length)
try:
# 텐서로 변환
if self.return_tensors == "pt":
# 텐서 형태 로깅 (제한된 횟수만)
if verbose_logging:
logging.info(f"텐서 변환 - 배치 크기: {len(padded_input_ids)}, 시퀀스 길이: {max_length}")
# 텐서 변환 전 데이터 검증
if any(len(ids) != max_length for ids in padded_input_ids):
warning_msg = "패딩 후에도 시퀀스 길이가 일관되지 않음!"
logging.warning(warning_msg)
# 길이가 일치하지 않는 시퀀스 수정
for i, ids in enumerate(padded_input_ids):
if len(ids) != max_length:
logging.info(f" - 시퀀스 {i} 길이 불일치: {len(ids)} != {max_length}")
padded_input_ids = [ids[:max_length] + [0] * (max_length - len(ids)) if len(ids) != max_length else ids for ids in padded_input_ids]
padded_attention_mask = [mask[:max_length] + [0] * (max_length - len(mask)) if len(mask) != max_length else mask for mask in padded_attention_mask]
padded_labels = [lbl[:max_length] + [-100] * (max_length - len(lbl)) if len(lbl) != max_length else lbl for lbl in padded_labels]
# 모든 텐서가 같은 길이를 가지도록 보장
assert all(len(ids) == max_length for ids in padded_input_ids), "Input IDs lengths don't match max_length"
assert all(len(mask) == max_length for mask in padded_attention_mask), "Attention mask lengths don't match max_length"
assert all(len(lbl) == max_length for lbl in padded_labels), "Labels lengths don't match max_length"
# 모든 항목이 정수 리스트인지 확인 (중첩 리스트 제거)
for i in range(len(padded_input_ids)):
# input_ids 확인
if not all(isinstance(token, int) for token in padded_input_ids[i]):
logging.warning(f"항목 {i}의 input_ids에 정수가 아닌 값 또는 중첩 리스트 발견")
logging.warning(f"문제 있는 데이터: {padded_input_ids[i][:10]}...")
try:
# 중첩 리스트가 있을 경우 flatten 시도
flattened = []
for item in padded_input_ids[i]:
if isinstance(item, list):
flattened.extend(item)
else:
flattened.append(item)
padded_input_ids[i] = flattened[:max_length] + [0] * max(0, max_length - len(flattened))
logging.info(f"중첩 리스트 평탄화 후 길이: {len(padded_input_ids[i])}")
except Exception as e:
logging.error(f"input_ids 수정 실패: {e}")
padded_input_ids[i] = [0] * max_length
# attention_mask 확인
if not all(isinstance(token, int) for token in padded_attention_mask[i]):
logging.warning(f"항목 {i}의 attention_mask에 정수가 아닌 값 또는 중첩 리스트 발견")
try:
# 중첩 리스트가 있을 경우 flatten 시도
flattened = []
for item in padded_attention_mask[i]:
if isinstance(item, list):
flattened.extend(item)
else:
flattened.append(item)
padded_attention_mask[i] = flattened[:max_length] + [0] * max(0, max_length - len(flattened))
except Exception as e:
logging.error(f"attention_mask 수정 실패: {e}")
padded_attention_mask[i] = [0] * max_length
# labels 확인
if not all(isinstance(token, int) for token in padded_labels[i]):
logging.warning(f"항목 {i}의 labels에 정수가 아닌 값 또는 중첩 리스트 발견")
try:
# 중첩 리스트가 있을 경우 flatten 시도
flattened = []
for item in padded_labels[i]:
if isinstance(item, list):
flattened.extend(item)
else:
flattened.append(item)
padded_labels[i] = flattened[:max_length] + [-100] * max(0, max_length - len(flattened))
except Exception as e:
logging.error(f"labels 수정 실패: {e}")
padded_labels[i] = [-100] * max_length
batch = {
"input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
"attention_mask": torch.tensor(padded_attention_mask, dtype=torch.long),
"labels": torch.tensor(padded_labels, dtype=torch.long),
}
# 텐서 형태 확인 (제한된 횟수만)
if verbose_logging:
logging.info(f"=== 최종 텐서 정보 ===")
for key, tensor in batch.items():
logging.info(f"{key} 텐서 형태: {tensor.shape}, 타입: {tensor.dtype}")
if key == "labels":
non_masked = (tensor != -100).sum().item()
total = tensor.numel()
logging.info(f" - 마스킹되지 않은 라벨: {non_masked}/{total} ({non_masked/total:.2%})")
logging.info(f"배치 처리 완료 - 형태: {batch['input_ids'].shape}")
elif self.return_tensors == "np":
batch = {
"input_ids": np.array(padded_input_ids),
"attention_mask": np.array(padded_attention_mask),
"labels": np.array(padded_labels),
}
else:
batch = {
"input_ids": padded_input_ids,
"attention_mask": padded_attention_mask,
"labels": padded_labels,
}
return batch
except Exception as e:
error_msg = f"배치 생성 중 오류 발생: {e}"
logging.error(error_msg)
# 오류 발생 시 기본 배치 반환 (형태 수정)
batch_size = len(features)
# 모든 텐서가 같은 길이를 가지도록 함
default_batch = {
"input_ids": torch.zeros((batch_size, 8), dtype=torch.long),
"attention_mask": torch.zeros((batch_size, 8), dtype=torch.long),
"labels": torch.full((batch_size, 8), -100, dtype=torch.long)
}
logging.info(f"기본 배치 반환: 형태 {default_batch['input_ids'].shape}")
return default_batch
class TrainingMonitorCallback(TrainerCallback):
"""학습 진행 상황을 모니터링하고 샘플 생성 테스트를 수행하는 콜백"""
def __init__(self, processor, prompter, model, train_dataset, eval_dataset=None):
self.processor = processor
self.prompter = prompter
self.model = model
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.sample_interval = 500 # 몇 스텝마다 샘플을 생성할지 (기존 100에서 증가)
self.last_generated_step = 0 # 마지막으로 샘플을 생성한 스텝
self.test_prompt = "클릭한번하시면, 수익률 500% 보장..."
def on_step_end(self, args, state, control, **kwargs):
"""각 학습 스텝 이후 호출되는 콜백 메서드"""
# 일정 간격으로 샘플 생성 테스트 수행
if state.global_step - self.last_generated_step >= self.sample_interval:
self.last_generated_step = state.global_step
logging.info(f"스텝 {state.global_step}: 샘플 생성 테스트")
# Gemma3 모델 형식에 맞는 입력 생성
messages = self.prompter.generate_chat_prompt(
instruction=self.test_prompt,
label=None
)
# 모델의 현재 상태 저장 (추론 모드 변환 전)
training_mode = self.model.training
try:
# 추론 모드로 전환
self.model.eval()
# 채팅 템플릿 적용 및 토큰화
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
# 생성 설정
generation_config = {
"max_new_tokens": 100,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
}
# 텍스트 생성
with torch.inference_mode():
outputs = self.model.generate(**inputs, **generation_config)
# 생성된 텍스트 추출
result = self.prompter.extract_generated_text(self.processor, outputs, inputs)
# 결과 로깅
logging.info(f"샘플 입력: {self.test_prompt}")
logging.info(f"샘플 출력: {result['generated_text']}")
# 검증 데이터셋에서 샘플 하나를 가져와 테스트 수행
if self.eval_dataset is not None and len(self.eval_dataset) > 0:
# 랜덤 인덱스 선택
idx = np.random.randint(0, len(self.eval_dataset))
eval_sample = self.eval_dataset[idx]
# 원본 데이터 구조 확인
if isinstance(eval_sample, dict) and "instruction" in eval_sample:
# 입력값 추출
eval_input = eval_sample["instruction"]
eval_label = eval_sample.get("output", "")
# 채팅 프롬프트 생성
eval_messages = self.prompter.generate_chat_prompt(
instruction=eval_input,
label=None
)
# 모델 입력 준비
eval_inputs = self.processor.apply_chat_template(
eval_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
# 텍스트 생성
with torch.inference_mode():
eval_outputs = self.model.generate(**eval_inputs, **generation_config)
# 생성된 텍스트 추출
eval_result = self.prompter.extract_generated_text(self.processor, eval_outputs, eval_inputs)
# 검증 데이터 결과 로깅
logging.info(f"=== 검증 데이터 샘플 결과 ===")
logging.info(f"검증 입력: {eval_input}")
if eval_label:
logging.info(f"정답 출력: {eval_label}")
logging.info(f"모델 출력: {eval_result['generated_text']}")
logging.info(f"==============================")
except Exception as e:
logging.error(f"샘플 생성 중 오류 발생: {str(e)}")
finally:
# 원래 모드로 복원
self.model.train(training_mode)
return control
def setup_trainer(model, processor, train_data, val_data):
"""
Trainer 설정 및 생성
Args:
model: 파인튜닝할 모델
processor: 토크나이저/프로세서
train_data: 훈련 데이터셋
val_data: 검증 데이터셋
Returns:
Trainer 인스턴스
"""
logging.info("훈련 설정 중...")
# 데이터 콜레이터 설정
data_collator = CustomDataCollator(
processor=processor,
pad_to_multiple_of=8, # 8의 배수로 패딩
return_tensors="pt",
padding=True
)
# 훈련 인자 설정
training_args = TrainingArguments(**TRAINING_ARGS)
# Prompter 인스턴스 생성
prompter = Prompter()
# 콜백 설정
callbacks = [
TrainingMonitorCallback(
processor=processor,
prompter=prompter,
model=model,
train_dataset=train_data,
eval_dataset=val_data
)
]
# Trainer 인스턴스 생성
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
data_collator=data_collator,
callbacks=callbacks
)
logging.info("훈련 준비 완료")
return trainer
3-5. 훈련 실행하기(main.py)
- 우선 코드 맨밑에 언급해둔 라이브러리들을 pip로 설치
- 트랜스포머도 밑에 언급해둔 버전으로 설치(2025.03.22기준)
- python main.py로 학습시작
from huggingface_hub import login
from data import Prompter, load_dataset, generate_and_tokenize_chat_prompt
from model import setup_processor, setup_model
from trainer import setup_trainer
from inference import load_trained_model, run_inference
from config import HF_TOKEN, MODEL_ID, DATASET_NAME, SYSTEM_PROMPT, LORA_CONFIG
import logging
import torch
from transformers import Gemma3ForConditionalGeneration, AutoProcessor
import os
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def test_model_generation(
model, processor, prompter, text, system_prompt=SYSTEM_PROMPT, image_path=None
):
"""Test model's generation with a sample input using chat_template
Args:
model: Gemma3ForConditionalGeneration 모델
processor: AutoProcessor
prompter: Prompter 클래스 인스턴스
text: 생성할 텍스트 프롬프트
system_prompt: 시스템 프롬프트
image_path: 이미지 경로 (있는 경우)
"""
# 메시지 생성
messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
}
]
# 사용자 메시지 구성 (이미지가 있는 경우 포함)
user_content = []
# 이미지 추가 (있는 경우)
if image_path:
user_content.append({"type": "image", "image": image_path})
# 텍스트 추가
user_content.append({"type": "text", "text": text})
# 사용자 메시지 추가
messages.append({
"role": "user",
"content": user_content
})
# 프로세서로 토큰화 및 모델 입력 준비
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16) # 양자화 사용 시 torch.bfloat16 사용
# 추론 모드로 텍스트 생성
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True, # 샘플링 활성화
temperature=0.7,
top_p=0.95,
top_k=64,
)
# 모델 출력에서 생성된 텍스트 추출
result = prompter.extract_generated_text(processor, outputs, inputs)
return result
def main():
"""
메인 훈련 함수
"""
# Configure logging
logging.info("멀티모달 Gemma3 파인튜닝 파이프라인 시작...")
# 환경 변수 설정
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Login to Hugging Face
logging.info("Hugging Face 로그인 중...")
login(HF_TOKEN)
logging.info(f"훈련 설정 - 모델 ID: {MODEL_ID}")
logging.info(f"훈련 설정 - LORA 설정: {LORA_CONFIG}")
# Setup processor and prompter
logging.info("프로세서와 프롬프터 초기화 중...")
processor = AutoProcessor.from_pretrained(MODEL_ID)
tokenizer = processor.tokenizer
logging.info(f"토크나이저 클래스: {type(tokenizer).__name__}")
# logging.info(f"토크나이저 세부 정보: {tokenizer}")
prompter = Prompter()
logging.info(f"시스템 프롬프트: {prompter.system_prompt}")
# Load and process dataset
logging.info("데이터셋 로드 및 처리 중...")
train_data, val_data = load_dataset(DATASET_NAME)
logging.info(
f"로드된 데이터: {len(train_data)} 훈련 샘플, {len(val_data)} 검증 샘플"
)
# 데이터셋 샘플 확인
logging.info("#1. 데이터셋 샘플 확인:")
for i in range(min(3, len(train_data))):
sample = train_data[i]
logging.info(f"샘플 {i+1}:")
logging.info(f"문자: {sample['문자']}...")
logging.info(f"분자분류 라벨링: {sample['문자분류 라벨링']}")
logging.info("-" * 50)
# 데이터셋에 이미지 컬럼이 있는지 확인 (예: 'image_path' 컬럼이 있다고 가정)
image_column = None
if 'image_path' in train_data.column_names:
image_column = 'image_path'
logging.info(f"이미지 컬럼 감지됨: {image_column}")
# Tokenize datasets using chat_template
logging.info("#2.채팅 템플릿으로 데이터셋 토큰화 중...")
# ===== 토큰화 과정 시작 =====
# 1. 샘플 데이터 토큰화 테스트 (첫 번째 데이터로 작동 확인)
test_sample = train_data[0]
logging.info(f"샘플 데이터 구조: {test_sample.keys()}")
logging.info(f"샘플 문자 예시: {test_sample['문자'][:100]}...")
logging.info(f"샘플 문자분류 예시: {test_sample['문자분류 라벨링'][:100]}...")
# 2. 단일 샘플 토큰화 결과 확인 (데이터 형태 및 길이)
test_tokenized = generate_and_tokenize_chat_prompt(test_sample, prompter, processor, image_field=image_column, is_sample=True)
# BatchFeature 객체인지 확인하고 적절히 처리
if hasattr(test_tokenized, "data") and isinstance(test_tokenized.data, dict):
# BatchFeature 객체인 경우 data 속성에서 값 가져오기
input_ids_list = test_tokenized.data.get("input_ids", [])
labels_list = test_tokenized.data.get("labels", [])
# 2차원 리스트 처리
if input_ids_list and isinstance(input_ids_list[0], list):
input_ids = input_ids_list[0]
labels = labels_list[0] if labels_list and isinstance(labels_list[0], list) else []
logging.info(f"토큰화 테스트 - BatchFeature 객체, 2차원 구조에서 데이터 추출")
else:
input_ids = input_ids_list
labels = labels_list
logging.info(f"토큰화 테스트 - BatchFeature 객체에서 데이터 추출")
else:
# 일반 딕셔너리인 경우 직접 접근
input_ids = test_tokenized.get("input_ids", [])
labels = test_tokenized.get("labels", [])
# 기본 정보 출력
logging.info("=== 토큰화 테스트 결과 ===")
logging.info(f"input_ids 길이: {len(input_ids)}")
logging.info(f"labels 길이: {len(labels)}")
# 라벨 마스킹 확인
non_masked = sum(1 for x in labels if x != -100)
first_non_masked_idx = next((i for i, x in enumerate(labels) if x != -100), -1)
if first_non_masked_idx != -1:
logging.info(f"첫 번째 마스킹되지 않은 토큰 위치: {first_non_masked_idx}")
# 마스킹되지 않은 라벨 디코딩
non_masked_labels = [x for x in labels if x != -100]
if non_masked_labels:
non_masked_text = processor.decode(non_masked_labels)
logging.info(f"마스킹되지 않은 라벨: {non_masked_text}")
# 전체 입력 프롬프트 디코딩
if input_ids:
input_text = processor.decode(input_ids)
logging.info(f"전체 입력 프롬프트 디코딩: {input_text}")
logging.info("=== 토큰화 테스트 완료 ===")
# 3. 전체 데이터셋 토큰화
# 원본 텍스트 컬럼 제거 설정 (토큰화 후에는 필요 없음)
columns_to_remove = ['문자', '문자분류 라벨링']
if image_column:
columns_to_remove.append(image_column)
# 3-1. 훈련 데이터셋 토큰화
train_data = train_data.map(
lambda x: generate_and_tokenize_chat_prompt(x, prompter, processor, image_field=image_column),
remove_columns=columns_to_remove, # 원본 컬럼 제거
desc="토큰화 중...",
)
# 3-2. 검증 데이터셋 토큰화
val_data = val_data.map(
lambda x: generate_and_tokenize_chat_prompt(x, prompter, processor, image_field=image_column),
remove_columns=columns_to_remove, # 원본 컬럼 제거
desc="토큰화 중...",
)
# 4. 토큰화 결과 확인
logging.info(f"토큰화된 훈련 데이터셋 크기: {len(train_data)}")
logging.info(f"토큰화된 검증 데이터셋 크기: {len(val_data)}")
logging.info(f"토큰화된 데이터 컬럼: {train_data.column_names}")
# ===== 토큰화 과정 완료 =====
# Setup model
logging.info("멀티모달 Gemma3 모델 초기화 중...")
# 원본 모델 테스트는 건너뜀 (메모리 절약)
logging.info("원본 모델 테스트는 건너뜁니다 (메모리 절약)")
# Setup model with LoRA
model = setup_model()
# 모델 설정 확인
logging.info(f"모델 설정: {model.config}")
logging.info(f"모델 디바이스: {model.device}")
# 모델 캐시 설정 확인 및 수정
model.config.use_cache = False
logging.info(f"모델 캐시 설정: {model.config.use_cache}")
# Setup trainer and train
logging.info("학습 시작...")
trainer = setup_trainer(model, processor, train_data, val_data)
# 학습 시작 전 모델 상태 확인
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
logging.info(f"학습 가능한 파라미터: {trainable_params:,} / {all_params:,} ({trainable_params/all_params:.2%})")
# 학습 시작
trainer.train()
logging.info("학습 완료!")
# Save model to Hub
logging.info("모델을 Hugging Face Hub에 저장 중...")
model_name = "여러분 허깅패이스 주소/SpamAI_Gemma3-12b-it-v1"
model.push_to_hub(model_name, use_auth_token=True)
logging.info(f"모델 저장됨: {model_name}")
# 샘플 추론 테스트
#logging.info("샘플 텍스트로 추론 테스트 중...")
#sample_text = "일단 클릭하면 수익률 500%보장..."
#result = test_model_generation(model, processor, prompter, sample_text)
#logging.info(f"샘플 추론 결과: {result['generated_text']}")
# Load trained model and run inference
#logging.info("검증 데이터로 추론 실행 중...")
#trained_model = load_trained_model(model_name, processor)
#run_inference(trained_model, processor, prompter, val_data)
#logging.info("파이프라인 성공적으로 완료!")
if __name__ == "__main__":
main()
# pip install torch transformers peft bitsandbytes accelerate pandas datasets openpyxl huggingface-hub wandb scikit-learn tqdm
# pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
# python main.py
4. 결론
- 우여곡절 끝에 멀티모달 모델의 템플릿 구조와 중첩리스트 형태를 알수 있었다.
- 파인튜닝 결과는... 아직 1epoch를 돌리긴 했지만....잘 된다.
- 근데 멀티모달 모델이라 그런지, 초반 loss가 80을 찍는다. 보통 EXAONE, Qwen은 4점대에서 시작하는데... 그래도 결국에는 수렴한다.
- 학습 속도가 생각보다 많이 느리다. 물론 EXAONE7.8B보다 큰 Gemma3 12B모델로 했지만, 그래도 체감상 엄청 느렸다.
5개의 py코드에 자신의 데이터를 넣으면 gemma3 파인튜닝이 가능할 것이다(텍스트 기준)
'AI & ML > 학습하기' 카테고리의 다른 글
AI서비스를 위한 GPU이해하기 (0) | 2025.06.17 |
---|---|
허깅페이스(huggingface) 토크나이저 사용해서 모델 추론하는 3가지 방법 (2) | 2025.06.12 |
Chat_template 구조 파인튜닝하기(feat. EXAONE-3.5-7B) (0) | 2025.03.07 |
DeepSeek-R1 정리(공부하기) + Open r1 (0) | 2025.02.04 |
패캠(패스트캠퍼스) "LLM 모델 파인튜닝을 위한 GPU 최적화" 후기 (2) | 2024.12.02 |