Skip to content

Commit e9ab3f9

Browse files
authored
Merge pull request #9 from Gurummang/develop
KAN-456 <feat: set RabbitMQ for alerts>
2 parents 51a2a53 + 6505dad commit e9ab3f9

File tree

4 files changed

+142
-20
lines changed

4 files changed

+142
-20
lines changed

app/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,19 @@
3434
# Exchange 설정
3535
EXCHANGE_NAME = os.getenv("RABBITMQ_EXCHANGE_NAME")
3636
EXCHANGE_TYPE = os.getenv("RABBITMQ_EXCHANGE_TYPE")
37+
ALERT_EXCHANGE_NAME = os.getenv("RABBITMQ_ALERT_EXCHANGE")
3738

3839
# Queue 설정
3940
DOC_SCAN_QUEUE = os.getenv("RABBITMQ_DOC_QUEUE_NAME")
4041
EXE_SCAN_QUEUE = os.getenv("RABBITMQ_EXE_QUEUE_NAME")
4142
IMG_SCAN_QUEUE = os.getenv("RABBITMQ_IMG_QUEUE_NAME")
43+
ALERT_QUEUE = os.getenv("RABBITMQ_SUSPICIOUS_QUEUE")
4244

4345
# Routing Key 설정
4446
EXE_ROUTING_KEY = os.getenv("RABBITMQ_EXE_ROUTING_KEY")
4547
IMG_ROUTING_KEY = os.getenv("RABBITMQ_IMG_ROUTING_KEY")
4648
DOC_ROUTING_KEY = os.getenv("RABBITMQ_DOC_ROUTING_KEY")
49+
ALERT_ROUTING_KEY = os.getenv("RABBITMQ_SUSPICIOUS_ROUTING_KEY")
4750

4851
# S3
4952
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")

app/models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pydantic import BaseModel
2-
from sqlalchemy import BigInteger, Boolean, Column, Integer, Text
2+
from sqlalchemy import BigInteger, Boolean, Column, Integer, Text, String, DateTime
33
from sqlalchemy.ext.declarative import declarative_base
44

55
Base = declarative_base()
@@ -34,3 +34,13 @@ class FileStatus(Base):
3434
gscan_status = Column(Boolean, nullable=True, default=-1)
3535
dlp_status = Column(Boolean, nullable=True, default=-1)
3636
vt_status = Column(Boolean, nullable=True, default=-1)
37+
38+
39+
class FileUpload(Base):
40+
__tablename__ = "file_upload"
41+
id = Column(BigInteger, primary_key=True, autoincrement=True)
42+
org_saas_id = Column(Integer, nullable=False)
43+
saas_file_id = Column(String(255), nullable=True)
44+
upload_ts = Column(DateTime, nullable=True)
45+
salted_hash = Column(String(255), nullable=True)
46+
deleted = Column(Boolean, default=False, nullable=False)

app/rabbitmq_sender.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import logging
2+
import ssl
3+
import time
4+
import struct
5+
import pika
6+
7+
from app import (
8+
ALERT_EXCHANGE_NAME,
9+
ALERT_ROUTING_KEY,
10+
EXCHANGE_TYPE,
11+
RABBITMQ_HOST,
12+
RABBITMQ_PASSWORD,
13+
RABBITMQ_PORT,
14+
RABBITMQ_SSL_ENABLED,
15+
RABBITMQ_USER,
16+
RETRY_INTERVAL,
17+
)
18+
19+
# SSL 설정
20+
ssl_options = None
21+
if RABBITMQ_SSL_ENABLED:
22+
ssl_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
23+
ssl_options = pika.SSLOptions(context=ssl_context)
24+
25+
26+
def connect_to_rabbitmq():
27+
while True:
28+
try:
29+
credentials = pika.PlainCredentials(RABBITMQ_USER, RABBITMQ_PASSWORD)
30+
parameters = pika.ConnectionParameters(
31+
host=RABBITMQ_HOST,
32+
port=int(RABBITMQ_PORT),
33+
credentials=credentials,
34+
ssl_options=ssl_options,
35+
connection_attempts=3,
36+
retry_delay=5,
37+
socket_timeout=10.0, # 타임아웃 설정 (초)
38+
)
39+
connection = pika.BlockingConnection(parameters)
40+
return connection
41+
except pika.exceptions.AMQPConnectionError as e:
42+
logging.error(
43+
f"Connection failed, retrying in {RETRY_INTERVAL} seconds... Error: {e}"
44+
)
45+
time.sleep(RETRY_INTERVAL)
46+
47+
48+
def send_message(message: int):
49+
connection = connect_to_rabbitmq()
50+
channel = connection.channel()
51+
52+
# Exchange 선언
53+
channel.exchange_declare(
54+
exchange=ALERT_EXCHANGE_NAME, exchange_type=EXCHANGE_TYPE, durable=True
55+
)
56+
57+
# int 메시지를 바이트로 변환
58+
message_bytes = struct.pack('!Q', message) # '!Q'는 unsigned long long 형식입니다.
59+
60+
channel.basic_publish(
61+
exchange=ALERT_EXCHANGE_NAME,
62+
routing_key=ALERT_ROUTING_KEY, # 적절한 라우팅 키로 변경
63+
body=message_bytes,
64+
properties=pika.BasicProperties(
65+
delivery_mode=2 # 메시지 영속화
66+
)
67+
)
68+
69+
print(f"Sent message: {message}")
70+
connection.close()

