Coverage for python/lum/clu/odin/serialization.py: 72%
105 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-17 18:41 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-17 18:41 +0000
1from lum.clu.odin.mention import (Mention, TextBoundMention, RelationMention, EventMention, CrossSentenceMention)
2from lum.clu.processors.document import Document
3from lum.clu.processors.interval import Interval
4import typing
5import collections
7__all__ = ["OdinJsonSerializer"]
11 # ("type" -> longString) ~
12 # // used for correspondence with paths map
13 # ("id" -> id) ~ // tb.id would just create a different TextBoundMentionOps to provide the id
14 # ("text" -> tb.text) ~
15 # ("labels" -> tb.labels) ~
16 # ("tokenInterval" -> Map("start" -> tb.tokenInterval.start, "end" -> tb.tokenInterval.end)) ~
17 # ("characterStartOffset" -> tb.startOffset) ~
18 # ("characterEndOffset" -> tb.endOffset) ~
19 # ("sentence" -> tb.sentence) ~
20 # ("document" -> documentEquivalenceHash.toString) ~
21 # ("keep" -> tb.keep) ~
22 # ("foundBy" -> tb.foundBy)
24# object TextBoundMentionOps {
25# val string = "TextBoundMention"
26# val shortString = "T"
27# }
29# object EventMentionOps {
30# val string = "EventMention"
31# val shortString = "E"
32# }
34# object RelationMentionOps {
35# val string = "RelationMention"
36# val shortString = "R"
37# }
39# object CrossSentenceMentionOps {
40# val string = "CrossSentenceMention"
41# val shortString = "CS"
42# }
44class OdinJsonSerializer:
46 MENTION_TB_TYPE = "TextBoundMention"
47 MENTION_R_TYPE = "RelationMention"
48 MENTION_E_TYPE = "EventMention"
49 MENTION_C_TYPE = "CrossSentenceMention"
51 # @staticmethod
52 # def to_compact_mentions_json(jdata: dict[str, typing.Any]) -> list[Mention]:
53 # pass
55 # don't blow the stack
56 @staticmethod
57 def from_compact_mentions_json(compact_json: dict[str, typing.Any]) -> list[Mention]:
59 # populate mapping of doc id -> Document
60 docs_map = dict()
61 for doc_id, doc_json in compact_json["documents"].items():
62 # store ID if not set
63 if "id" not in doc_json:
64 doc_json.update({"id": doc_id})
65 docs_map[doc_id] = Document(**doc_json)
67 mentions_map: dict[str, Mention] = dict()
68 mention_ids: typing.Set[str] = {mn.get("id") for mn in compact_json["mentions"]}
69 # attack TBMs first
70 srt_fn = lambda mid: -1 if mid.startswith("T") else 1
71 # make a queue w/ TBMs first
72 missing: collections.deque = collections.deque(sorted(list(mention_ids), key=srt_fn))
74 while len(missing) > 0:
75 m_id = missing.popleft()
76 # pop a key and try to create the mention map
77 _, mns_map = OdinJsonSerializer._fetch_mention(
78 m_id=m_id,
79 compact_json=compact_json,
80 docs_map=docs_map,
81 mentions_map=mentions_map
82 )
83 # store new results
84 mentions_map.update(mns_map)
85 # filter out newly constructed mentions
86 missing = collections.deque([k for k in missing if k not in mentions_map])
87 #return list(mentions_map.values())
88 # avoids unraveling mentions to include triggers, etc.
89 return [m for mid, m in mentions_map.items() if mid in mention_ids]
91 @staticmethod
92 def _fetch_mention(m_id: str, compact_json: dict[str, typing.Any], docs_map: dict[str, Document], mentions_map: dict[str, Mention]) -> typing.Tuple[Mention, dict[str, Mention]]:
93 # base case
94 if m_id in mentions_map:
95 return mentions_map[m_id], mentions_map
97 mjson: dict[str, typing.Any] = [mn for mn in compact_json["mentions"] if mn.get("id", None)== m_id][0]
98 mtype = mjson["type"]
99 # gather general info
100 labels = mjson["labels"]
101 token_interval = Interval(**mjson["tokenInterval"])
102 document = docs_map[mjson["document"]]
103 start = mjson["characterStartOffset"]
104 end = mjson["characterEndOffset"]
105 sentence_index = mjson["sentence"]
106 found_by = mjson["foundBy"]
107 keep = mjson.get("keep", True)
108 # easy case. We have everything we need.
109 if mtype == OdinJsonSerializer.MENTION_TB_TYPE:
110 m = TextBoundMention(
111 labels=labels,
112 token_interval=token_interval,
113 sentence_index=sentence_index,
114 start=start,
115 end=end,
116 document=document,
117 found_by=found_by,
118 keep=keep
119 )
120 mentions_map[m_id] = m
121 return (m, mentions_map)
122 # everything else *might* have paths
123 paths: typing.Optional[Mention.Paths] = OdinJsonSerializer.construct_paths(mjson.get("paths", None))
124 # retrieve all args recursively
125 arguments: Mention.Arguments = dict()
126 for role, mns_json in mjson.get("arguments", {}).items():
127 role_mns = arguments.get(role, [])
128 for mn_json in mns_json:
129 _mid = mn_json["id"]
130 if _mid in mentions_map:
131 _mn = mentions_map[_mid]
132 else:
133 # NOTE: in certain cases, the referenced mid might not be found in the compact_json.
134 # we'll add it to be safe.
135 if all(m["id"] != _mid for m in compact_json["mentions"]):
136 compact_json["mentions"] = compact_json["mentions"] + [mn_json]
137 _mn, _mns_map = OdinJsonSerializer._fetch_mention(
138 m_id=_mid,
139 compact_json=compact_json,
140 docs_map=docs_map, mentions_map=mentions_map
141 )
142 # update our progress
143 mentions_map.update(_mns_map)
144 # store this guy
145 role_mns.append(_mn)
146 # update our args
147 arguments[role] = role_mns
149 if mtype == OdinJsonSerializer.MENTION_E_TYPE:
150 # get or load trigger
151 trigger_mjson = mjson["trigger"]
152 trigger_id = trigger_mjson["id"]
153 if trigger_id in mentions_map:
154 trigger = mentions_map[trigger_id]
155 # avoid a recursive call
156 # for the sake of the stack...
157 else:
158 trigger = TextBoundMention(
159 labels=trigger_mjson["labels"],
160 token_interval=Interval(**trigger_mjson["tokenInterval"]),
161 sentence_index=trigger_mjson["sentence"],
162 start=trigger_mjson["characterStartOffset"],
163 end=trigger_mjson["characterEndOffset"],
164 document=docs_map[trigger_mjson["document"]],
165 found_by=trigger_mjson["foundBy"],
166 keep=trigger_mjson.get("keep", False)
167 )
168 # we have what we need
169 m = EventMention(
170 labels=labels,
171 token_interval=token_interval,
172 trigger=trigger,
173 sentence_index=sentence_index,
174 start=start,
175 end=end,
176 document=document,
177 arguments=arguments,
178 paths=paths,
179 found_by=found_by,
180 keep=keep
181 )
182 mentions_map[m_id] = m
183 return (m, mentions_map)
184 if mtype == OdinJsonSerializer.MENTION_R_TYPE:
185 # we have what we need
186 m = RelationMention(
187 labels=labels,
188 token_interval=token_interval,
189 sentence_index=sentence_index,
190 start=start,
191 end=end,
192 document=document,
193 arguments=arguments,
194 paths=paths,
195 found_by=found_by,
196 keep=keep
197 )
198 mentions_map[m_id] = m
199 return (m, mentions_map)
200 if mtype == OdinJsonSerializer.MENTION_C_TYPE:
201 # anchor
202 # this will be one of our args (see https://github.com/clulab/processors/blob/9f89ea7bf6ac551f77dbfdbb8eec9bf216711df4/main/src/main/scala/org/clulab/odin/Mention.scala#L535), so we'll be lazy
203 anchor: Mention = mentions_map[mjson["anchor"]["id"]]
204 # neighbor
205 # this will be one of our args (see https://github.com/clulab/processors/blob/9f89ea7bf6ac551f77dbfdbb8eec9bf216711df4/main/src/main/scala/org/clulab/odin/Mention.scala#L535), so we'll be lazy
206 neighbor: Mention = mentions_map[mjson["neighbor"]["id"]]
207 # we have what we need
208 m = CrossSentenceMention(
209 labels=labels,
210 token_interval=token_interval,
211 anchor=anchor,
212 neighbor=neighbor,
213 # corresponds to anchor.sentence_inde
214 sentence_index=sentence_index,
215 start=start,
216 end=end,
217 document=document,
218 arguments=arguments,
219 paths=None,
220 found_by=found_by,
221 keep=keep
222 )
223 mentions_map[m_id] = m
224 return (m, mentions_map)
225 else:
226 raise Exception(f"Unrecognized mention type {mtype}. Expected one of the following {OdinJsonSerializer.MENTION_TB_TYPE}, {OdinJsonSerializer.MENTION_E_TYPE}, {OdinJsonSerializer.MENTION_R_TYPE}, {OdinJsonSerializer.MENTION_C_TYPE}")
228 @staticmethod
229 def construct_paths(maybe_path_data: typing.Optional[dict[str, typing.Any]]) -> typing.Optional[Mention.Paths]:
230 # FIXME: implement me
231 return None
233 @staticmethod
234 def _load_mention_from_compact_JSON(mention_id: str, compact_json: dict[str, typing.Any], docs_dict: dict[str, Document], mentions_dict: dict[str, Mention]):
235 mjson = compact_json["mentions"][mention_id]
236 # recover document
237 document = docs_dict[mjson["document"]]
238 # TODO: load args
240 # collect components
241 mtype = mjson["type"]
242 labels = mjson["labels"]
243 token_interval = Interval(**mjson["tokenInterval"])
244 if mtype == OdinJsonSerializer.MENTION_TB_TYPE:
245 raise NotImplementedError
246 elif mtype == OdinJsonSerializer.MENTION_E_TYPE:
247 # get or load trigger
248 raise NotImplementedError
249 elif mtype == OdinJsonSerializer.MENTION_R_TYPE:
250 raise NotImplementedError
251 elif mtype == OdinJsonSerializer.MENTION_C_TYPE:
252 raise NotImplementedError
254 kwargs = {
255 "label": mjson.get("label", labels[0]),
256 "labels": labels,
257 "token_interval": Interval.load_from_JSON(mjson["tokenInterval"]),
258 "sentence": mjson["sentence"],
259 "document": doc,
260 "doc_id": doc_id,
261 "trigger": mjson.get("trigger", None),
262 "arguments": mjson.get("arguments", None),
263 "paths": mjson.get("paths", None),
264 "keep": mjson.get("keep", True),
265 "foundBy": mjson["foundBy"]
266 }
267 m = Mention(**kwargs)
268 # set IDs
269 m.id = mjson["id"]
270 m._doc_id = doc_id
271 # set character offsets
272 m.character_start_offset = mjson["characterStartOffset"]
273 m.character_end_offset = mjson["characterEndOffset"]
274 return m
276 # def to_JSON_dict(self):
277 # m = dict()
278 # m["id"] = self.id
279 # m["type"] = self.type
280 # m["label"] = self.label
281 # m["labels"] = self.labels
282 # m["tokenInterval"] = self.tokenInterval.to_JSON_dict()
283 # m["characterStartOffset"] = self.characterStartOffset
284 # m["characterEndOffset"] = self.characterEndOffset
285 # m["sentence"] = self.sentence
286 # m["document"] = self._doc_id
287 # # do we have a trigger?
288 # if self.trigger:
289 # m["trigger"] = self.trigger.to_JSON_dict()
290 # # do we have arguments?
291 # if self.arguments:
292 # m["arguments"] = self._arguments_to_JSON_dict()
293 # # handle paths
294 # if self.paths:
295 # m["paths"] = self.paths
296 # m["keep"] = self.keep
297 # m["foundBy"] = self.foundBy
298 # return m