Skip to content

Commit db7208f

Browse files
authored
Merge pull request #30 from Gurummang/develop
Develop
2 parents e339932 + e99eb3a commit db7208f

File tree

3 files changed

+114
-70
lines changed

3 files changed

+114
-70
lines changed

app/rabbitmq_consumer.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,33 +51,34 @@ async def connect_to_rabbitmq() -> Optional[Connection]:
5151

5252

5353
async def on_message(message: IncomingMessage, yara_rules):
54-
async with message.process():
54+
try:
55+
body = message.body
56+
logging.info(f"Received message: {body}")
57+
58+
if not body:
59+
logging.error("Received empty message")
60+
await message.nack(requeue=False) # 재처리 없이 nack
61+
return
62+
63+
try:
64+
message_str = body.decode("utf-8")
65+
logging.info(f"Decoded message: {message_str}")
66+
except UnicodeDecodeError:
67+
logging.error(f"Failed to decode message: {body}")
68+
await message.nack(requeue=False) # 재처리 없이 nack
69+
return
70+
5571
try:
56-
body = message.body
57-
logging.info(f"Received message: {body}")
58-
59-
if not body:
60-
logging.error("Received empty message")
61-
return
62-
63-
try:
64-
message_str = body.decode("utf-8")
65-
logging.info(f"Decoded message: {message_str}")
66-
except UnicodeDecodeError:
67-
logging.error(f"Failed to decode message: {body}")
68-
await message.nack(requeue=False) # 잘못된 메시지 재처리 안 함
69-
return
70-
71-
try:
72-
file_id = int(message_str)
73-
logging.info(f"Processing file with ID: {file_id}")
74-
await scan_file(file_id, yara_rules)
75-
except ValueError:
76-
logging.error(f"Invalid file ID format: {message_str}")
77-
await message.nack(requeue=False) # 잘못된 파일 ID 재처리 안 함
78-
except Exception as e:
79-
logging.exception(f"Error processing message: {e}")
80-
await message.nack(requeue=True) # 예외 발생 시 메시지 재처리 가능
72+
file_id = int(message_str)
73+
logging.info(f"Processing file with ID: {file_id}")
74+
await scan_file(file_id, yara_rules)
75+
await message.ack() # 성공적으로 처리되면 ack
76+
except ValueError:
77+
logging.error(f"Invalid file ID format: {message_str}")
78+
await message.nack(requeue=False) # 잘못된 파일 ID는 재처리 안 함
79+
except Exception as e:
80+
logging.exception(f"Error processing message: {e}")
81+
await message.nack(requeue=False) # 예외 발생 시 메시지를 재처리 가능
8182

8283

8384
async def start_consuming(queue_name: str, yara_rules, routing_key: str):
@@ -96,11 +97,20 @@ async def start_consuming(queue_name: str, yara_rules, routing_key: str):
9697
await queue.bind(exchange, routing_key)
9798

9899
logging.info(f"Waiting for messages in {queue_name}. To exit press CTRL+C")
99-
await queue.consume(lambda message: on_message(message, yara_rules))
100100

101+
async def shutdown():
102+
logging.info("Shutting down consumer.")
103+
await connection.close()
104+
105+
# 메시지를 처리하는 부분에 shutdown 처리 로직을 추가
106+
loop = asyncio.get_event_loop()
101107
try:
108+
await queue.consume(lambda message: on_message(message, yara_rules))
102109
await asyncio.Future() # 무한 대기
110+
except asyncio.CancelledError:
111+
await shutdown()
103112
finally:
104-
await connection.close()
113+
await shutdown() # 종료 시 RabbitMQ 연결 정리
105114
except Exception as e:
106115
logging.exception(f"Error in start_consuming: {e}")
116+

app/rabbitmq_sender.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22
import ssl
3-
import time
43
import struct
54
import pika
5+
import aio_pika
6+
import asyncio
67

78
from app import (
89
ALERT_EXCHANGE_NAME,
@@ -15,6 +16,7 @@
1516
RABBITMQ_USER,
1617
RETRY_INTERVAL,
1718
)
19+
MAX_RETRIES = 10
1820

