Source code for web2vec.extractors.dns_features

import logging
from dataclasses import dataclass, field
from functools import cache
from typing import List, Optional

import dns.resolver

from web2vec.utils import get_domain_from_url

logger = logging.getLogger(__name__)


[docs] @dataclass class DNSRecordFeatures: record_type: str ttl: int values: List[str]
[docs] @dataclass class DNSFeatures: domain: str records: List[DNSRecordFeatures] = field(default_factory=list) min_ttl: Optional[int] = field(init=False, default=None) ttl_expires_within_hour: Optional[bool] = field(init=False, default=None) ttl_expires_within_day: Optional[bool] = field(init=False, default=None) ttl_expires_within_week: Optional[bool] = field(init=False, default=None) @property def count_ips(self) -> int: """Return the number of resolved IPs (IPv4).""" ip_records = [record for record in self.records if record.record_type == "A"] return len(ip_records[0].values) if ip_records else 0 @property def count_name_servers(self) -> int: """Return number of NameServers (NS) resolved.""" ns_records = [record for record in self.records if record.record_type == "NS"] return len(ns_records[0].values) if ns_records else 0 @property def count_mx_servers(self) -> int: """Return number of resolved MX Servers.""" mx_records = [record for record in self.records if record.record_type == "MX"] return len(mx_records[0].values) if mx_records else 0 def _address_record_ttls(self) -> List[int]: return [ record.ttl for record in self.records if record.record_type in ["A", "AAAA"] and record.ttl is not None ]
[docs] def compute_derived_features(self) -> None: """Populate TTL-based indicators for downstream ML usage.""" ttl_values = self._address_record_ttls() self.min_ttl = min(ttl_values) if ttl_values else None if self.min_ttl is None: self.ttl_expires_within_hour = None self.ttl_expires_within_day = None self.ttl_expires_within_week = None return self.ttl_expires_within_hour = self.min_ttl <= 3600 # 1 hour self.ttl_expires_within_day = self.min_ttl <= 86400 # 24 hours self.ttl_expires_within_week = self.min_ttl <= 604800 # 7 days
@property def extract_ttl(self) -> Optional[int]: """Return Time-to-live (TTL) value associated with hostname.""" if self.min_ttl is not None: return self.min_ttl ttl_records = self._address_record_ttls() return ttl_records[0] if ttl_records else None
[docs] def get_dns_features(domain: str) -> DNSFeatures: """Get DNS features for the given domain.""" dns_result = DNSFeatures(domain=domain) try: for record_type in ["A", "AAAA", "MX", "TXT", "NS", "CNAME"]: try: answers = dns.resolver.resolve(domain, record_type) record_values = [rdata.to_text() for rdata in answers] ttl = answers.rrset.ttl dns_result.records.append( DNSRecordFeatures(record_type, ttl, record_values) ) except dns.resolver.NoAnswer: logger.debug(f"No {record_type} record found for {domain}") except dns.resolver.NXDOMAIN: logger.warning(f"{domain} does not exist") except Exception as e: # noqa logger.warning( f"Error fetching {record_type} records for {domain}: {e}", e ) except Exception as e: # noqa logger.warning(f"General error fetching DNS records for {domain}: {e}", e) dns_result.compute_derived_features() return dns_result
[docs] @cache def get_dns_features_cached(domain: str) -> DNSFeatures: """Get DNS features for the given domain.""" return get_dns_features(domain)
if __name__ == "__main__": url = "https://www.example.com" domain = get_domain_from_url(url) result = get_dns_features(domain) print(result)