LLM DOS 공격 방어

img 11

LLM DOS 공격 방어

img 11

LLM DOS 공격 시나리오

LLM DOS 공격은 다음과 같은 방식으로 이루어집니다.

  • 공격자는 연산 비용이 높은 복잡한 프롬프트를 대량으로 생성합니다.
  • 수천 개의 동시 요청을 API 서버로 전송하여 처리 능력을 포화 상태로 만듭니다.
  • Rate limiting이나 동시 요청 수 제한이 없는 경우, 서버는 과부하로 인해 정상 사용자에 대한 서비스를 거부하게 됩니다.
  • 토큰 검증이 미흡하거나 없는 경우 익명 공격도 가능합니다.

실제 공격 환경에서는 멀티스레딩을 통해 수만 건의 동시 요청을 보내는 방식이 활용됩니다. 이러한 공격에 대응하기 위한 방어 코드를 아래에서 살펴봅니다.

Vulnerable Code

from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
from functools import wraps
import requests
import os
from datetime import datetime

app = Flask(__name__)

# Load environment variables from a .env file if present
# 이 코드는 환경 변수를 로드하고, 여기에서 SQL 서버 자격 증명을 가져오고, Flask 앱에서 SQLAlchemy 구성을 위한 데이터베이스 URI를 구성하며, 이때 PostgreSQL을 위한 기본 포트가 사용됩니다.
from dotenv import load_dotenv
load_dotenv()

# Fetch SQL server credentials and other details from environment variables
DATABASE_USER = os.getenv('DATABASE_USER')
DATABASE_PASSWORD = os.getenv('DATABASE_PASSWORD')
DATABASE_HOST = os.getenv('DATABASE_HOST')
DATABASE_PORT = os.getenv('DATABASE_PORT', '5432') # Default port for PostgreSQL
DATABASE_NAME = os.getenv('DATABASE_NAME')

# Construct the SQLALCHEMY_DATABASE_URI from environment variables
app.config['SQLALCHEMY_DATABASE_URI'] = f'postgresql://{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_NAME}'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
db = SQLAlchemy(app)

# Model for API tokens
class APIToken(db.Model):
    __tablename__ = 'api_tokens'

    id = db.Column(db.Integer, primary_key=True)
    token = db.Column(db.String(255), unique=True, nullable=False)
    user_id = db.Column(db.Integer, nullable=False)
    expires_at = db.Column(db.DateTime, nullable=False)

