Coverage for python/lum/clu/odin/serialization.py: 72%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-17 18:41 +0000

1from lum.clu.odin.mention import (Mention, TextBoundMention, RelationMention, EventMention, CrossSentenceMention) 

2from lum.clu.processors.document import Document 

3from lum.clu.processors.interval import Interval 

4import typing 

5import collections 

6 

7__all__ = ["OdinJsonSerializer"] 

8 

9 

10 

11 # ("type" -> longString) ~ 

12 # // used for correspondence with paths map 

13 # ("id" -> id) ~ // tb.id would just create a different TextBoundMentionOps to provide the id 

14 # ("text" -> tb.text) ~ 

15 # ("labels" -> tb.labels) ~ 

16 # ("tokenInterval" -> Map("start" -> tb.tokenInterval.start, "end" -> tb.tokenInterval.end)) ~ 

17 # ("characterStartOffset" -> tb.startOffset) ~ 

18 # ("characterEndOffset" -> tb.endOffset) ~ 

19 # ("sentence" -> tb.sentence) ~ 

20 # ("document" -> documentEquivalenceHash.toString) ~ 

21 # ("keep" -> tb.keep) ~ 

22 # ("foundBy" -> tb.foundBy) 

23 

24# object TextBoundMentionOps { 

25# val string = "TextBoundMention" 

26# val shortString = "T" 

27# } 

28 

29# object EventMentionOps { 

30# val string = "EventMention" 

31# val shortString = "E" 

32# } 

33 

34# object RelationMentionOps { 

35# val string = "RelationMention" 

36# val shortString = "R" 

37# } 

38 

39# object CrossSentenceMentionOps { 

40# val string = "CrossSentenceMention" 

41# val shortString = "CS" 

42# } 

43 

44class OdinJsonSerializer: 

45 

46 MENTION_TB_TYPE = "TextBoundMention" 

47 MENTION_R_TYPE = "RelationMention" 

48 MENTION_E_TYPE = "EventMention" 

49 MENTION_C_TYPE = "CrossSentenceMention" 

50 

51 # @staticmethod 

52 # def to_compact_mentions_json(jdata: dict[str, typing.Any]) -> list[Mention]: 

53 # pass 

54 

55 # don't blow the stack 

56 @staticmethod 

57 def from_compact_mentions_json(compact_json: dict[str, typing.Any]) -> list[Mention]: 

58 

59 # populate mapping of doc id -> Document 

60 docs_map = dict() 

61 for doc_id, doc_json in compact_json["documents"].items(): 

62 # store ID if not set 

63 if "id" not in doc_json: 

64 doc_json.update({"id": doc_id}) 

65 docs_map[doc_id] = Document(**doc_json) 

66 

67 mentions_map: dict[str, Mention] = dict() 

68 mention_ids: typing.Set[str] = {mn.get("id") for mn in compact_json["mentions"]} 

69 # attack TBMs first 

70 srt_fn = lambda mid: -1 if mid.startswith("T") else 1 

71 # make a queue w/ TBMs first 

72 missing: collections.deque = collections.deque(sorted(list(mention_ids), key=srt_fn)) 

73 

74 while len(missing) > 0: 

75 m_id = missing.popleft() 

76 # pop a key and try to create the mention map 

77 _, mns_map = OdinJsonSerializer._fetch_mention( 

78 m_id=m_id, 

79 compact_json=compact_json, 

80 docs_map=docs_map, 

81 mentions_map=mentions_map 

82 ) 

83 # store new results 

84 mentions_map.update(mns_map) 

85 # filter out newly constructed mentions 

86 missing = collections.deque([k for k in missing if k not in mentions_map]) 

87 #return list(mentions_map.values()) 

88 # avoids unraveling mentions to include triggers, etc. 

89 return [m for mid, m in mentions_map.items() if mid in mention_ids] 

90 

91 @staticmethod 

92 def _fetch_mention(m_id: str, compact_json: dict[str, typing.Any], docs_map: dict[str, Document], mentions_map: dict[str, Mention]) -> typing.Tuple[Mention, dict[str, Mention]]: 

93 # base case 

94 if m_id in mentions_map: 

95 return mentions_map[m_id], mentions_map 

96 

97 mjson: dict[str, typing.Any] = [mn for mn in compact_json["mentions"] if mn.get("id", None)== m_id][0] 

98 mtype = mjson["type"] 

