Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ streamlit_js_eval
extra-streamlit-components
scikit-learn
joblib
numpy
numpy #<2.0 (recomended)
PyJWT
statsmodels
prophet
Expand Down
192 changes: 92 additions & 100 deletions test_oauth.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,102 @@
"""
Test script for OAuth functionality
Run this to test OAuth configuration without starting the full app
Pytest-based OAuth Test Suite
Run with: pytest -v
"""

import os
import sys
import pytest
from urllib.parse import urlparse
from dotenv import load_dotenv

# Load environment variables
load_dotenv()


# -----------------------------
# OAuth Config Tests
# -----------------------------

def test_oauth_config():
"""Test OAuth configuration"""
print("πŸ” Testing OAuth Configuration...")

try:
from auth.oauth_config import oauth_config

print(f"Available providers: {oauth_config.get_available_providers()}")

for provider in oauth_config.get_available_providers():
print(f"\nβœ… {provider.upper()} OAuth configured")
provider_config = oauth_config.get_provider(provider)
print(f" Client ID: {provider_config.client_id[:10]}...")
print(f" Redirect URI: {provider_config.redirect_uri}")

if not oauth_config.get_available_providers():
print("❌ No OAuth providers configured")
print("Please set up your OAuth credentials in .env file")
return False

return True

except Exception as e:
print(f"❌ Error testing OAuth config: {e}")
return False

def test_database():
"""Test database initialization"""
print("\nπŸ—„οΈ Testing Database...")

try:
from auth.auth_utils import init_db
init_db()
print("βœ… Database initialized successfully")
return True
except Exception as e:
print(f"❌ Error initializing database: {e}")
return False

def test_oauth_utils():
"""Test OAuth utility functions"""
print("\nπŸ”§ Testing OAuth Utils...")

try:
from auth.oauth_utils import generate_state, store_oauth_state, verify_oauth_state

# Test state generation
state = generate_state()
print(f"βœ… State generated: {state[:10]}...")

# Test state storage and verification
store_oauth_state(state, "google")
verified_provider = verify_oauth_state(state)

if verified_provider == "google":
print("βœ… State storage and verification working")
else:
print("❌ State verification failed")
return False

return True

except Exception as e:
print(f"❌ Error testing OAuth utils: {e}")
return False

def main():
"""Run all tests"""
print("πŸš€ TalkHeal OAuth Test Suite")
print("=" * 40)

tests = [
test_oauth_config,
test_database,
test_oauth_utils
]

passed = 0
total = len(tests)

for test in tests:
if test():
passed += 1

print("\n" + "=" * 40)
print(f"πŸ“Š Test Results: {passed}/{total} tests passed")

if passed == total:
print("πŸŽ‰ All tests passed! OAuth is ready to use.")
else:
print("⚠️ Some tests failed. Please check the configuration.")

return passed == total

if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)
"""Ensure at least one OAuth provider is configured properly"""
from auth.oauth_config import oauth_config

providers = oauth_config.get_available_providers()
assert isinstance(providers, list)

assert len(providers) > 0, (
"No OAuth providers configured. "
"Ensure credentials are set in .env"
)

for provider in providers:
config = oauth_config.get_provider(provider)

assert config.client_id is not None
assert config.client_secret is not None
assert config.redirect_uri is not None

# Validate redirect URI structure
parsed = urlparse(config.redirect_uri)
assert parsed.scheme in ["http", "https"]
assert parsed.netloc != ""


# -----------------------------
# Database Tests
# -----------------------------

def test_database_initialization():
"""Ensure database initializes without error"""
from auth.auth_utils import init_db

# Should not raise any exception
init_db()


# -----------------------------
# OAuth Utility Tests
# -----------------------------

def test_state_generation_entropy():
"""Ensure generated states are unique and sufficiently long"""
from auth.oauth_utils import generate_state

states = {generate_state() for _ in range(300)}

# Ensure uniqueness
assert len(states) == 300

# Ensure reasonable entropy length
for state in states:
assert len(state) >= 32


def test_state_storage_and_verification():
"""Ensure stored state verifies correctly"""
from auth.oauth_utils import generate_state, store_oauth_state, verify_oauth_state

state = generate_state()
store_oauth_state(state, "google")

provider = verify_oauth_state(state)
assert provider == "google"


def test_state_reuse_protection():
"""Ensure state cannot be reused (prevents replay attacks)"""
from auth.oauth_utils import generate_state, store_oauth_state, verify_oauth_state

state = generate_state()
store_oauth_state(state, "google")

# First verification should pass
assert verify_oauth_state(state) == "google"

# Second verification should fail
assert verify_oauth_state(state) is None


def test_invalid_state_rejected():
"""Ensure invalid state is rejected"""
from auth.oauth_utils import verify_oauth_state

assert verify_oauth_state("invalid-state") is None