decode_dataclass.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import dataclasses
  2. from dataclasses import fields, is_dataclass
  3. from typing import Type, get_type_hints, get_origin, get_args, Union, List, Dict, Any
  4. from types import SimpleNamespace
  5. # 方法1: 使用递归反射的通用解码器
  6. def decode_dataclass(cls, data):
  7. """通用的dataclass解码器,支持嵌套结构"""
  8. if not is_dataclass(cls):
  9. return data
  10. if not isinstance(data, dict):
  11. raise ValueError(f"Expected dict for {cls.__name__}, got {type(data)}")
  12. # 获取类型提示
  13. type_hints = get_type_hints(cls)
  14. field_values = {}
  15. for field in fields(cls):
  16. field_name = field.name
  17. field_type = type_hints.get(field_name, field.type)
  18. if field_name not in data:
  19. if field.default != dataclasses.MISSING:
  20. field_values[field_name] = field.default
  21. elif field.default_factory != dataclasses.MISSING:
  22. field_values[field_name] = field.default_factory()
  23. else:
  24. raise ValueError(f"Missing required field: {field_name}")
  25. continue
  26. field_value = data[field_name]
  27. field_values[field_name] = _decode_field_value(field_type, field_value)
  28. output = cls(**field_values)
  29. return output
  30. def _decode_field_value(field_type, value):
  31. """解码单个字段值"""
  32. # 处理None值
  33. if value is None:
  34. return None
  35. # 获取类型的origin(如List, Dict等)
  36. origin = get_origin(field_type)
  37. args = get_args(field_type)
  38. # 处理Optional类型 (Union[T, None])
  39. if origin is Union:
  40. # 检查是否是Optional类型
  41. if len(args) == 2 and type(None) in args:
  42. non_none_type = args[0] if args[1] is type(None) else args[1]
  43. return _decode_field_value(non_none_type, value)
  44. else:
  45. # 其他Union类型,尝试第一个类型
  46. return _decode_field_value(args[0], value)
  47. # 处理List类型
  48. if origin is list or origin is List:
  49. if not isinstance(value, list):
  50. raise ValueError(f"Expected list, got {type(value)}")
  51. element_type = args[0] if args else Any
  52. return [_decode_field_value(element_type, item) for item in value]
  53. # 处理Dict类型
  54. if origin is dict or origin is Dict:
  55. if not isinstance(value, dict):
  56. raise ValueError(f"Expected dict, got {type(value)}")
  57. value_type = args[1] if len(args) > 1 else Any
  58. return {k: _decode_field_value(value_type, v) for k, v in value.items()}
  59. # 处理dataclass类型
  60. if is_dataclass(field_type):
  61. return decode_dataclass(field_type, value)
  62. # 基础类型直接返回
  63. return value
  64. # 方法4: 专门处理对象数组的函数
  65. def decode_dataclass_list(cls, data_list):
  66. """将JSON对象数组解码为dataclass列表"""
  67. if not isinstance(data_list, list):
  68. raise ValueError(f"Expected list, got {type(data_list)}")
  69. return [decode_dataclass(cls, item) for item in data_list]
  70. # 方法5: 扩展通用解码器支持顶层列表
  71. def decode_json_to_type(target_type, json_data):
  72. """更通用的JSON解码器,支持顶层数组"""
  73. origin = get_origin(target_type)
  74. args = get_args(target_type)
  75. # 处理List[SomeDataClass]类型
  76. if origin is list or origin is List:
  77. if not isinstance(json_data, list):
  78. raise ValueError(
  79. f"Expected list for {target_type}, got {type(json_data)}")
  80. element_type = args[0] if args else Any
  81. return [decode_json_to_type(element_type, item) for item in json_data]
  82. # 处理单个dataclass
  83. if is_dataclass(target_type):
  84. return decode_dataclass(target_type, json_data)
  85. # 其他类型直接返回
  86. return json_data
  87. def ns_to_dataclass(ns: Any, dataclass_type: Type[Any], field_mapping: Dict[str, str] = None) -> Any:
  88. """
  89. 将 SimpleNamespace 对象转换为 dataclass 实例,支持嵌套结构和列表。
  90. Args:
  91. ns: 要转换的 SimpleNamespace 对象或其他值。
  92. dataclass_type: 目标 dataclass 类型。
  93. field_mapping: 可选的字段映射字典,键为 SimpleNamespace 属性名,值为 dataclass 字段名。
  94. Returns:
  95. 转换后的 dataclass 实例或其他原始值。
  96. """
  97. if field_mapping is None:
  98. field_mapping = {}
  99. # 如果不是 SimpleNamespace,直接返回原始值
  100. if not isinstance(ns, SimpleNamespace):
  101. return ns
  102. # 获取 dataclass 的字段信息
  103. if not is_dataclass(dataclass_type) and not isinstance(dataclass_type, type):
  104. raise ValueError(f"{dataclass_type} 不是有效的 dataclass 类型")
  105. dc_fields = {f.name: f.type for f in fields(
  106. dataclass_type)} if is_dataclass(dataclass_type) else {}
  107. # 将 SimpleNamespace 转为字典
  108. ns_dict = vars(ns)
  109. result_dict = {}
  110. # 遍历 SimpleNamespace 的属性
  111. for ns_key, value in ns_dict.items():
  112. # 应用字段映射(如果有)
  113. dc_key = field_mapping.get(ns_key, ns_key)
  114. if dc_key not in dc_fields:
  115. continue # 忽略 dataclass 中不存在的字段
  116. # 获取 dataclass 字段的类型
  117. field_type = dc_fields.get(dc_key)
  118. # 处理嵌套的 SimpleNamespace
  119. if isinstance(value, SimpleNamespace):
  120. result_dict[dc_key] = ns_to_dataclass(
  121. value, field_type, field_mapping)
  122. # 处理列表
  123. elif isinstance(value, list) and field_type:
  124. origin_type = get_origin(field_type)
  125. if origin_type is list:
  126. item_type = get_args(field_type)[
  127. 0] if get_args(field_type) else Any
  128. result_dict[dc_key] = [
  129. ns_to_dataclass(item, item_type, field_mapping)
  130. if isinstance(item, SimpleNamespace)
  131. else item
  132. for item in value
  133. ]
  134. else:
  135. result_dict[dc_key] = value
  136. # 直接赋值其他类型
  137. else:
  138. result_dict[dc_key] = value
  139. # 创建 dataclass 实例
  140. return dataclass_type(**result_dict)