99 # gather general info 

100 labels = mjson["labels"] 

101 token_interval = Interval(**mjson["tokenInterval"]) 

102 document = docs_map[mjson["document"]] 

103 start = mjson["characterStartOffset"] 

104 end = mjson["characterEndOffset"] 

105 sentence_index = mjson["sentence"] 

106 found_by = mjson["foundBy"] 

107 keep = mjson.get("keep", True) 

108 # easy case. We have everything we need. 

109 if mtype == OdinJsonSerializer.MENTION_TB_TYPE: 

110 m = TextBoundMention( 

111 labels=labels, 

112 token_interval=token_interval, 

113 sentence_index=sentence_index, 

114 start=start, 

115 end=end, 

116 document=document, 

117 found_by=found_by, 

118 keep=keep 

119 ) 

120 mentions_map[m_id] = m 

121 return (m, mentions_map) 

122 # everything else *might* have paths 

123 paths: typing.Optional[Mention.Paths] = OdinJsonSerializer.construct_paths(mjson.get("paths", None)) 

124 # retrieve all args recursively 

125 arguments: Mention.Arguments = dict() 

126 for role, mns_json in mjson.get("arguments", {}).items(): 

127 role_mns = arguments.get(role, []) 

128 for mn_json in mns_json: 

129 _mid = mn_json["id"] 

130 if _mid in mentions_map: 

131 _mn = mentions_map[_mid] 

132 else: 

133 # NOTE: in certain cases, the referenced mid might not be found in the compact_json. 

134 # we'll add it to be safe. 

135 if all(m["id"] != _mid for m in compact_json["mentions"]): 

136 compact_json["mentions"] = compact_json["mentions"] + [mn_json] 

137 _mn, _mns_map = OdinJsonSerializer._fetch_mention( 

138 m_id=_mid, 

139 compact_json=compact_json, 

140 docs_map=docs_map, mentions_map=mentions_map 

141 ) 

142 # update our progress 

143 mentions_map.update(_mns_map) 

144 # store this guy 

145 role_mns.append(_mn) 

146 # update our args 

147 arguments[role] = role_mns 

148 

149 if mtype == OdinJsonSerializer.MENTION_E_TYPE: 

150 # get or load trigger 

151 trigger_mjson = mjson["trigger"] 

152 trigger_id = trigger_mjson["id"] 

153 if trigger_id in mentions_map: 

154 trigger = mentions_map[trigger_id] 

155 # avoid a recursive call  

156 # for the sake of the stack... 

157 else: 

158 trigger = TextBoundMention( 

159 labels=trigger_mjson["labels"], 

160 token_interval=Interval(**trigger_mjson["tokenInterval"]), 

161 sentence_index=trigger_mjson["sentence"], 

162 start=trigger_mjson["characterStartOffset"], 

163 end=trigger_mjson["characterEndOffset"], 

164 document=docs_map[trigger_mjson["document"]], 

165 found_by=trigger_mjson["foundBy"], 

166 keep=trigger_mjson.get("keep", False) 

167 ) 

168 # we have what we need 

169 m = EventMention( 

170 labels=labels, 

171 token_interval=token_interval, 

172 trigger=trigger, 

173 sentence_index=sentence_index, 

174 start=start, 

175 end=end, 

176 document=document, 

177 arguments=arguments, 

178 paths=paths, 

179 found_by=found_by, 

180 keep=keep 

181 ) 

182 mentions_map[m_id] = m 

183 return (m, mentions_map) 

184 if mtype == OdinJsonSerializer.MENTION_R_TYPE: 

185 # we have what we need 

186 m = RelationMention( 

187 labels=labels, 

188 token_interval=token_interval, 

189 sentence_index=sentence_index, 

190 start=start, 

191 end=end, 

192 document=document, 

193 arguments=arguments, 

194 paths=paths, 

195 found_by=found_by, 

196 keep=keep 

197 ) 

198 mentions_map[m_id] = m 

199 return (m, mentions_map) 

200 if mtype == OdinJsonSerializer.MENTION_C_TYPE: 

201 # anchor 

202 # this will be one of our args (see https://github.com/clulab/processors/blob/9f89ea7bf6ac551f77dbfdbb8eec9bf216711df4/main/src/main/scala/org/clulab/odin/Mention.scala#L535), so we'll be lazy 

203 anchor: Mention = mentions_map[mjson["anchor"]["id"]] 

