diff --git a/pyproject.toml b/pyproject.toml index 00a4edb..a7c762f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "humanize", "Pillow", "pillow-heif", + "IP2Location", ] [tool.setuptools.packages.find] diff --git a/src/snek/IP2LOCATION-LITE-DB11.BIN b/src/snek/IP2LOCATION-LITE-DB11.BIN new file mode 100755 index 0000000..8f04af8 Binary files /dev/null and b/src/snek/IP2LOCATION-LITE-DB11.BIN differ diff --git a/src/snek/app.py b/src/snek/app.py index c092b41..4bc8129 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -8,7 +8,7 @@ from snek import snode from snek.view.threads import ThreadsView import json logging.basicConfig(level=logging.DEBUG) - +from ipaddress import ip_address from concurrent.futures import ThreadPoolExecutor from aiohttp import web @@ -20,7 +20,7 @@ from aiohttp_session import ( from aiohttp_session.cookie_storage import EncryptedCookieStorage from app.app import Application as BaseApplication from jinja2 import FileSystemLoader - +import IP2Location from snek.sssh import start_ssh_server from snek.docs.app import Application as DocsApplication from snek.mapper import get_mappers @@ -69,6 +69,29 @@ async def session_middleware(request, handler): response = await handler(request) return response +@web.middleware +async def ip2location_middleware(request, handler): + response = await handler(request) + ip = request.headers.get("X-Forwarded-For", request.remote) + ipaddress = ip_address(ip) + if ipaddress.is_private: + return response + if not request.app.session.get("uid"): + return response + user = await request.app.services.user.get(uid=request.app.session.get("uid")) + if not user: + return response + location = request.app.ip2location.get(ip) + original_city = user['city'] + if user['city'] != location.city: + user['country_long'] = location.country + user['country_short'] = locaion.country_short + user['city'] = location.city + user['region'] = location.region + user['latitude'] = location.latitude + user['longitude'] = location.longitude + await request.app.services.user.update(user) + return response @web.middleware async def trailing_slash_middleware(request, handler): @@ -84,6 +107,7 @@ class Application(BaseApplication): middlewares = [ cors_middleware, web.normalize_path_middleware(merge_slashes=True), + ip2location_middleware ] self.template_path = pathlib.Path(__file__).parent.joinpath("templates") self.static_path = pathlib.Path(__file__).parent.joinpath("static") @@ -111,11 +135,15 @@ class Application(BaseApplication): self.broadcast_service = None self.user_availability_service_task = None + base_path = pathlib.Path(__file__).parent + self.ip2location = IP2Location.IP2Location(base_path.joinpath("IP2LOCATION-LITE-DB11.BIN")) + self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_ssh_server) self.on_startup.append(self.prepare_database) - + + @property def uptime_seconds(self): return (datetime.now() - self.time_start).total_seconds() diff --git a/src/snek/model/user.py b/src/snek/model/user.py index 9869456..000577b 100644 --- a/src/snek/model/user.py +++ b/src/snek/model/user.py @@ -30,6 +30,13 @@ class UserModel(BaseModel): last_ping = ModelField(name="last_ping", required=False, kind=str) is_admin = ModelField(name="is_admin", required=False, kind=bool) + + country_short = ModelField(name="country_short", required=False, kind=str) + country_long = ModelField(name="country_long", required=False, kind=str) + city = ModelField(name="city", required=False, kind=str) + latitude = ModelField(name="latitude", required=False, kind=float) + longitude = ModelField(name="longitude", required=False, kind=float) + region = ModelField(name="region", required=False, kind=str) async def get_property(self, name): prop = await self.app.services.user_property.find_one(