from typing import List, Dict
from malaya_graph.utils.triplet import dict_to_list, rebel_format, parse_rebel
import logging
logger = logging.getLogger(__name__)
class Base:
def cuda(self, **kwargs):
return self._model.model.cuda(**kwargs)
[docs]class TexttoKG(Base):
def __init__(self, model):
self._model = model
[docs] def generate(self, strings: List[Dict], got_networkx: bool = True, **kwargs):
"""
Generate list of knowledge graphs from the input.
Parameters
----------
strings : List[str]
got_networkx: bool, optional (default=True)
If True, will generate networkx.MultiDiGraph.
**kwargs: vector arguments pass to huggingface `generate` method.
Read more at https://huggingface.co/docs/transformers/main_classes/text_generation
Returns
-------
result: List[List[Dict]]
"""
if got_networkx:
try:
import pandas as pd
import networkx as nx
except BaseException:
logger.warning(
'pandas and networkx not installed. Please install it by `pip install pandas networkx` and try again. Will skip to generate networkx.MultiDiGraph'
)
got_networkx = False
outputs_ = self._model.generate(strings, **kwargs)
outputs = [parse_rebel(o) for o in outputs_]
for no in range(len(outputs)):
if got_networkx:
try:
df = pd.DataFrame(outputs[no])
G = nx.from_pandas_edgelist(
df,
source='head',
target='tail',
edge_attr='type',
create_using=nx.MultiDiGraph(),
)
except Exception as e:
logger.warning(e)
G = None
else:
G = None
outputs[no] = {'G': G, 'triple': outputs[no], 'rebel': outputs_[no]}
return outputs
[docs]class KGtoText(Base):
def __init__(self, model):
self._model = model
[docs] def generate(self, kgs: List[List[Dict]], **kwargs):
"""
Generate a text from list of knowledge graph dictionary.
Parameters
----------
kg: List[List[Dict]]
list of list of {'head', 'type', 'tail'}
**kwargs: vector arguments pass to huggingface `generate` method.
Read more at https://huggingface.co/docs/transformers/main_classes/text_generation
Returns
-------
result: List[str]
"""
for kg in kgs:
for no, k in enumerate(kg):
if 'head' not in k and 'type' not in k and 'tail' not in k:
raise ValueError('a dict must have `head`, `type` and `tail` properties.')
elif not len(k['head']):
raise ValueError(f'`head` length must > 0 for knowledge graph index {no}')
elif not len(k['type']):
raise ValueError(f'`head` length must > 0 for knowledge graph index {no}')
elif not len(k['tail']):
raise ValueError(f'`head` length must > 0 for knowledge graph index {no}')
rebels = [rebel_format(dict_to_list(kg)) for kg in kgs]
return self._model.generate(rebels, **kwargs)