# 이 코드는 Flask 경로에서 API 토큰을 위한 데이터베이스 모델과 토큰 기반 인증을 위한 데코레이터를 정의합니다. APIToken 클래스는 ID, 토큰 문자열, 연관된 사용자 ID, 그리고 만료일을 갖는 토큰을 나타냅니다. token_required 데코레이터는 요청 헤더에서 유효한 전달자 토큰을 확인하고, 데이터베이스와 비교하여 유효성을 검사하고, 데코레이팅된 경로에 대한 액세스를 허용하기 전에 만료되지 않았는지 확인합니다.
def token_required(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        auth_header = request.headers.get('Authorization')
        if auth_header and auth_header.startswith('Bearer '):
            user_token = auth_header[7:] # Extract token from header
        else:
            return jsonify({'error': 'Token is missing or improperly formatted'}), 401

        # Validate the user token against the database
        api_token = APIToken.query.filter(APIToken.token == user_token, APIToken.expires_at > datetime.utcnow()).first()

        if not api_token:
            return jsonify({'error': 'User token is invalid or expired'}), 401

        return f(*args, **kwargs)
    return decorated_function

# contact_fake_llm 함수는 환경 변수에서 가져온 토큰을 사용하여 언어 모델 서비스에 요청을 보냅니다. 이 함수는 input_text를 페이로드로 게시하고, HTTP 응답을 처리하고, 서비스의 응답 또는 오류 메시지를 반환합니다.
def contact_fake_llm(input_text):
    llm_service_url = 'http://fake-llm-service.com/process'
    # Fetch the API token for the LLM service from environment variables
    EXTERNAL_LLM_API_TOKEN = os.getenv('EXTERNAL_LLM_API_TOKEN', 'default_external_llm_token')
    try:
        payload = {'input_text': input_text}
        headers = {'Authorization': f'Bearer {EXTERNAL_LLM_API_TOKEN}'}
        response = requests.post(llm_service_url, json=payload, headers=headers)

        if response.status_code == 200:
            llm_response = response.json().get('response', 'No response received')
            return llm_response
        else:
            return f"LLM service returned an error: {response.status_code}"
    except Exception as e:
        return f"Error contacting LLM service: {e}"

# 이 코드 블록은 POST 요청이 필요한 Flask 경로 /process를 정의합니다. token_required 데코레이터를 사용하여 요청이 승인되었는지 확인합니다. 이 경로는 요청의 JSON 페이로드에서 input_text를 추출합니다. input_text가 없으면 오류를 반환합니다. input_text가 없으면 contact_fake_llm 함수를 통해 가상 언어 모델 서비스로 input_text를 전달하고 해당 서비스의 응답을 반환합니다. 앱은 직접 실행될 경우 기본 Flask 서버에서 실행됩니다.
@app.route('/process', methods=['POST'])
@token_required
def process_input():
    data = request.get_json()
    if not data or 'input_text' not in data:
        return jsonify({'error': 'No input_text provided'}), 400

    response = contact_fake_llm(data['input_text'])
    return jsonify({'response': response})

if __name__ == '__main__':
    app.run()

 
 
DoS 취약성을 완화

DoS 취약성을 완화하려면 여러 가지 메커니즘이 필요합니다.

1. 동시에 처리할 수 있는 요청 수 제한
2. 동시에 실행할 수 있는 계산적으로 비용이 많이 드는 프롬프트의 수를 감지하고 제한합니다
3. API가 LLM에 다른 프롬프트를 전송하여 처리하기 전에 기다려야 하는 기간을 제한합니다

이러한 메커니즘을 프로덕션 환경에 배포하기 전에 LLM API에 대한 스트레스 테스트를 수행하여 보안 메커니즘의 적절한 값을 결정해야 합니다
 
 

from flask import Flask, request, jsonify
from flask_sqlalchemy import SQLAlchemy
from functools import wraps
import requests
import os
from datetime import datetime
import time
from threading import Semaphore

app = Flask(__name__)

# Load environment variables from a .env file if present
from dotenv import load_dotenv
load_dotenv()

# Fetch SQL server credentials and other details from environment variables
DATABASE_USER = os.getenv('DATABASE_USER')
DATABASE_PASSWORD = os.getenv('DATABASE_PASSWORD')
DATABASE_HOST = os.getenv('DATABASE_HOST')
DATABASE_PORT = os.getenv('DATABASE_PORT', '5432') # Default port for PostgreSQL
DATABASE_NAME = os.getenv('DATABASE_NAME')

# Construct the SQLALCHEMY_DATABASE_URI from environment variables
app.config['SQLALCHEMY_DATABASE_URI'] = f'postgresql://{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_NAME}'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
db = SQLAlchemy(app)

# Model for API tokens
class APIToken(db.Model):
    __tablename__ = 'api_tokens'

    id = db.Column(db.Integer, primary_key=True)
    token = db.Column(db.String(255), unique=True, nullable=False)
    user_id = db.Column(db.Integer, nullable=False)
    expires_at = db.Column(db.DateTime, nullable=False)

def token_required(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        auth_header = request.headers.get('Authorization')
        if auth_header and auth_header.startswith('Bearer '):
            user_token = auth_header[7:] # Extract token from header
        else:
            return jsonify({'error': 'Token is missing or improperly formatted'}), 401

        # Validate the user token against the database
        api_token = APIToken.query.filter(APIToken.token == user_token, APIToken.expires_at > datetime.utcnow()).first()

        if not api_token:
            return jsonify({'error': 'User token is invalid or expired'}), 401

        return f(*args, **kwargs)
    return decorated_function
# detect_costly_prompt 함수는 주어진 프롬프트에서 미리 정의된 계산 집약적 키워드를 검색하여, 키워드가 발견되면 부울 값과 검색된 키워드 목록을 반환하고, 발견되지 않으면 False와 None을 반환합니다. 정확성을 보장하기 위해 대소문자를 구분하지 않는 매칭 방식을 사용합니다.
# 스트레스 테스트를 실시한 후, 탐지에 사용되는 프롬프트 목록도 LLM의 역량에 따라 수정해야 합니다.
def detect_costly_prompt(prompt):
    keywords = [
        "long-form generation",
        "language translation",
        "natural language understanding",
        "data analysis",
        "multimodal tasks",
        "continual learning",
        "personalized recommendations",
        "adversarial attacks"
    ]

    # Convert prompt to lowercase for case-insensitive matching
    prompt_lower = prompt.lower()

    # Check if any keywords are present in the prompt
    detected_keywords = [keyword for keyword in keywords if keyword in prompt_lower]

    if detected_keywords:
        return True, detected_keywords
    else:
        return False, None

def contact_fake_llm(input_text):
    llm_service_url = 'http://fake-llm-service.com/process'
    # Fetch the API token for the LLM service from environment variables
    EXTERNAL_LLM_API_TOKEN = os.getenv('EXTERNAL_LLM_API_TOKEN', 'default_external_llm_token')
    # Define the maximum execution time for requests (in seconds)
    # MAX_EXECUTION_TIME = 5 외부 LLM 서비스에 대한 HTTP 요청에 대해 5초의 시간 제한을 설정합니다. 이를 통해 함수가 응답을 너무 오래 기다리는 것을 방지하여 애플리케이션의 응답성을 유지하고 요청 중단으로 인한 리소스 고갈을 방지합니다.
    # 적절한 시간 초과 설정을 위해서는 LLM과 해당 API에 대해 예상 사용자 트래픽을 기반으로 스트레스 테스트를 실시해야 합니다. 테스트 결과를 분석하여 최적의 시간 초과 기간으로 설정을 조정해야 합니다.
    MAX_EXECUTION_TIME = 5 # Adjust as needed
    try:
        payload = {'input_text': input_text}
        headers = {'Authorization': f'Bearer {EXTERNAL_LLM_API_TOKEN}'}
        # Send a request to the LLM service with a timeout
        response = requests.post(llm_service_url, json=payload, headers=headers,
                                 timeout=MAX_EXECUTION_TIME)
        if response.status_code == 200:
            llm_response = response.json().get('response', 'No response received')
            return llm_response
        else:
            return f"LLM service returned an error: {response.status_code}"
    except requests.Timeout:
        return "Request timed out: Exceeded maximum execution time"
    except requests.RequestException as e:
        return f"Error contacting LLM service: {e}"

# 이 Flask 경로 /process POST 요청을 예상하고 token_required 데코레이터를 통한 인증을 요구합니다. 요청에서 JSON 데이터를 추출하여 input_text 필드가 포함되어 있는지 확인합니다. 포함되어 있지 않으면 오류 응답을 반환합니다.
@app.route('/process', methods=['POST'])
@token_required
# process_input 함수는 MAX_CONCURRENT_REQUESTS(10)로 정의된 세마포어를 사용하여 외부 LLM 서비스에 대한 동시 액세스를 제한함으로써 시스템이 과부하되지 않고 응답성을 유지하도록 보장합니다.
def process_input():
    data = request.get_json()
    if not data or 'input_text' not in data:
        return jsonify({'error': 'No input_text provided'}), 400

    # Detect if the prompt is computationally costly
    # 그런 다음 함수는 detect_costly_prompt 사용하여 입력 프롬프트가 계산 집약적인지 확인합니다. 계산 집약적인 경우, 실행을 의도적으로 30초 지연합니다. 이 지연 시간은 필요에 따라 조정할 수 있습니다.
    is_costly, detected_keywords = detect_costly_prompt(data['input_text'])

    if is_costly:
        # If the prompt is computationally costly, delay the execution substantially
        time.sleep(30) # Adjust the delay time as needed

    # Define the maximum number of concurrent requests
    MAX_CONCURRENT_REQUESTS = 10 # Adjust as needed

    # Semaphore to control access to shared resources (e.g., number of concurrent requests)
    # 세마포어는 동시 시스템에서 여러 프로세스가 공유 리소스에 접근하는 것을 제어하여 질서 있고 안전한 접근을 보장하는 동기화 메커니즘입니다.
    # 이 경우에도 세마포어가 활용하는 MAX_CONCURRENT_REQUESTS 값은 적절한 값을 결정하기 위해 스트레스 테스트를 기반으로 조정되어야 합니다.
    request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)

    # Acquire semaphore before processing the request
    with request_semaphore:
        response = contact_fake_llm(data['input_text'])
        return jsonify({'response': response})

if __name__ == '__main__':
    app.run()

댓글 달기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다

위로 스크롤