Ssul's Blog

AI Product 개발전략과 개발기 - 임베딩 Top-k전략 본문

AI & ML/사용하기

AI Product 개발전략과 개발기 - 임베딩 Top-k전략

Ssul 2024. 6. 28. 23:13

스팸분류기를 제작하고 있다.

기존에 LLM을 파인튜닝하는

https://issul.tistory.com/455

https://issul.tistory.com/456

두개의 방법과는 다른 접근 방법이다.

 

이 방법은 문자 텍스트 데이터를 성능좋은 임베딩 모델을 가지고, 각 문자를 임베딩하여 N차원의 공간에 배치하는 것이다.

그리고, 새로운 문자가 들어오면, 신규문자를 N차원에 공간에 뿌려서 가장 근처의 문자 3개(Top-k)를 가져와서 문자통계를 내서, 가장 빈도가 높은 문자분류를 입력된 문자의 분류로 결정하는 개념이다.

 

다양한 임베딩 모델이 있지만 오픈AI의 text-embedding-3-small을 활용할 예정이다.

 

 

1. 기본 셋팅을 진행한다

!pip install openai

###########################################
# 1-1. 구글 드라이브 마운트

from google.colab import drive
drive.mount('/content/drive')

import os
from google.colab import userdata
# 환경 변수에 API 키 설정
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
import openai
from openai import OpenAI

openai.api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI()

- 환경변수로 가져온 오픈AI키를 설정한다

 

2. 문자 데이터셋 임베딩

# openai embedding함수
def get_embedding(text, model="text-embedding-3-small"):
    text = text.replace("\n", " ")
    return client.embeddings.create(input=[text], model=model).data[0].embedding
    
    
import pandas as pd

# CSV 파일 경로를 지정합니다.
folder_path = '/content/drive/MyDrive/00_AI연구/02_스팸분류기Proj'
embedding_file = 'spam_embedding.csv'
embedding_file_path = os.path.join(folder_path, embedding_file)

# embedding.csv가 존재하면, 데이터프레임 df로 로드
if os.path.exists(embedding_file_path):
    print(f"{embedding_file} is exist")
    df = pd.read_csv(embedding_file_path)
    # string으로 저장된 embedding을 list로 변환
    df['embedding'] = df['embedding'].apply(ast.literal_eval)
else:
    df = pd.read_csv('/content/drive/MyDrive/00_AI연구/02_스팸분류기Proj/spam_dataset_final3.csv')

    # Completion 열에서 텍스트를 읽어와서 임베딩 처리
    df['embedding'] = df['sms'].apply(lambda x: get_embedding(x, model="text-embedding-3-small"))

    # csv파일로 저장
    df.to_csv(embedding_file_path, index=False, encoding='utf-8')

- 기존의 문자데이터(spam_dataset_final3.csv)를 읽어와서, 임베딩 모델을 가지고 임베딩 진행

- 해당 결과를 spam_embedding.csv에 저장한다.(물론 faiss, chromadb를 사용해도 된다)

 

3. 새로운 문자를 입력받아, Top-3결과 확인하기

import ast
import numpy as np
from numpy import dot
from numpy.linalg import norm


# 유사도 측정
def cos_sim(A, B):
    return dot(A, B)/(norm(A)*norm(B))
    
def query_csv_spam(query: str, use_retriever: bool = False):
    # 입력된 문장 임베딩
    query_embedding = get_embedding(
        query,
        model="text-embedding-3-small"
    )
    if use_retriever:
        # csv파일을 읽어서 임베딩값과 가장 가까운 3개 문장을 반환
        df = pd.read_csv(embedding_file_path)

        # 문자열로 저장된 embedding을 실제 숫자 배열로 변환
        df['embedding'] = df['embedding'].apply(ast.literal_eval)

        df['similarity'] = df.embedding.apply(lambda x: cos_sim(np.array(x), np.array(query_embedding)))
        top_docs = df.sort_values('similarity', ascending=False).head(3)
        from pprint import pprint
        pprint(top_docs)
        # top_docs = top_docs['Text'].to_list()
    else:
        # TODO: 다른 경우 처리
        top_docs = []

    return top_docs

- 새로운 문자를 입력받으면, 그 문자를 임베딩 후,

- 기존의 문자데이터들과 비교하여 유사도가 가장 높은 3개를 찾는 함수

 

query = """
[Web발신]
A
오늘하루만
#강원에너지 +23%
#미래나노텍 +20%
내일급등주받으세요
https://xgo.ac/1CE
"""

- 오늘 아침에 받은 스팸문자를 셋팅한다

 

result = query_csv_spam(query, use_retriever=True)

내가 오늘아침에 받은 문자와 가장 유사한 3개의 문자를 데이터셋에서 가져옴.

2개가 스팸, 1개가 프로모션문자.... 결과는 스팸으로 분류!!!

결과가 맞았다.

 

의외로 데이터셋만 잘 구성이 된다면, 임베딩만으로도 스팸 문자를 분류할수 있을 것 같다.

 

다음 글에서는 3개의 모델을 고도화하고, 성능을 비교분석해 보겠다.