204 # neighbor 

205 # this will be one of our args (see https://github.com/clulab/processors/blob/9f89ea7bf6ac551f77dbfdbb8eec9bf216711df4/main/src/main/scala/org/clulab/odin/Mention.scala#L535), so we'll be lazy 

206 neighbor: Mention = mentions_map[mjson["neighbor"]["id"]] 

207 # we have what we need 

208 m = CrossSentenceMention( 

209 labels=labels, 

210 token_interval=token_interval, 

211 anchor=anchor, 

212 neighbor=neighbor, 

213 # corresponds to anchor.sentence_inde 

214 sentence_index=sentence_index, 

215 start=start, 

216 end=end, 

217 document=document, 

218 arguments=arguments, 

219 paths=None, 

220 found_by=found_by, 

221 keep=keep 

222 ) 

223 mentions_map[m_id] = m 

224 return (m, mentions_map) 

225 else: 

226 raise Exception(f"Unrecognized mention type {mtype}. Expected one of the following {OdinJsonSerializer.MENTION_TB_TYPE}, {OdinJsonSerializer.MENTION_E_TYPE}, {OdinJsonSerializer.MENTION_R_TYPE}, {OdinJsonSerializer.MENTION_C_TYPE}") 

227 

228 @staticmethod 

229 def construct_paths(maybe_path_data: typing.Optional[dict[str, typing.Any]]) -> typing.Optional[Mention.Paths]: 

230 # FIXME: implement me 

231 return None 

232 

233 @staticmethod 

234 def _load_mention_from_compact_JSON(mention_id: str, compact_json: dict[str, typing.Any], docs_dict: dict[str, Document], mentions_dict: dict[str, Mention]): 

235 mjson = compact_json["mentions"][mention_id] 

236 # recover document 

237 document = docs_dict[mjson["document"]] 

238 # TODO: load args 

239 

240 # collect components 

241 mtype = mjson["type"] 

242 labels = mjson["labels"] 

243 token_interval = Interval(**mjson["tokenInterval"]) 

244 if mtype == OdinJsonSerializer.MENTION_TB_TYPE: 

245 raise NotImplementedError 

246 elif mtype == OdinJsonSerializer.MENTION_E_TYPE: 

247 # get or load trigger 

248 raise NotImplementedError 

249 elif mtype == OdinJsonSerializer.MENTION_R_TYPE: 

250 raise NotImplementedError 

251 elif mtype == OdinJsonSerializer.MENTION_C_TYPE: 

252 raise NotImplementedError 

253 

254 kwargs = { 

255 "label": mjson.get("label", labels[0]), 

256 "labels": labels, 

257 "token_interval": Interval.load_from_JSON(mjson["tokenInterval"]), 

258 "sentence": mjson["sentence"], 

259 "document": doc, 

260 "doc_id": doc_id, 

261 "trigger": mjson.get("trigger", None), 

262 "arguments": mjson.get("arguments", None), 

263 "paths": mjson.get("paths", None), 

264 "keep": mjson.get("keep", True), 

265 "foundBy": mjson["foundBy"] 

266 } 

267 m = Mention(**kwargs) 

268 # set IDs 

269 m.id = mjson["id"] 

270 m._doc_id = doc_id 

271 # set character offsets 

272 m.character_start_offset = mjson["characterStartOffset"] 

273 m.character_end_offset = mjson["characterEndOffset"] 

274 return m 

275 

276 # def to_JSON_dict(self): 

277 # m = dict() 

278 # m["id"] = self.id 

279 # m["type"] = self.type 

280 # m["label"] = self.label 

281 # m["labels"] = self.labels 

282 # m["tokenInterval"] = self.tokenInterval.to_JSON_dict() 

283 # m["characterStartOffset"] = self.characterStartOffset 

284 # m["characterEndOffset"] = self.characterEndOffset 

285 # m["sentence"] = self.sentence 

286 # m["document"] = self._doc_id 

287 # # do we have a trigger? 

288 # if self.trigger: 

289 # m["trigger"] = self.trigger.to_JSON_dict() 

290 # # do we have arguments? 

291 # if self.arguments: 

292 # m["arguments"] = self._arguments_to_JSON_dict() 

293 # # handle paths 

294 # if self.paths: 

295 # m["paths"] = self.paths 

296 # m["keep"] = self.keep 

297 # m["foundBy"] = self.foundBy 

298 # return m