diff --git a/generate.py b/generate.py index 9c0e8d4..4d39cbe 100755 --- a/generate.py +++ b/generate.py @@ -61,12 +61,14 @@ def generate_protocol(output: str) -> None: ] ) + enumerations = {e['name']: e for e in lsp_json['enumerations']} + content += '\n\n\n' content += '\n\n\n'.join(generate_enumerations(lsp_json['enumerations'], ENUM_OVERRIDES)) content += '\n\n' - content += '\n'.join(generate_type_aliases(lsp_json['typeAliases'], ALIAS_OVERRIDES)) + content += '\n'.join(generate_type_aliases(lsp_json['typeAliases'], ALIAS_OVERRIDES, enumerations)) content += '\n\n\n' - content += '\n\n\n'.join(generate_structures(lsp_json['structures'])) + content += '\n\n\n'.join(generate_structures(lsp_json['structures'], enumerations)) content += '\n' content += '\n'.join(get_new_literal_structures()) @@ -101,10 +103,12 @@ def generate_custom(output: str) -> None: requests = sorted(lsp_json['requests'], key=itemgetter('typeName')) notifications = sorted(lsp_json['notifications'], key=itemgetter('typeName')) + enumerations = {e['name']: e for e in lsp_json['enumerations']} + content += '\n\n\n' - content += '\n\n\n'.join(generate_requests_and_responses(requests)) + content += '\n\n\n'.join(generate_requests_and_responses(requests, enumerations)) content += '\n\n\n' - content += '\n\n\n'.join(generate_notifications(notifications)) + content += '\n\n\n'.join(generate_notifications(notifications, enumerations)) content += '\n' # Remove trailing spaces. diff --git a/generated/lsp_types.py b/generated/lsp_types.py index e5d46c2..042ab4e 100644 --- a/generated/lsp_types.py +++ b/generated/lsp_types.py @@ -1249,7 +1249,7 @@ class FoldingRange(TypedDict): """ endCharacter: NotRequired[Uint] """The zero-based character offset before the folded range ends. If not defined, defaults to the length of the end line.""" - kind: NotRequired['FoldingRangeKind'] + kind: NotRequired[Union[str, FoldingRangeKind]] """ Describes the kind of the folding range such as 'comment' or 'region'. The kind is used to categorize folding ranges and used by commands like 'Fold all comments'. @@ -3322,7 +3322,7 @@ class CodeAction(TypedDict): title: str """A short, human-readable, title for this code action.""" - kind: NotRequired['CodeActionKind'] + kind: NotRequired[Union[str, CodeActionKind]] """ The kind of the code action. @@ -3389,7 +3389,7 @@ class CodeActionRegistrationOptions(TypedDict): A document selector to identify the scope of the registration. If set to null the document selector provided on the client side will be used. """ - codeActionKinds: NotRequired[List['CodeActionKind']] + codeActionKinds: NotRequired[List[Union[str, CodeActionKind]]] """ CodeActionKinds that this server may return. @@ -4624,7 +4624,7 @@ class TextDocumentItem(TypedDict): uri: DocumentUri """The text document's uri.""" - languageId: 'LanguageKind' + languageId: Union[str, LanguageKind] """The text document's language identifier.""" version: int """ @@ -4804,7 +4804,7 @@ class ServerCapabilities(TypedDict): server. """ - positionEncoding: NotRequired['PositionEncodingKind'] + positionEncoding: NotRequired[Union[str, PositionEncodingKind]] """ The position encoding the server picked from the encodings offered by the client via the client capability `general.positionEncodings`. @@ -4987,7 +4987,7 @@ class FileSystemWatcher(TypedDict): @since 3.17.0 support for relative patterns. """ - kind: NotRequired['WatchKind'] + kind: NotRequired[Union[Uint, WatchKind]] """ The kind of events of interest. If omitted it defaults to WatchKind.Create | WatchKind.Change | WatchKind.Delete @@ -5419,7 +5419,7 @@ class CodeActionContext(TypedDict): that these accurately reflect the error state of the resource. The primary parameter to compute code actions is the provided range. """ - only: NotRequired[List['CodeActionKind']] + only: NotRequired[List[Union[str, CodeActionKind]]] """ Requested kind of actions to return. @@ -5452,7 +5452,7 @@ class CodeActionDisabled(TypedDict): class CodeActionOptions(TypedDict): """Provider options for a {@link CodeActionRequest}.""" - codeActionKinds: NotRequired[List['CodeActionKind']] + codeActionKinds: NotRequired[List[Union[str, CodeActionKind]]] """ CodeActionKinds that this server may return. @@ -6123,7 +6123,7 @@ class CodeActionKindDocumentation(TypedDict): @proposed """ - kind: 'CodeActionKind' + kind: Union[str, CodeActionKind] """ The kind of the code action being documented. @@ -6513,7 +6513,7 @@ class GeneralClientCapabilities(TypedDict): @since 3.16.0 """ - positionEncodings: NotRequired[List['PositionEncodingKind']] + positionEncodings: NotRequired[List[Union[str, PositionEncodingKind]]] """ The position encodings supported by the client. Client and server have to agree on the same position encoding to ensure that offsets @@ -7886,7 +7886,7 @@ class ClientCodeLensResolveOptions(TypedDict): class ClientFoldingRangeKindOptions(TypedDict): """@since 3.18.0""" - valueSet: NotRequired[List['FoldingRangeKind']] + valueSet: NotRequired[List[Union[str, FoldingRangeKind]]] """ The folding range kind values the client supports. When this property exists the client also guarantees that it will @@ -8003,7 +8003,7 @@ class ClientSignatureParameterInformationOptions(TypedDict): class ClientCodeActionKindOptions(TypedDict): """@since 3.18.0""" - valueSet: List['CodeActionKind'] + valueSet: List[Union[str, CodeActionKind]] """ The code action kind values the client supports. When this property exists the client also guarantees that it will diff --git a/lsp_schema.py b/lsp_schema.py index c23578d..8cfe6a3 100644 --- a/lsp_schema.py +++ b/lsp_schema.py @@ -225,6 +225,7 @@ class MetaModel(TypedDict): EveryType = ( BaseType + | EnumerationType | ReferenceType | ArrayType | MapType diff --git a/utils/generate_notifications.py b/utils/generate_notifications.py index 78c2371..87c7104 100644 --- a/utils/generate_notifications.py +++ b/utils/generate_notifications.py @@ -5,16 +5,17 @@ from utils.helpers import indentation if TYPE_CHECKING: + from lsp_schema import Enumeration from lsp_schema import Notification -def generate_notifications(notifications: list[Notification]) -> list[str]: +def generate_notifications(notifications: list[Notification], enumerations: dict[str, Enumeration]) -> list[str]: client_notification_names: list[str] = [] server_notification_names: list[str] = [] definitions: list[str] = [] for notification in notifications: message_direction = notification['messageDirection'] - name, definition = generate_notification(notification) + name, definition = generate_notification(notification, enumerations) if message_direction == 'clientToServer': client_notification_names.append(name) elif message_direction == 'serverToClient': @@ -32,14 +33,14 @@ def generate_notifications(notifications: list[Notification]) -> list[str]: ] -def generate_notification(notification: Notification) -> tuple[str, str]: +def generate_notification(notification: Notification, enumerations: dict[str, Enumeration]) -> tuple[str, str]: method = notification['method'] params = notification.get('params') name = notification['typeName'] definition = f'class {name}(TypedDict):\n' definition += f"{indentation}method: Literal['{method}']\n" if params: - definition += f'{indentation}params: {format_type(params, {"root_symbol_name": ""})}' + definition += f'{indentation}params: {format_type(params, {"enumerations": enumerations})}' else: definition += f'{indentation}params: None' return (name, definition) diff --git a/utils/generate_requests_and_responses.py b/utils/generate_requests_and_responses.py index 454c583..7fce0cd 100644 --- a/utils/generate_requests_and_responses.py +++ b/utils/generate_requests_and_responses.py @@ -5,10 +5,11 @@ from utils.helpers import indentation if TYPE_CHECKING: + from lsp_schema import Enumeration from lsp_schema import Request -def generate_requests_and_responses(requests: list[Request]) -> list[str]: +def generate_requests_and_responses(requests: list[Request], enumerations: dict[str, Enumeration]) -> list[str]: client_request_names: list[str] = [] server_request_names: list[str] = [] client_response_names: list[str] = [] @@ -18,7 +19,7 @@ def generate_requests_and_responses(requests: list[Request]) -> list[str]: for request in requests: message_direction = request['messageDirection'] # Requests - req_name, req_definition = generate_request(request) + req_name, req_definition = generate_request(request, enumerations) if message_direction == 'clientToServer': client_request_names.append(req_name) elif message_direction == 'serverToClient': @@ -28,7 +29,7 @@ def generate_requests_and_responses(requests: list[Request]) -> list[str]: server_request_names.append(req_name) req_definitions.append(req_definition) # Responses - res_name, res_definition = generate_response(request) + res_name, res_definition = generate_response(request, enumerations) if message_direction == 'clientToServer': server_response_names.append(res_name) elif message_direction == 'serverToClient': @@ -51,20 +52,20 @@ def generate_requests_and_responses(requests: list[Request]) -> list[str]: ] -def generate_request(request: Request) -> tuple[str, str]: +def generate_request(request: Request, enumerations: dict[str, Enumeration]) -> tuple[str, str]: method = request['method'] params = request.get('params') name = request['typeName'] definition = f'class {name}(TypedDict):\n' definition += f"{indentation}method: Literal['{method}']\n" if params: - definition += f'{indentation}params: {format_type(params, {"root_symbol_name": ""})}' + definition += f'{indentation}params: {format_type(params, {"enumerations": enumerations})}' else: definition += f'{indentation}params: None' return (name, definition) -def generate_response(request: Request) -> tuple[str, str]: +def generate_response(request: Request, enumerations: dict[str, Enumeration]) -> tuple[str, str]: method = request['method'] result = request['result'] params = request.get('params') @@ -73,7 +74,7 @@ def generate_response(request: Request) -> tuple[str, str]: definition = f'class {name}(TypedDict):\n' definition += f"{indentation}method: Literal['{method}']\n" if request['messageDirection'] == 'serverToClient': - typ = format_type(params, {'root_symbol_name': ''}) if params else None + typ = format_type(params, {'enumerations': enumerations}) if params else None definition += f'{indentation}params: {typ}\n' - definition += f'{indentation}result: {format_type(result, {"root_symbol_name": ""})}' + definition += f'{indentation}result: {format_type(result, {"enumerations": enumerations})}' return (name, definition) diff --git a/utils/generate_structures.py b/utils/generate_structures.py index fb48c2d..c22a11a 100644 --- a/utils/generate_structures.py +++ b/utils/generate_structures.py @@ -11,18 +11,21 @@ from utils.helpers import StructureKind if TYPE_CHECKING: + from lsp_schema import Enumeration from lsp_schema import Structure -def generate_structures(structures: list[Structure]) -> list[str]: +def generate_structures(structures: list[Structure], enumerations: dict[str, Enumeration]) -> list[str]: def to_string(structure: Structure) -> str: kind = StructureKind.Function if has_invalid_property_name(structure['properties']) else StructureKind.Class - return generate_structure(structure, structures, kind) + return generate_structure(structure, structures, kind, enumerations) return [to_string(structure) for structure in structures if not structure['name'].startswith('_')] -def get_additional_properties(for_structure: Structure, structures: list[Structure]) -> list[FormattedProperty]: +def get_additional_properties( + for_structure: Structure, structures: list[Structure], enumerations: dict[str, Enumeration] +) -> list[FormattedProperty]: """Return properties from extended and mixin types.""" result: list[FormattedProperty] = [] additional_structures = for_structure.get('extends') or [] @@ -33,16 +36,21 @@ def get_additional_properties(for_structure: Structure, structures: list[Structu raise Exception(error, additional_structure['kind']) structure = next(structure for structure in structures if structure['name'] == additional_structure['name']) if structure: - properties = get_formatted_properties(structure['properties'], structure['name']) + properties = get_formatted_properties(structure['properties'], {'enumerations': enumerations}) result.extend(properties) return result -def generate_structure(structure: Structure, structures: list[Structure], structure_kind: StructureKind) -> str: +def generate_structure( + structure: Structure, + structures: list[Structure], + structure_kind: StructureKind, + enumerations: dict[str, Enumeration], +) -> str: result = '' symbol_name = structure['name'] - properties = get_formatted_properties(structure['properties'], structure['name']) - additional_properties = get_additional_properties(structure, structures) + properties = get_formatted_properties(structure['properties'], {'enumerations': enumerations}) + additional_properties = get_additional_properties(structure, structures, enumerations) # add extended properties taken_property_names = [p['name'] for p in properties] diff --git a/utils/generate_type_aliases.py b/utils/generate_type_aliases.py index d345db9..0be3cb0 100644 --- a/utils/generate_type_aliases.py +++ b/utils/generate_type_aliases.py @@ -5,17 +5,20 @@ from utils.helpers import format_type if TYPE_CHECKING: + from lsp_schema import Enumeration from lsp_schema import TypeAlias -def generate_type_aliases(type_aliases: list[TypeAlias], overrides: dict[str, str]) -> list[str]: +def generate_type_aliases( + type_aliases: list[TypeAlias], overrides: dict[str, str], enumerations: dict[str, Enumeration] +) -> list[str]: def to_string(type_alias: TypeAlias) -> str: symbol_name = type_alias['name'] documentation = format_comment(type_alias.get('documentation')) if symbol_name in overrides: value = overrides[symbol_name] else: - value = format_type(type_alias['type'], {'root_symbol_name': symbol_name}) + value = format_type(type_alias['type'], {'enumerations': enumerations}) result = f""" {symbol_name}: TypeAlias = {value}""" if documentation: diff --git a/utils/helpers.py b/utils/helpers.py index d08699b..daf919c 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from lsp_schema import BaseType + from lsp_schema import Enumeration + from lsp_schema import EnumerationType from lsp_schema import EveryType from lsp_schema import MapKeyType from lsp_schema import Property @@ -54,7 +56,7 @@ class StructureKind(Enum): class FormatTypeContext(TypedDict): - root_symbol_name: str + enumerations: dict[str, Enumeration] def format_type(typ: EveryType, context: FormatTypeContext) -> str: @@ -63,13 +65,15 @@ def format_type(typ: EveryType, context: FormatTypeContext) -> str: return format_base_types(typ) if typ['kind'] == 'reference': literal_symbol_name = typ['name'] + if (enum := context['enumerations'].get(literal_symbol_name)) and enum.get('supportsCustomValues'): + return f'Union[{format_type(enum["type"], context)}, {literal_symbol_name}]' return f"'{literal_symbol_name}'" if typ['kind'] == 'array': literal_symbol_name = format_type(typ['element'], context) return f'List[{literal_symbol_name}]' if typ['kind'] == 'map': key = format_base_types(typ['key']) - value = format_type(typ['value'], {'root_symbol_name': key}) + value = format_type(typ['value'], {'enumerations': context['enumerations']}) return f'Dict[{key}, {value}]' if typ['kind'] == 'and': pass @@ -91,7 +95,7 @@ def format_type(typ: EveryType, context: FormatTypeContext) -> str: return result -def format_base_types(base_type: BaseType | MapKeyType) -> str: +def format_base_types(base_type: BaseType | MapKeyType | EnumerationType) -> str: mapping: dict[str, str] = { 'integer': 'int', 'uinteger': 'Uint', @@ -111,11 +115,11 @@ class FormattedProperty(TypedDict): documentation: str -def get_formatted_properties(properties: list[Property], root_symbol_name: str) -> list[FormattedProperty]: +def get_formatted_properties(properties: list[Property], context: FormatTypeContext) -> list[FormattedProperty]: result: list[FormattedProperty] = [] for p in properties: key = p['name'] - value = format_type(p['type'], {'root_symbol_name': root_symbol_name + '_' + key}) + value = format_type(p['type'], context) if p.get('optional'): value = f'NotRequired[{value}]' documentation = p.get('documentation') or ''