1
1
import asyncio
2
2
import logging
3
3
import os
4
+ import functools
4
5
from collections import defaultdict
5
6
from datetime import datetime
6
7
import pytz
@@ -86,13 +87,18 @@ async def stream_file_from_s3(s3_key):
86
87
87
88
try :
88
89
loop = asyncio .get_event_loop ()
89
- response = await loop .run_in_executor (None , s3_client .get_object , {"Bucket" : bucket_name , "Key" : key })
90
+ # functools.partial로 키워드 인자를 전달할 수 있도록 함
91
+ response = await loop .run_in_executor (
92
+ None , functools .partial (s3_client .get_object , Bucket = bucket_name , Key = key )
93
+ )
90
94
return response ["Body" ]
91
95
except Exception as e :
92
96
logging .error (f"Failed to stream file from S3: { e } " )
93
97
raise
94
98
95
99
100
+
101
+
96
102
async def save_scan_result (uploadId : int , stored_file_id , detect , detail ):
97
103
try :
98
104
conn = await aiomysql .connect (
@@ -101,7 +107,8 @@ async def save_scan_result(uploadId: int, stored_file_id, detect, detail):
101
107
async with conn .cursor () as cursor :
102
108
try :
103
109
await cursor .execute (
104
- "INSERT INTO scan_table (file_id, detect, step2_detail) VALUES (%s, %s, %s)" ,
110
+ "INSERT INTO scan_table (file_id, detect, step2_detail) VALUES (%s, %s, %s) "
111
+ "ON DUPLICATE KEY UPDATE detect=VALUES(detect), step2_detail=VALUES(step2_detail)" ,
105
112
(stored_file_id , detect , detail ),
106
113
)
107
114
await conn .commit ()
@@ -125,7 +132,7 @@ async def save_scan_result(uploadId: int, stored_file_id, detect, detail):
125
132
raise
126
133
127
134
128
- async def select_keyword (matches ):
135
+ def select_keyword (matches ):
129
136
keyword_count = defaultdict (int )
130
137
131
138
for match in matches :
@@ -134,12 +141,14 @@ async def select_keyword(matches):
134
141
keyword_count [atk_type ] += 1
135
142
136
143
if keyword_count :
137
- most_common_keyword = max (keyword_count , key = keyword_count .get )
138
- logging .info (f"Most common atk_type: { most_common_keyword } " )
139
- return most_common_keyword
144
+ # 가장 많이 매칭된 atk_type 값을 추출
145
+ keywords = str (keyword_count .keys ())
146
+ logging .info (f"Most common atk_type: { keywords } " )
147
+ return keywords
140
148
else :
141
149
logging .info ("No atk_type found in matches" )
142
- return None
150
+ return "unmatched" # None 대신 기본값 반환
151
+
143
152
144
153
145
154
async def yara_test_match (file_path , yara_rules ):
@@ -161,21 +170,31 @@ async def yara_test_match(file_path, yara_rules):
161
170
162
171
async def scan_file (upload_id : int , yara_rules ):
163
172
try :
173
+ # 파일 업로드 정보 가져오기
164
174
file_record = await get_file_upload (upload_id )
165
175
salted_hash = file_record ["salted_hash" ]
166
176
177
+ # 저장된 파일 정보 가져오기
167
178
stored_file_record = await get_stored_file (salted_hash )
168
179
stored_file_id = stored_file_record ["id" ]
169
180
s3_key = stored_file_record ["save_path" ]
170
181
171
182
file_stream = await stream_file_from_s3 (s3_key )
172
- file_data = await file_stream .read ()
173
183
184
+ # S3에서 반환된 file_stream은 이미 bytes 객체입니다.
185
+ file_data = file_stream .read () # 여기에서 read()는 필요 없음, file_stream 자체가 파일 데이터임
186
+
187
+ # YARA 룰 매칭
174
188
matches = yara_rules .match (data = file_data )
189
+
175
190
detect = 1 if matches else 0
176
191
177
- most_common_keyword = await select_keyword (matches )
178
- detail = "\n " .join ([str (match ) for match in matches ]) if matches else "unmatched"
192
+ most_common_keyword = select_keyword (matches )
193
+ if most_common_keyword is None :
194
+ most_common_keyword = "unmatched"
195
+ detail = (
196
+ "\n " .join ([str (match ) for match in matches ]) if matches else "unmatched"
197
+ )
179
198
180
199
logging .info (f"result: { matches } " )
181
200
logging .info (f"detect: { detect } " )
@@ -188,6 +207,7 @@ async def scan_file(upload_id: int, yara_rules):
188
207
raise HTTPException (status_code = 500 , detail = "Error scanning file" )
189
208
190
209
210
+
191
211
async def get_stored_file (hash : str ):
192
212
try :
193
213
conn = await aiomysql .connect (
0 commit comments