import bisect import re import unicodedata from typing import Optional, Union from . import idnadata from .intranges import intranges_contain _virama_combining_class = 9 _alabel_prefix = b"xn--" _unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]") class IDNAError(UnicodeError): """Base exception for all IDNA-encoding related problems""" pass class IDNABidiError(IDNAError): """Exception when bidirectional requirements are not satisfied""" pass class InvalidCodepoint(IDNAError): """Exception when a disallowed or unallocated codepoint is used""" pass class InvalidCodepointContext(IDNAError): """Exception when the codepoint is not valid in the context it is used""" pass def _combining_class(cp: int) -> int: v = unicodedata.combining(chr(cp)) if v == 0: if not unicodedata.name(chr(cp)): raise ValueError("Unknown character in unicodedata") return v def _is_script(cp: str, script: str) -> bool: return intranges_contain(ord(cp), idnadata.scripts[script]) def _punycode(s: str) -> bytes: return s.encode("punycode") def _unot(s: int) -> str: return "U+{:04X}".format(s) def valid_label_length(label: Union[bytes, str]) -> bool: if len(label) > 63: return False return True def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool: if len(label) > (254 if trailing_dot else 253): return False return True def check_bidi(label: str, check_ltr: bool = False) -> bool: # Bidi rules should only be applied if string contains RTL characters bidi_label = False for idx, cp in enumerate(label, 1): direction = unicodedata.bidirectional(cp) if direction == "": # String likely comes from a newer version of Unicode raise IDNABidiError("Unknown directionality in label {} at position {}".format(repr(label), idx)) if direction in ["R", "AL", "AN"]: bidi_label = True if not bidi_label and not check_ltr: return True # Bidi rule 1 direction = unicodedata.bidirectional(label[0]) if direction in ["R", "AL"]: rtl = True elif direction == "L": rtl = False else: raise IDNABidiError("First codepoint in label {} must be directionality L, R or AL".format(repr(label))) valid_ending = False number_type: Optional[str] = None for idx, cp in enumerate(label, 1): direction = unicodedata.bidirectional(cp) if rtl: # Bidi rule 2 if direction not in [ "R", "AL", "AN", "EN", "ES", "CS", "ET", "ON", "BN", "NSM", ]: raise IDNABidiError("Invalid direction for codepoint at position {} in a right-to-left label".format(idx)) # Bidi rule 3 if direction in ["R", "AL", "EN", "AN"]: valid_ending = True elif direction != "NSM": valid_ending = False # Bidi rule 4 if direction in ["AN", "EN"]: if not number_type: number_type = direction else: if number_type != direction: raise IDNABidiError("Can not mix numeral types in a right-to-left label") else: # Bidi rule 5 if direction not in ["L", "EN", "ES", "CS", "ET", "ON", "BN", "NSM"]: raise IDNABidiError("Invalid direction for codepoint at position {} in a left-to-right label".format(idx)) # Bidi rule 6 if direction in ["L", "EN"]: valid_ending = True elif direction != "NSM": valid_ending = False if not valid_ending: raise IDNABidiError("Label ends with illegal codepoint directionality") return True def check_initial_combiner(label: str) -> bool: if unicodedata.category(label[0])[0] == "M": raise IDNAError("Label begins with an illegal combining character") return True def check_hyphen_ok(label: str) -> bool: if label[2:4] == "--": raise IDNAError("Label has disallowed hyphens in 3rd and 4th position") if label[0] == "-" or label[-1] == "-": raise IDNAError("Label must not start or end with a hyphen") return True def check_nfc(label: str) -> None: if unicodedata.normalize("NFC", label) != label: raise IDNAError("Label must be in Normalization Form C") def valid_contextj(label: str, pos: int) -> bool: cp_value = ord(label[pos]) if cp_value == 0x200C: if pos > 0: if _combining_class(ord(label[pos - 1])) == _virama_combining_class: return True ok = False for i in range(pos - 1, -1, -1): joining_type = idnadata.joining_types.get(ord(label[i])) if joining_type == ord("T"): continue elif joining_type in [ord("L"), ord("D")]: ok = True break else: break if not ok: return False ok = False for i in range(pos + 1, len(label)): joining_type = idnadata.joining_types.get(ord(label[i])) if joining_type == ord("T"): continue elif joining_type in [ord("R"), ord("D")]: ok = True break else: break return ok if cp_value == 0x200D: if pos > 0: if _combining_class(ord(label[pos - 1])) == _virama_combining_class: return True return False else: return False def valid_contexto(label: str, pos: int, exception: bool = False) -> bool: cp_value = ord(label[pos]) if cp_value == 0x00B7: if 0 < pos < len(label) - 1: if ord(label[pos - 1]) == 0x006C and ord(label[pos + 1]) == 0x006C: return True return False elif cp_value == 0x0375: if pos < len(label) - 1 and len(label) > 1: return _is_script(label[pos + 1], "Greek") return False elif cp_value == 0x05F3 or cp_value == 0x05F4: if pos > 0: return _is_script(label[pos - 1], "Hebrew") return False elif cp_value == 0x30FB: for cp in label: if cp == "\u30fb": continue if _is_script(cp, "Hiragana") or _is_script(cp, "Katakana") or _is_script(cp, "Han"): return True return False elif 0x660 <= cp_value <= 0x669: for cp in label: if 0x6F0 <= ord(cp) <= 0x06F9: return False return True elif 0x6F0 <= cp_value <= 0x6F9: for cp in label: if 0x660 <= ord(cp) <= 0x0669: return False return True return False def check_label(label: Union[str, bytes, bytearray]) -> None: if isinstance(label, (bytes, bytearray)): label = label.decode("utf-8") if len(label) == 0: raise IDNAError("Empty Label") check_nfc(label) check_hyphen_ok(label) check_initial_combiner(label) for pos, cp in enumerate(label): cp_value = ord(cp) if intranges_contain(cp_value, idnadata.codepoint_classes["PVALID"]): continue elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTJ"]): try: if not valid_contextj(label, pos): raise InvalidCodepointContext( "Joiner {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label)) ) except ValueError: raise IDNAError( "Unknown codepoint adjacent to joiner {} at position {} in {}".format( _unot(cp_value), pos + 1, repr(label) ) ) elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTO"]): if not valid_contexto(label, pos): raise InvalidCodepointContext( "Codepoint {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label)) ) else: raise InvalidCodepoint( "Codepoint {} at position {} of {} not allowed".format(_unot(cp_value), pos + 1, repr(label)) ) check_bidi(label) def alabel(label: str) -> bytes: try: label_bytes = label.encode("ascii") ulabel(label_bytes) if not valid_label_length(label_bytes): raise IDNAError("Label too long") return label_bytes except UnicodeEncodeError: pass check_label(label) label_bytes = _alabel_prefix + _punycode(label) if not valid_label_length(label_bytes): raise IDNAError("Label too long") return label_bytes def ulabel(label: Union[str, bytes, bytearray]) -> str: if not isinstance(label, (bytes, bytearray)): try: label_bytes = label.encode("ascii") except UnicodeEncodeError: check_label(label) return label else: label_bytes = label label_bytes = label_bytes.lower() if label_bytes.startswith(_alabel_prefix): label_bytes = label_bytes[len(_alabel_prefix) :] if not label_bytes: raise IDNAError("Malformed A-label, no Punycode eligible content found") if label_bytes.decode("ascii")[-1] == "-": raise IDNAError("A-label must not end with a hyphen") else: check_label(label_bytes) return label_bytes.decode("ascii") try: label = label_bytes.decode("punycode") except UnicodeError: raise IDNAError("Invalid A-label") check_label(label) return label def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str: """Re-map the characters in the string according to UTS46 processing.""" from .uts46data import uts46data output = "" for pos, char in enumerate(domain): code_point = ord(char) try: uts46row = uts46data[code_point if code_point < 256 else bisect.bisect_left(uts46data, (code_point, "Z")) - 1] status = uts46row[1] replacement: Optional[str] = None if len(uts46row) == 3: replacement = uts46row[2] if ( status == "V" or (status == "D" and not transitional) or (status == "3" and not std3_rules and replacement is None) ): output += char elif replacement is not None and ( status == "M" or (status == "3" and not std3_rules) or (status == "D" and transitional) ): output += replacement elif status != "I": raise IndexError() except IndexError: raise InvalidCodepoint( "Codepoint {} not allowed at position {} in {}".format(_unot(code_point), pos + 1, repr(domain)) ) return unicodedata.normalize("NFC", output) def encode( s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False, ) -> bytes: if not isinstance(s, str): try: s = str(s, "ascii") except UnicodeDecodeError: raise IDNAError("should pass a unicode string to the function rather than a byte string.") if uts46: s = uts46_remap(s, std3_rules, transitional) trailing_dot = False result = [] if strict: labels = s.split(".") else: labels = _unicode_dots_re.split(s) if not labels or labels == [""]: raise IDNAError("Empty domain") if labels[-1] == "": del labels[-1] trailing_dot = True for label in labels: s = alabel(label) if s: result.append(s) else: raise IDNAError("Empty label") if trailing_dot: result.append(b"") s = b".".join(result) if not valid_string_length(s, trailing_dot): raise IDNAError("Domain too long") return s def decode( s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, ) -> str: try: if not isinstance(s, str): s = str(s, "ascii") except UnicodeDecodeError: raise IDNAError("Invalid ASCII in A-label") if uts46: s = uts46_remap(s, std3_rules, False) trailing_dot = False result = [] if not strict: labels = _unicode_dots_re.split(s) else: labels = s.split(".") if not labels or labels == [""]: raise IDNAError("Empty domain") if not labels[-1]: del labels[-1] trailing_dot = True for label in labels: s = ulabel(label) if s: result.append(s) else: raise IDNAError("Empty label") if trailing_dot: result.append("") return ".".join(result)