Files
squeezelite-esp32/tools/protoc_utils/ProtoElement.py
2025-03-18 17:38:34 -04:00

219 lines
8.4 KiB
Python

from __future__ import annotations
from functools import partial
from typing import ClassVar, List, Any, Dict
from typing import Callable
from google.protobuf import message, descriptor_pb2
from google.protobuf.descriptor import Descriptor, FieldDescriptor, FileDescriptor
from google.protobuf.descriptor_pb2 import FieldDescriptorProto,DescriptorProto,EnumDescriptorProto
import google.protobuf.descriptor_pool as descriptor_pool
import logging
import copy
# import custom_options_pb2 as custom
RendererType = Callable[['ProtoElement'], Dict]
class ProtoElement:
childs:List[ProtoElement]
descriptor:Descriptor|FieldDescriptor
comments: Dict[str,str]
enum_type:EnumDescriptorProto
_comments: Dict[str,str] ={}
pool:descriptor_pool.DescriptorPool
prototypes: dict[str, type[message.Message]]
renderer:RendererType
package:str
file:FileDescriptor
message:str
_positions: Dict[str,tuple]
position: tuple
options:Dict[str,any]
_message_instance:ClassVar
@classmethod
def set_prototypes(cls,prototypes:dict[str, type[message.Message]]):
cls.prototypes = prototypes
@classmethod
def set_comments_base(cls,comments:Dict[str,str]):
cls._comments = comments
@classmethod
def set_positions_base(cls,positions:Dict[str,tuple]):
cls._positions = positions
@classmethod
def set_pool(cls,pool:descriptor_pool.DescriptorPool):
cls.pool = pool
@classmethod
def set_logger(cls,logger = None):
if not logger and not cls.logger:
cls.logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
elif logger:
cls.logger = logger
@classmethod
def set_render(cls,render):
cls.render_class = render
def __init__(self, descriptor: Descriptor|FieldDescriptor, parent=None):
ProtoElement.set_logger()
self.descriptor = descriptor
self.file = descriptor.file
self.package = getattr(descriptor,"file",parent).package
self.descriptorname = descriptor.name
self.json_name = getattr(descriptor,'json_name','')
self.type_name = getattr(descriptor,'type_name',descriptor.name)
self.parent = parent
self.fullname = descriptor.full_name
self.type = getattr(descriptor,'type',FieldDescriptor.TYPE_MESSAGE)
if self.type ==FieldDescriptor.TYPE_MESSAGE:
try:
self._message_instance = self.prototypes[self.descriptor.message_type.full_name]()
# self.logger.debug(f'Found instance for {self.descriptor.message_type.full_name}')
except:
# self.logger.error(f'Could not find instance for {self.descriptor.full_name}')
self._message_instance = self.prototypes[self.descriptor.full_name]()
self.label = getattr(descriptor,'label',None)
self.childs = []
if descriptor.has_options:
self.options = {descr.name: value for descr, value in descriptor.GetOptions().ListFields()}
else:
self.options = {}
try:
if descriptor.containing_type.has_options:
self.options.update({descr.name: value for descr, value in descriptor.containing_type.GetOptions().ListFields()})
except:
pass
self.render = partial(self.render_class, self)
self.comments = {comment.split('.')[-1]:self._comments[comment] for comment in self._comments.keys() if comment.startswith(self.path)}
self.position = self._positions.get(self.path)
@property
def cpp_type(self)->str:
return f'{self.package}_{self.descriptor.containing_type.name}'
@property
def cpp_member(self)->str:
return self.name
@property
def cpp_type_member_prefix(self)->str:
return f'{self.cpp_type}_{self.cpp_member}'
@property
def cpp_type_member(self)->str:
return f'{self.cpp_type}.{self.cpp_member}'
@property
def main_message(self)->bool:
return self.parent == None
@property
def parent(self)->ProtoElement:
return self._parent
@parent.setter
def parent(self,value:ProtoElement):
self._parent = value
if value:
self._parent.childs.append(self)
@property
def root(self)->ProtoElement:
return self if not self.parent else self.parent
@property
def enum_type(self)->EnumDescriptorProto:
return self.descriptor.enum_type
@property
def cpp_root(self):
return f'{self.cpp_type}_ROOT'
@property
def cpp_child(self):
return f'{self.cpp_type}_CHILD'
@property
def proto_file_line(self):
# Accessing file descriptor to get source code info, adjusted for proper context
if self.position:
start_line, start_column, end_line = self.position
return f"{self.file.name}:{start_line}"
else:
return f"{self.file.name}"
@property
def message_instance(self):
return getattr(self,'_message_instance',getattr(self.parent,'message_instance',None))
@property
def new_message_instance(self):
if self.type == FieldDescriptor.TYPE_MESSAGE:
try:
# Try to create a new instance using the full name of the message type
return self.prototypes[self.descriptor.message_type.full_name]()
except KeyError:
# If the above fails, use an alternative method to create a new instance
# Log the error if necessary
# self.logger.error(f'Could not find instance for {self.descriptor.full_name}')
return self.prototypes[self.descriptor.full_name]()
else:
# Return None or raise an exception if the type is not a message
return None
@property
def tree(self):
childs = '->('+', '.join(c.tree for c in self.childs ) + ')' if len(self.childs)>0 else ''
return f'{self.name}{childs}'
@property
def name(self):
return self.descriptorname if len(self.descriptorname)>0 else self.parent.name if self.parent else self.package
@property
def enum_values(self)->List[str]:
return [n.name for n in getattr(self.enum_type,"values",getattr(self.enum_type,"value",[])) ]
@property
def enum_values_str(self)->str:
return ', '.join(self.enum_values)
@property
def fields(self)->List[FieldDescriptor]:
return getattr(self.descriptor,"fields",getattr(getattr(self.descriptor,"message_type",None),"fields",None))
@property
def _default_value(self):
if 'default_value' in self.options:
return self.options['default_value']
if self.type in [FieldDescriptorProto.TYPE_INT32, FieldDescriptorProto.TYPE_INT64,
FieldDescriptorProto.TYPE_UINT32, FieldDescriptorProto.TYPE_UINT64,
FieldDescriptorProto.TYPE_SINT32, FieldDescriptorProto.TYPE_SINT64,
FieldDescriptorProto.TYPE_FIXED32, FieldDescriptorProto.TYPE_FIXED64,
FieldDescriptorProto.TYPE_SFIXED32, FieldDescriptorProto.TYPE_SFIXED64]:
return 0
elif self.type in [FieldDescriptorProto.TYPE_FLOAT, FieldDescriptorProto.TYPE_DOUBLE]:
return 0.0
elif self.type == FieldDescriptorProto.TYPE_BOOL:
return False
elif self.type in [FieldDescriptorProto.TYPE_STRING, FieldDescriptorProto.TYPE_BYTES]:
return ""
elif self.is_enum:
return self.enum_values[0] if self.enum_values else 0
@property
def detached_leading_comments(self)->str:
return self.comments["leading"] if "detached" in self.comments else ""
@property
def leading_comment(self)->str:
return self.comments["leading"] if "leading" in self.comments else ""
@property
def trailing_comment(self)->str:
return self.comments["trailing"] if "trailing" in self.comments else ""
@property
def is_enum(self):
return self.type == FieldDescriptorProto.TYPE_ENUM
@property
def path(self) -> str:
return self.descriptor.full_name
@property
def enum_name(self)-> str:
return self.type_name.split('.', maxsplit=1)[-1]
@property
def repeated(self)->bool:
return self.label== FieldDescriptor.LABEL_REPEATED