1921
# SSL 설정
2022
ssl_options = None
@@ -23,48 +25,60 @@
2325
ssl_options = pika.SSLOptions(context=ssl_context)
2426

2527

26-
def connect_to_rabbitmq():
27-
while True:
28+
async def connect_to_rabbitmq():
29+
retry_count = 0
30+
while retry_count < MAX_RETRIES:
2831
try:
29-
credentials = pika.PlainCredentials(RABBITMQ_USER, RABBITMQ_PASSWORD)
30-
parameters = pika.ConnectionParameters(
32+
ssl_context = None
33+
if RABBITMQ_SSL_ENABLED:
34+
ssl_context = ssl.create_default_context()
35+
ssl_context.check_hostname = False
36+
ssl_context.verify_mode = ssl.CERT_NONE
37+
38+
connection = await aio_pika.connect_robust(
3139
host=RABBITMQ_HOST,
3240
port=int(RABBITMQ_PORT),
33-
credentials=credentials,
34-
ssl_options=ssl_options,
35-
connection_attempts=3,
36-
retry_delay=5,
37-
socket_timeout=10.0, # 타임아웃 설정 (초)
41+
login=RABBITMQ_USER,
42+
password=RABBITMQ_PASSWORD,
43+
ssl=ssl_context,
44+
loop=asyncio.get_event_loop() # asyncio 이벤트 루프 사용
3845
)
39-
connection = pika.BlockingConnection(parameters)
4046
return connection
41-
except pika.exceptions.AMQPConnectionError as e:
47+
except aio_pika.AMQPConnectionError as e:
48+
retry_count += 1
4249
logging.error(
43-
f"Connection failed, retrying in {RETRY_INTERVAL} seconds... Error: {e}"
50+
f"Connection failed, retrying ({retry_count}/{MAX_RETRIES}) in {RETRY_INTERVAL} seconds... Error: {e}"
4451
)
45-
time.sleep(RETRY_INTERVAL)
52+
await asyncio.sleep(RETRY_INTERVAL)
53+
54+
logging.error(f"Failed to connect to RabbitMQ after {MAX_RETRIES} attempts.")
55+
return None
4656

4757

48-
def send_message(message: int):
49-
connection = connect_to_rabbitmq()
50-
channel = connection.channel()
58+
async def send_message(message: int):
59+
connection = await connect_to_rabbitmq()
60+
if not connection:
61+
logging.error("Failed to establish connection to RabbitMQ.")
62+
return
5163

52-
# Exchange 선언
53-
channel.exchange_declare(
54-
exchange=ALERT_EXCHANGE_NAME, exchange_type=EXCHANGE_TYPE, durable=True
55-
)
64+
async with connection:
65+
channel = await connection.channel()
66+
67+
# Exchange 선언
68+
exchange = await channel.declare_exchange(
69+
ALERT_EXCHANGE_NAME, aio_pika.ExchangeType(EXCHANGE_TYPE), durable=True
70+
)
5671

57-
# int 메시지를 바이트로 변환
58-
message_bytes = struct.pack('!Q', message) # '!Q'는 unsigned long long 형식입니다.
72+
# int 메시지를 바이트로 변환
73+
message_bytes = struct.pack('!Q', message) # '!Q'는 unsigned long long 형식입니다.
5974

60-
channel.basic_publish(
61-
exchange=ALERT_EXCHANGE_NAME,
62-
routing_key=ALERT_ROUTING_KEY, # 적절한 라우팅 키로 변경
63-
body=message_bytes,
64-
properties=pika.BasicProperties(
65-
delivery_mode=2 # 메시지 영속화
75+
await exchange.publish(
76+
aio_pika.Message(
77+
body=message_bytes,
78+
delivery_mode=aio_pika.DeliveryMode.PERSISTENT
79+
),
80+
routing_key=ALERT_ROUTING_KEY
6681
)
67-
)
6882

69-
print(f"Sent message: {message}")
70-
connection.close()
83+
logging.info(f"Sent message: {message}")
84+
# connection.close()

app/utils.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
33
import os
4+
import functools
45
from collections import defaultdict
56
from datetime import datetime
67
import pytz
@@ -86,13 +87,18 @@ async def stream_file_from_s3(s3_key):
8687

8788
try:
8889
loop = asyncio.get_event_loop()
89-
response = await loop.run_in_executor(None, s3_client.get_object, {"Bucket": bucket_name, "Key": key})
90+
# functools.partial로 키워드 인자를 전달할 수 있도록 함
91+
response = await loop.run_in_executor(
92+
None, functools.partial(s3_client.get_object, Bucket=bucket_name, Key=key)
93+
)
9094
return response["Body"]
9195
except Exception as e:
9296
logging.error(f"Failed to stream file from S3: {e}")
9397
raise
9498

9599

100+
101+
96102
async def save_scan_result(uploadId: int, stored_file_id, detect, detail):
97103
try:
98104
conn = await aiomysql.connect(
@@ -101,7 +107,8 @@ async def save_scan_result(uploadId: int, stored_file_id, detect, detail):
101107
async with conn.cursor() as cursor:
102108
try:
103109
await cursor.execute(
104-
"INSERT INTO scan_table (file_id, detect, step2_detail) VALUES (%s, %s, %s)",
110+
"INSERT INTO scan_table (file_id, detect, step2_detail) VALUES (%s, %s, %s) "
111+
"ON DUPLICATE KEY UPDATE detect=VALUES(detect), step2_detail=VALUES(step2_detail)",
105112
(stored_file_id, detect, detail),
106113
)
107114
await conn.commit()
@@ -125,7 +132,7 @@ async def save_scan_result(uploadId: int, stored_file_id, detect, detail):
125132
raise
126133

127134

128-
async def select_keyword(matches):
135+
def select_keyword(matches):
129136
keyword_count = defaultdict(int)
130137

131138
for match in matches:
@@ -134,12 +141,14 @@ async def select_keyword(matches):
134141
keyword_count[atk_type] += 1
135142

136143
if keyword_count:
137-
most_common_keyword = max(keyword_count, key=keyword_count.get)
138-
logging.info(f"Most common atk_type: {most_common_keyword}")
139-
return most_common_keyword
144+
# 가장 많이 매칭된 atk_type 값을 추출
145+
keywords = str(keyword_count.keys())
146+
logging.info(f"Most common atk_type: {keywords}")
147+
return keywords
140148
else:
141149
logging.info("No atk_type found in matches")
142-
return None
150+
return "unmatched" # None 대신 기본값 반환
151+
143152

144153

145154
async def yara_test_match(file_path, yara_rules):
@@ -161,21 +170,31 @@ async def yara_test_match(file_path, yara_rules):
161170

162171
async def scan_file(upload_id: int, yara_rules):
163172
try:
173+
# 파일 업로드 정보 가져오기
164174
file_record = await get_file_upload(upload_id)
165175
salted_hash = file_record["salted_hash"]
166176

177+
# 저장된 파일 정보 가져오기
167178
stored_file_record = await get_stored_file(salted_hash)
168179
stored_file_id = stored_file_record["id"]
169180
s3_key = stored_file_record["save_path"]
170181

171182
file_stream = await stream_file_from_s3(s3_key)
172-
file_data = await file_stream.read()
173183

184+
# S3에서 반환된 file_stream은 이미 bytes 객체입니다.
185+
file_data = file_stream.read() # 여기에서 read()는 필요 없음, file_stream 자체가 파일 데이터임
186+
187+
# YARA 룰 매칭
174188
matches = yara_rules.match(data=file_data)
189+
175190
detect = 1 if matches else 0
176191

177-
most_common_keyword = await select_keyword(matches)
178-
detail = "\n".join([str(match) for match in matches]) if matches else "unmatched"
192+
most_common_keyword = select_keyword(matches)
193+
if most_common_keyword is None:
194+
most_common_keyword = "unmatched"
195+
detail = (
196+
"\n".join([str(match) for match in matches]) if matches else "unmatched"
197+
)
179198

180199
logging.info(f"result: {matches}")
181200
logging.info(f"detect: {detect}")
@@ -188,6 +207,7 @@ async def scan_file(upload_id: int, yara_rules):
188207
raise HTTPException(status_code=500, detail="Error scanning file")
189208

190209

210+
191211
async def get_stored_file(hash: str):
192212
try:
193213
conn = await aiomysql.connect(

0 commit comments

Comments
 (0)