app/utils.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
MYSQL_USER,
1717
)
1818
from app.models import FileScanRequest
19-
19+
from app.rabbitmq_sender import send_message
2020

2121
def load_yara_rules(directory):
2222
rule_files = []
@@ -74,7 +74,7 @@ def stream_file_from_s3(s3_key):
7474
raise
7575

7676

77-
def save_scan_result(file_id, detect, detail):
77+
def save_scan_result(uploadId: int, stored_file_id, detect, detail):
7878
try:
7979
conn = mysql.connector.connect(
8080
host=MYSQL_HOST, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DB
@@ -84,7 +84,7 @@ def save_scan_result(file_id, detect, detail):
8484
try:
8585
cursor.execute(
8686
"INSERT INTO scan_table (file_id, detect, step2_detail) VALUES (%s, %s, %s)",
87-
(file_id, detect, detail),
87+
(stored_file_id, detect, detail),
8888
)
8989
conn.commit() # 첫 번째 쿼리 커밋
9090
except Exception as e:
@@ -94,9 +94,10 @@ def save_scan_result(file_id, detect, detail):
9494

9595
try:
9696
cursor.execute(
97-
"UPDATE file_status SET gscan_status = 1 WHERE file_id = %s", (file_id,)
97+
"UPDATE file_status SET gscan_status = 1 WHERE file_id = %s", (stored_file_id,)
9898
)
9999
conn.commit() # 두 번째 쿼리 커밋
100+
send_message(uploadId) # RabbitMQ 전송: Alerts
100101
except Exception as e:
101102
conn.rollback() # 두 번째 쿼리 롤백
102103
logging.error(f"Failed to update file_status: {e}")
@@ -147,21 +148,17 @@ def yara_test_match(file_path, yara_rules):
147148
return detect, detail
148149

149150

150-
def scan_file(file_id: int, yara_rules):
151+
def scan_file(upload_id: int, yara_rules):
151152
try:
152-
conn = mysql.connector.connect(
153-
host=MYSQL_HOST, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DB
154-
)
155-
cursor = conn.cursor(dictionary=True)
156-
cursor.execute("SELECT * FROM stored_file WHERE id = %s", (file_id,))
157-
file_record = cursor.fetchone()
158-
cursor.close()
159-
conn.close()
160-
161-
if not file_record:
162-
raise HTTPException(status_code=404, detail="File not found")
163-
164-
s3_key = file_record["save_path"]
153+
# 파일 업로드 정보 가져오기
154+
file_record = get_file_upload(upload_id)
155+
salted_hash = file_record["salted_hash"]
156+
157+
# 저장된 파일 정보 가져오기
158+
stored_file_record = get_stored_file(salted_hash)
159+
stored_file_id = stored_file_record["id"]
160+
s3_key = stored_file_record["save_path"]
161+
165162
file_stream = stream_file_from_s3(s3_key)
166163

167164
# 파일 전체를 한 번에 읽음
@@ -182,7 +179,49 @@ def scan_file(file_id: int, yara_rules):
182179
logging.info(f"detail: {detail}")
183180
logging.info(f"most_common_keyword: {most_common_keyword}")
184181

185-
save_scan_result(file_id, detect, most_common_keyword)
182+
save_scan_result(upload_id, stored_file_id, detect, most_common_keyword)
186183
except Exception as e:
187184
logging.error(f"Error scanning file: {e}")
188185
raise HTTPException(status_code=500, detail="Error scanning file")
186+
187+
188+
def get_stored_file(hash:str):
189+
try:
190+
conn = mysql.connector.connect(
191+
host=MYSQL_HOST, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DB
192+
)
193+
cursor = conn.cursor(dictionary=True)
194+
cursor.execute("SELECT * FROM stored_file WHERE salted_hash = %s", (hash,))
195+
stored_file_record = cursor.fetchone()
196+
cursor.close()
197+
conn.close()
198+
199+
if not stored_file_record:
200+
raise HTTPException(status_code=404, detail="File not found in stored_file table")
201+
202+
return stored_file_record
203+
204+
except Exception as e:
205+
logging.error(f"Error fetching stored file record: {e}")
206+
raise HTTPException(status_code=500, detail="Error fetching stored file record")
207+
208+
209+
def get_file_upload(file_id: int):
210+
try:
211+
conn = mysql.connector.connect(
212+
host=MYSQL_HOST, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DB
213+
)
214+
cursor = conn.cursor(dictionary=True)
215+
cursor.execute("SELECT * FROM file_upload WHERE id = %s", (file_id,))
216+
file_record = cursor.fetchone()
217+
cursor.close()
218+
conn.close()
219+
220+
if not file_record:
221+
raise HTTPException(status_code=404, detail="File not found in file_upload table")
222+
223+
return file_record
224+
225+
except Exception as e:
226+
logging.error(f"Error fetching file upload record: {e}")
227+
raise HTTPException(status_code=500, detail="Error fetching file upload record")

0 commit comments

Comments
 (0)