backend-junPreP Update(Update whitelist model), front-urlpredictor.jsx UI Update
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
backend/app/__pycache__/predictor.cpython-310.pyc
Normal file
BIN
backend/app/__pycache__/predictor.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
@@ -1,52 +1,53 @@
|
||||
from app.junPreP import extract_features
|
||||
import numpy as np
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from tensorflow.keras.models import load_model
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
# 모델 및 스케일러 경로 (FastAPI 기준으로 맞춰서 절대 경로 또는 경로 설정)
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
MODEL_PATH = os.path.join(BASE_DIR, "models", "Recall_0.77.keras")
|
||||
SCALER_PATH = os.path.join(BASE_DIR, "models", "scaler.pkl")
|
||||
|
||||
# 모델 및 스케일러 로드 (1회만 수행)
|
||||
model = load_model(MODEL_PATH)
|
||||
with open(SCALER_PATH, 'rb') as f:
|
||||
scaler = pickle.load(f)
|
||||
|
||||
# @tf.function으로 추론 최적화
|
||||
@tf.function(reduce_retracing=True)
|
||||
def predict_with_model(model, input_data):
|
||||
return model(input_data)
|
||||
|
||||
# Threshold (적절히 조정 가능)
|
||||
BEST_THRESHOLD = 0.4034
|
||||
|
||||
# 📦 예측 함수 정의 (FastAPI에서 import해서 사용)
|
||||
def predict_url_maliciousness(url: str) -> dict:
|
||||
# 특성 추출
|
||||
features = extract_features(url)
|
||||
input_df = pd.DataFrame([list(features.values())], columns=features.keys())
|
||||
|
||||
# 스케일링
|
||||
input_scaled = scaler.transform(input_df)
|
||||
|
||||
# 예측
|
||||
prediction = predict_with_model(model, input_scaled)
|
||||
malicious_prob = float(prediction[0][0])
|
||||
|
||||
# 임계값 기반 판단
|
||||
is_malicious = bool(malicious_prob > BEST_THRESHOLD)
|
||||
|
||||
# Ensure all values are Python native types (not numpy types)
|
||||
return {
|
||||
"url": str(url),
|
||||
"malicious_probability": float(malicious_prob),
|
||||
"is_malicious": bool(is_malicious),
|
||||
"threshold": float(BEST_THRESHOLD)
|
||||
}
|
||||
|
||||
|
||||
from app.junPreP import extract_features
|
||||
import numpy as np
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from tensorflow.keras.models import load_model
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
# 모델 및 스케일러 경로 (FastAPI 기준으로 맞춰서 절대 경로 또는 경로 설정)
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
MODEL_PATH = os.path.join(BASE_DIR, "models", "White_list_model.keras")
|
||||
SCALER_PATH = os.path.join(BASE_DIR, "models", "scaler.pkl")
|
||||
|
||||
# 모델 및 스케일러 로드 (1회만 수행)
|
||||
model = load_model(MODEL_PATH)
|
||||
with open(SCALER_PATH, 'rb') as f:
|
||||
scaler = pickle.load(f)
|
||||
|
||||
# @tf.function으로 추론 최적화
|
||||
@tf.function(reduce_retracing=True)
|
||||
def predict_with_model(model, input_data):
|
||||
return model(input_data)
|
||||
|
||||
# Threshold (적절히 조정 가능)
|
||||
BEST_THRESHOLD = 0.4034
|
||||
|
||||
# 📦 예측 함수 정의 (FastAPI에서 import해서 사용)
|
||||
def predict_url_maliciousness(url: str) -> dict:
|
||||
# 특성 추출
|
||||
features = extract_features(url)
|
||||
input_df = pd.DataFrame([list(features.values())], columns=features.keys())
|
||||
|
||||
# 스케일링
|
||||
input_scaled = scaler.transform(input_df)
|
||||
|
||||
# 예측
|
||||
prediction = predict_with_model(model, input_scaled)
|
||||
malicious_prob = float(prediction[0][0].numpy())
|
||||
|
||||
|
||||
# 임계값 기반 판단
|
||||
is_malicious = bool(malicious_prob > BEST_THRESHOLD)
|
||||
|
||||
# 예: malicious_probability가 np.float32 타입일 경우
|
||||
return {
|
||||
"url": str(url),
|
||||
"malicious_probability": malicious_prob,
|
||||
"is_malicious": is_malicious,
|
||||
"threshold": float(BEST_THRESHOLD)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,204 +1,274 @@
|
||||
import re
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import tldextract
|
||||
import zlib
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
from collections import Counter
|
||||
import math
|
||||
|
||||
|
||||
|
||||
def check_similar_brand(url):
|
||||
# 자주 사용되는 브랜드/도메인 목록
|
||||
common_brands = {
|
||||
'google', 'facebook', 'amazon', 'microsoft', 'apple',
|
||||
'netflix', 'paypal', 'twitter', 'instagram', 'linkedin',
|
||||
'youtube', 'yahoo', 'gmail', 'whatsapp', 'tiktok',
|
||||
'geocities', 'angelfire', 'newadvent', 'wikipedia',
|
||||
}
|
||||
|
||||
# 2. 유사 브랜드 확인
|
||||
try:
|
||||
# URL 파싱
|
||||
parsed = urlparse(url if '//' in url else '//' + url)
|
||||
domain = parsed.netloc.lower() if parsed.netloc else url.lower()
|
||||
|
||||
for brand in common_brands:
|
||||
if brand not in domain:
|
||||
similar = False
|
||||
# 비슷한 철자 패턴 확인
|
||||
patterns = [
|
||||
brand.replace('o', '0'),
|
||||
brand.replace('i', '1'),
|
||||
brand.replace('l', '1'),
|
||||
brand.replace('e', '3'),
|
||||
brand.replace('a', '4'),
|
||||
brand.replace('s', '5'),
|
||||
brand + '-',
|
||||
brand + '_',
|
||||
brand[:-1], # 마지막 문자 제거
|
||||
''.join(c + c for c in brand), # 문자 중복
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern in domain:
|
||||
similar = True
|
||||
break
|
||||
|
||||
if similar:
|
||||
return True # 유사 브랜드가 발견되면 True 반환
|
||||
|
||||
except Exception as e:
|
||||
return False # 예외 발생 시 False 반환
|
||||
|
||||
return False # 유사 브랜드가 없으면 False 반환
|
||||
|
||||
|
||||
|
||||
# url 압축 비율 계산 함수
|
||||
def compression_ratio(url: str) -> float:
|
||||
if not url:
|
||||
return 0.0
|
||||
original_length = len(url.encode('utf-8'))
|
||||
compressed_data = zlib.compress(url.encode('utf-8'))
|
||||
compressed_length = len(compressed_data)
|
||||
return compressed_length / original_length
|
||||
|
||||
|
||||
def extract_features(url):
|
||||
parsed_url = urlparse(url)
|
||||
suspicious_keywords = [
|
||||
'login', 'verify', 'account', 'update', 'secure', 'banking',
|
||||
'paypal', 'confirm', 'signin', 'auth', 'redirect', 'free',
|
||||
'bonus', 'admin', 'support', 'server', 'password', 'click',
|
||||
'urgent', 'immediate', 'alert', 'security', 'prompt'
|
||||
]
|
||||
|
||||
additional_keywords = [
|
||||
'verify', 'wallet', 'cryptocurrency', 'bitcoin', 'ethereum',
|
||||
'validation', 'authenticate', 'reset', 'recover', 'access',
|
||||
'limited', 'offer', 'prize', 'win', 'winner', 'payment',
|
||||
'bank', 'credit', 'debit', 'card', 'expire', 'suspension',
|
||||
'unusual', 'activity', 'verify', 'document', 'invoice'
|
||||
]
|
||||
|
||||
all_keywords = list(set(suspicious_keywords + additional_keywords))
|
||||
|
||||
contains_keyword = 0
|
||||
keyword_count = 0
|
||||
for keyword in all_keywords:
|
||||
if re.search(r'\b' + keyword + r'\b', url, re.IGNORECASE):
|
||||
contains_keyword = 1
|
||||
keyword_count += 1
|
||||
|
||||
url_length = len(url)
|
||||
extracted = tldextract.extract(url)
|
||||
tld = extracted.suffix
|
||||
domain = extracted.domain
|
||||
subdomain = extracted.subdomain
|
||||
|
||||
tld_length = len(tld) if tld else 0
|
||||
common_tlds = ['com', 'org', 'net', 'edu', 'gov', 'mil', 'io', 'co', 'info', 'biz']
|
||||
is_common_tld = 1 if tld in common_tlds else 0
|
||||
country_tlds = ['us', 'uk', 'ca', 'au', 'de', 'fr', 'jp', 'cn', 'ru', 'br', 'in', 'it', 'es']
|
||||
is_country_tld = 1 if tld in country_tlds else 0
|
||||
suspicious_tlds = ['xyz', 'top', 'club', 'online', 'site', 'icu', 'vip', 'work', 'rest', 'fit']
|
||||
is_suspicious_tld = 1 if tld in suspicious_tlds else 0
|
||||
url_shorteners = ['bit.ly', 'tinyurl.com', 'goo.gl', 't.co', 'ow.ly', 'is.gd', 'buff.ly', 'adf.ly', 'tiny.cc']
|
||||
full_domain = f"{domain}.{tld}" if tld else domain
|
||||
is_shortened = 1 if full_domain in url_shorteners else 0
|
||||
|
||||
|
||||
domain_length = len(domain) if domain else 0
|
||||
has_subdomain = 1 if subdomain else 0
|
||||
subdomain_length = len(subdomain) if subdomain else 0
|
||||
subdomain_count = len(subdomain.split('.')) if subdomain else 0
|
||||
|
||||
path = parsed_url.path
|
||||
path_length = len(path)
|
||||
path_depth = path.count('/') if path else 0
|
||||
|
||||
query = parsed_url.query
|
||||
has_query = 1 if query else 0
|
||||
query_length = len(query) if query else 0
|
||||
query_params = parse_qs(query)
|
||||
query_param_count = len(query_params) if query_params else 0
|
||||
|
||||
has_fragment = 1 if parsed_url.fragment else 0
|
||||
fragment_length = len(parsed_url.fragment) if parsed_url.fragment else 0
|
||||
|
||||
# Character type ratios
|
||||
letter_count = sum(c.isalpha() for c in url)
|
||||
digit_count = sum(c.isdigit() for c in url)
|
||||
special_char_count = len(re.findall(r'[^a-zA-Z0-9]', url))
|
||||
|
||||
letter_ratio = letter_count / url_length if url_length > 0 else 0
|
||||
digit_ratio = digit_count / url_length if url_length > 0 else 0
|
||||
special_char_ratio = special_char_count / url_length if url_length > 0 else 0
|
||||
|
||||
# Character distribution and entropy
|
||||
if url:
|
||||
char_counts = Counter(url)
|
||||
total_chars = len(url)
|
||||
char_frequencies = {char: count/total_chars for char, count in char_counts.items()}
|
||||
entropy = -sum(freq * math.log2(freq) for freq in char_frequencies.values())
|
||||
else:
|
||||
entropy = 0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if url_length <= 13:
|
||||
url_length_cat = 0
|
||||
elif url_length <= 18:
|
||||
url_length_cat = 1
|
||||
elif url_length <= 25:
|
||||
url_length_cat = 2
|
||||
else:
|
||||
url_length_cat = 3
|
||||
|
||||
return {
|
||||
# "url_length": url_length,
|
||||
"url_length_cat": url_length_cat,
|
||||
"num_dots": url.count("."),
|
||||
"num_digits": sum(c.isdigit() for c in url),
|
||||
"num_special_chars": len(re.findall(r"[^a-zA-Z0-9]", url)),
|
||||
"url_keyword": contains_keyword,
|
||||
# "url_keyword_count": keyword_count,
|
||||
"num_underbar": url.count("_"),
|
||||
"extract_consecutive_numbers": int(bool(re.findall(r'(\d)\1+', url))),
|
||||
"number": int(bool(len(re.findall(r'(\d)(?!\1)(\d)(?!\2)(\d)', url)))),
|
||||
"upper": int(any(c.isupper() for c in url)),
|
||||
|
||||
"is_common_tld": is_common_tld,
|
||||
"is country_tld": is_country_tld,
|
||||
"is_suspicious_tld": is_suspicious_tld,
|
||||
|
||||
"domain_length": domain_length,
|
||||
"has_subdomain": has_subdomain,
|
||||
"subdomain_length": subdomain_length,
|
||||
"subdomain_count": subdomain_count,
|
||||
|
||||
# "path_length": path_length,
|
||||
"path_depth": path_depth,
|
||||
"has_query": has_query,
|
||||
"query_length": query_length,
|
||||
"query_param_count": query_param_count,
|
||||
# "has_fragment": has_fragment,
|
||||
# "fragment_length": fragment_length,
|
||||
"url_shorteners": is_shortened,
|
||||
|
||||
# 새로 추가된 특성
|
||||
"compression_ratio": compression_ratio(url),
|
||||
"check_similar_brand" : check_similar_brand(url),
|
||||
|
||||
# Advanced text analysis
|
||||
"entropy": entropy,
|
||||
#"letter_ratio": letter_ratio,
|
||||
"digit_ratio": digit_ratio,
|
||||
"special_char_ratio": special_char_ratio
|
||||
|
||||
|
||||
}
|
||||
import re
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
import tldextract
|
||||
import zlib
|
||||
from collections import Counter
|
||||
import math
|
||||
|
||||
def url_is_whitelisted(url):
|
||||
trusted_domains = [
|
||||
# 1. 포털 / 검색엔진
|
||||
'naver.com', 'daum.net', 'google.com', 'bing.com', 'yahoo.com',
|
||||
|
||||
# 2. 소셜 미디어 / 커뮤니케이션
|
||||
'facebook.com', 'instagram.com', 'twitter.com', 'x.com', 'linkedin.com',
|
||||
'whatsapp.com', 'kakao.com', 'kakaocorp.com',
|
||||
|
||||
# 3. 동영상 / 스트리밍
|
||||
'youtube.com', 'netflix.com', 'twitch.tv', 'tving.com', 'watcha.com',
|
||||
|
||||
# 4. 쇼핑 / 이커머스
|
||||
'amazon.com', 'gmarket.co.kr', '11st.co.kr', 'coupang.com', 'ssg.com', 'wemakeprice.com',
|
||||
|
||||
# 5. 금융 / 결제
|
||||
'paypal.com', 'kbfg.com', 'shinhan.com', 'hanafn.com', 'wooribank.com',
|
||||
'kakaobank.com', 'toss.im',
|
||||
|
||||
# 6. 공공기관 / 교육
|
||||
'gov.kr', 'moe.go.kr', 'epeople.go.kr', 'pusan.ac.kr', 'ac.kr',
|
||||
|
||||
# 7. IT / 기술
|
||||
'apple.com', 'microsoft.com', 'adobe.com', 'github.com', 'stackoverflow.com'
|
||||
]
|
||||
|
||||
try:
|
||||
domain = urlparse(url if '//' in url else '//' + url).netloc.lower()
|
||||
for trusted in trusted_domains:
|
||||
if domain.endswith(trusted):
|
||||
return True
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def check_similar_brand(url):
|
||||
# 자주 사용되는 브랜드/도메인 목록
|
||||
common_brands = {
|
||||
'google', 'facebook', 'amazon', 'microsoft', 'apple',
|
||||
'netflix', 'paypal', 'twitter', 'instagram', 'linkedin',
|
||||
'youtube', 'yahoo', 'gmail', 'whatsapp', 'tiktok',
|
||||
'geocities', 'angelfire', 'newadvent', 'wikipedia',
|
||||
}
|
||||
|
||||
# 2. 유사 브랜드 확인
|
||||
try:
|
||||
# URL 파싱
|
||||
parsed = urlparse(url if '//' in url else '//' + url)
|
||||
domain = parsed.netloc.lower() if parsed.netloc else url.lower()
|
||||
|
||||
for brand in common_brands:
|
||||
if brand not in domain:
|
||||
similar = False
|
||||
# 비슷한 철자 패턴 확인
|
||||
patterns = [
|
||||
brand.replace('o', '0'),
|
||||
brand.replace('i', '1'),
|
||||
brand.replace('l', '1'),
|
||||
brand.replace('e', '3'),
|
||||
brand.replace('a', '4'),
|
||||
brand.replace('s', '5'),
|
||||
brand + '-',
|
||||
brand + '_',
|
||||
brand[:-1], # 마지막 문자 제거
|
||||
''.join(c + c for c in brand), # 문자 중복
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern in domain:
|
||||
similar = True
|
||||
break
|
||||
|
||||
if similar:
|
||||
return True # 유사 브랜드가 발견되면 True 반환
|
||||
|
||||
except Exception as e:
|
||||
return False # 예외 발생 시 False 반환
|
||||
|
||||
return False # 유사 브랜드가 없으면 False 반환
|
||||
|
||||
|
||||
|
||||
# url 압축 비율 계산 함수
|
||||
def compression_ratio(url: str) -> float:
|
||||
if not url:
|
||||
return 0.0
|
||||
original_length = len(url.encode('utf-8'))
|
||||
compressed_data = zlib.compress(url.encode('utf-8'))
|
||||
compressed_length = len(compressed_data)
|
||||
return compressed_length / original_length
|
||||
|
||||
|
||||
def extract_features(url):
|
||||
parsed_url = urlparse(url)
|
||||
suspicious_keywords = [
|
||||
'login', 'verify', 'account', 'update', 'secure', 'banking',
|
||||
'paypal', 'confirm', 'signin', 'auth', 'redirect', 'free',
|
||||
'bonus', 'admin', 'support', 'server', 'password', 'click',
|
||||
'urgent', 'immediate', 'alert', 'security', 'prompt'
|
||||
]
|
||||
|
||||
additional_keywords = [
|
||||
'verify', 'wallet', 'cryptocurrency', 'bitcoin', 'ethereum',
|
||||
'validation', 'authenticate', 'reset', 'recover', 'access',
|
||||
'limited', 'offer', 'prize', 'win', 'winner', 'payment',
|
||||
'bank', 'credit', 'debit', 'card', 'expire', 'suspension',
|
||||
'unusual', 'activity', 'verify', 'document', 'invoice'
|
||||
]
|
||||
|
||||
all_keywords = list(set(suspicious_keywords + additional_keywords))
|
||||
|
||||
contains_keyword = 0
|
||||
keyword_count = 0
|
||||
for keyword in all_keywords:
|
||||
if re.search(r'\b' + keyword + r'\b', url, re.IGNORECASE):
|
||||
contains_keyword = 1
|
||||
keyword_count += 1
|
||||
|
||||
url_length = len(url)
|
||||
extracted = tldextract.extract(url)
|
||||
tld = extracted.suffix
|
||||
domain = extracted.domain
|
||||
subdomain = extracted.subdomain
|
||||
|
||||
tld_length = len(tld) if tld else 0
|
||||
common_tlds = ['com', 'org', 'net', 'edu', 'gov', 'mil', 'io', 'co', 'info', 'biz']
|
||||
is_common_tld = 1 if tld in common_tlds else 0
|
||||
country_tlds = ['us', 'uk', 'ca', 'au', 'de', 'fr', 'jp', 'cn', 'ru', 'br', 'in', 'it', 'es']
|
||||
is_country_tld = 1 if tld in country_tlds else 0
|
||||
suspicious_tlds = ['xyz', 'top', 'club', 'online', 'site', 'icu', 'vip', 'work', 'rest', 'fit']
|
||||
is_suspicious_tld = 1 if tld in suspicious_tlds else 0
|
||||
url_shorteners = ['bit.ly', 'tinyurl.com', 'goo.gl', 't.co', 'ow.ly', 'is.gd', 'buff.ly', 'adf.ly', 'tiny.cc']
|
||||
full_domain = f"{domain}.{tld}" if tld else domain
|
||||
is_shortened = 1 if full_domain in url_shorteners else 0
|
||||
|
||||
|
||||
domain_length = len(domain) if domain else 0
|
||||
has_subdomain = 1 if subdomain else 0
|
||||
subdomain_length = len(subdomain) if subdomain else 0
|
||||
subdomain_count = len(subdomain.split('.')) if subdomain else 0
|
||||
|
||||
path = parsed_url.path
|
||||
path_length = len(path)
|
||||
path_depth = path.count('/') if path else 0
|
||||
|
||||
query = parsed_url.query
|
||||
has_query = 1 if query else 0
|
||||
query_length = len(query) if query else 0
|
||||
query_params = parse_qs(query)
|
||||
query_param_count = len(query_params) if query_params else 0
|
||||
|
||||
has_fragment = 1 if parsed_url.fragment else 0
|
||||
fragment_length = len(parsed_url.fragment) if parsed_url.fragment else 0
|
||||
|
||||
# Character type ratios
|
||||
letter_count = sum(c.isalpha() for c in url)
|
||||
digit_count = sum(c.isdigit() for c in url)
|
||||
special_char_count = len(re.findall(r'[^a-zA-Z0-9]', url))
|
||||
|
||||
letter_ratio = letter_count / url_length if url_length > 0 else 0
|
||||
digit_ratio = digit_count / url_length if url_length > 0 else 0
|
||||
special_char_ratio = special_char_count / url_length if url_length > 0 else 0
|
||||
|
||||
# Character distribution and entropy
|
||||
if url:
|
||||
char_counts = Counter(url)
|
||||
total_chars = len(url)
|
||||
char_frequencies = {char: count/total_chars for char, count in char_counts.items()}
|
||||
entropy = -sum(freq * math.log2(freq) for freq in char_frequencies.values())
|
||||
else:
|
||||
entropy = 0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if url_length <= 13:
|
||||
url_length_cat = 0
|
||||
elif url_length <= 18:
|
||||
url_length_cat = 1
|
||||
elif url_length <= 25:
|
||||
url_length_cat = 2
|
||||
else:
|
||||
url_length_cat = 3
|
||||
|
||||
if url_is_whitelisted(url):
|
||||
return {
|
||||
# 화이트리스트 URL이면 특징값들을 "정상적"으로 처리되도록 설정
|
||||
"url_length_cat": 1,
|
||||
"num_dots": 1,
|
||||
"num_digits": 0,
|
||||
"num_special_chars": 1,
|
||||
"url_keyword": 0,
|
||||
"num_underbar": 0,
|
||||
"extract_consecutive_numbers": 0,
|
||||
"number": 0,
|
||||
"upper": 0,
|
||||
|
||||
"is_common_tld": 1,
|
||||
"is country_tld": 0,
|
||||
"is_suspicious_tld": 0,
|
||||
|
||||
"domain_length": 5,
|
||||
"has_subdomain": 0,
|
||||
"subdomain_length": 0,
|
||||
"subdomain_count": 0,
|
||||
|
||||
"path_depth": 0,
|
||||
"has_query": 0,
|
||||
"query_length": 0,
|
||||
"query_param_count": 0,
|
||||
"url_shorteners": 0,
|
||||
|
||||
"compression_ratio": 1.0,
|
||||
"check_similar_brand": 0,
|
||||
"entropy": 3.0,
|
||||
"digit_ratio": 0.0,
|
||||
"special_char_ratio": 0.1
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
|
||||
# "url_length": url_length,
|
||||
"url_length_cat": url_length_cat,
|
||||
"num_dots": url.count("."),
|
||||
"num_digits": sum(c.isdigit() for c in url),
|
||||
"num_special_chars": len(re.findall(r"[^a-zA-Z0-9]", url)),
|
||||
"url_keyword": contains_keyword,
|
||||
# "url_keyword_count": keyword_count,
|
||||
"num_underbar": url.count("_"),
|
||||
"extract_consecutive_numbers": int(bool(re.findall(r'(\d)\1+', url))),
|
||||
"number": int(bool(len(re.findall(r'(\d)(?!\1)(\d)(?!\2)(\d)', url)))),
|
||||
"upper": int(any(c.isupper() for c in url)),
|
||||
|
||||
"is_common_tld": is_common_tld,
|
||||
"is country_tld": is_country_tld,
|
||||
"is_suspicious_tld": is_suspicious_tld,
|
||||
|
||||
"domain_length": domain_length,
|
||||
"has_subdomain": has_subdomain,
|
||||
"subdomain_length": subdomain_length,
|
||||
"subdomain_count": subdomain_count,
|
||||
|
||||
# "path_length": path_length,
|
||||
"path_depth": path_depth,
|
||||
"has_query": has_query,
|
||||
"query_length": query_length,
|
||||
"query_param_count": query_param_count,
|
||||
# "has_fragment": has_fragment,
|
||||
# "fragment_length": fragment_length,
|
||||
"url_shorteners": is_shortened,
|
||||
|
||||
# 새로 추가된 특성
|
||||
"compression_ratio": compression_ratio(url),
|
||||
"check_similar_brand" : check_similar_brand(url),
|
||||
|
||||
# Advanced text analysis
|
||||
"entropy": entropy,
|
||||
#"letter_ratio": letter_ratio,
|
||||
"digit_ratio": digit_ratio,
|
||||
"special_char_ratio": special_char_ratio
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from app.model_load import use_model # predictor.py에서 함수 import
|
||||
from app.exe import predict_url_maliciousness
|
||||
from app.utils import convert_numpy_to_python_types
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
@@ -28,13 +27,15 @@ def root():
|
||||
def predict(request: UrlRequest):
|
||||
url = request.url
|
||||
|
||||
result_model1 = convert_numpy_to_python_types(use_model(url))
|
||||
result_model2 = convert_numpy_to_python_types(predict_url_maliciousness(url))
|
||||
|
||||
response_data = {
|
||||
"url": url,
|
||||
"model1": result_model1,
|
||||
"model2": result_model2
|
||||
}
|
||||
|
||||
return convert_numpy_to_python_types(response_data)
|
||||
result_model1 = use_model(url)
|
||||
result_model2 = predict_url_maliciousness(url)
|
||||
# print("model1 : ")
|
||||
# print(result_model1.values())
|
||||
# print("model2 : ")
|
||||
# print(result_model2.values())
|
||||
|
||||
return {
|
||||
"url" : url,
|
||||
"model1": result_model1,
|
||||
"model2": result_model2
|
||||
}
|
||||
|
||||
@@ -29,6 +29,12 @@ def use_model(url : str):
|
||||
input_data = featured_df[features_cols]
|
||||
|
||||
# 학습된 모델에 적용
|
||||
model_pred = round(float(np.mean([model.predict_proba(input_data)[:, 1] for model in models_load])), 4)
|
||||
model_pred = round(np.mean([model.predict_proba(input_data)[:, 1] for model in models_load]), 4)
|
||||
|
||||
return model_pred
|
||||
#return model_pred
|
||||
return {
|
||||
"url" : url,
|
||||
"malicious_probability" : float(model_pred),
|
||||
"is_malicious" : bool(model_pred > best_threshold),
|
||||
"threshold" : float(best_threshold)
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ def predict_url(url: str) -> dict:
|
||||
input_data = preprocessed[features_cols]
|
||||
|
||||
# ✅ 전처리된 데이터 확인
|
||||
print("Preprocessed input:", input_data)
|
||||
#print("Preprocessed input:", input_data)
|
||||
|
||||
# 평균 확률 계산
|
||||
probs = [float(model.predict_proba(input_data)[0, 1]) for model in models_load]
|
||||
@@ -61,8 +61,8 @@ def predict_url(url: str) -> dict:
|
||||
# 예: malicious_probability가 np.float32 타입일 경우
|
||||
return {
|
||||
"url": url,
|
||||
"malicious_probability": mean_pred, # ⬅️ numpy -> float
|
||||
"is_malicious": bool(is_malicious), # ⬅️ numpy -> bool
|
||||
"malicious_probability": mean_pred,
|
||||
"is_malicious": is_malicious,
|
||||
"threshold": float(BEST_THRESHOLD) # ⬅️ numpy -> float
|
||||
}
|
||||
|
||||
|
||||
4
backend/app/testexe.py
Normal file
4
backend/app/testexe.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from exe import predict_url_maliciousness
|
||||
|
||||
result_model2 = predict_url_maliciousness("www.naver.com")
|
||||
print(result_model2)
|
||||
@@ -1,18 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
def convert_numpy_to_python_types(obj):
|
||||
"""
|
||||
Recursively convert numpy types to native Python types.
|
||||
"""
|
||||
if isinstance(obj, np.ndarray):
|
||||
return convert_numpy_to_python_types(obj.tolist())
|
||||
elif isinstance(obj, np.number):
|
||||
return float(obj) if isinstance(obj, np.floating) else int(obj)
|
||||
elif isinstance(obj, np.bool_):
|
||||
return bool(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_numpy_to_python_types(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||
return [convert_numpy_to_python_types(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
Reference in New Issue
Block a user