Source code for malaya_graph.model.text_to_kg

from typing import List, Dict
from malaya_graph.utils.triplet import dict_to_list, rebel_format, parse_rebel
import logging

logger = logging.getLogger(__name__)


class Base:
    def cuda(self, **kwargs):
        return self._model.model.cuda(**kwargs)


[docs]class TexttoKG(Base): def __init__(self, model): self._model = model
[docs] def generate(self, strings: List[Dict], got_networkx: bool = True, **kwargs): """ Generate list of knowledge graphs from the input. Parameters ---------- strings : List[str] got_networkx: bool, optional (default=True) If True, will generate networkx.MultiDiGraph. **kwargs: vector arguments pass to huggingface `generate` method. Read more at https://huggingface.co/docs/transformers/main_classes/text_generation Returns ------- result: List[List[Dict]] """ if got_networkx: try: import pandas as pd import networkx as nx except BaseException: logger.warning( 'pandas and networkx not installed. Please install it by `pip install pandas networkx` and try again. Will skip to generate networkx.MultiDiGraph' ) got_networkx = False outputs_ = self._model.generate(strings, **kwargs) outputs = [parse_rebel(o) for o in outputs_] for no in range(len(outputs)): if got_networkx: try: df = pd.DataFrame(outputs[no]) G = nx.from_pandas_edgelist( df, source='head', target='tail', edge_attr='type', create_using=nx.MultiDiGraph(), ) except Exception as e: logger.warning(e) G = None else: G = None outputs[no] = {'G': G, 'triple': outputs[no], 'rebel': outputs_[no]} return outputs
[docs]class KGtoText(Base): def __init__(self, model): self._model = model
[docs] def generate(self, kgs: List[List[Dict]], **kwargs): """ Generate a text from list of knowledge graph dictionary. Parameters ---------- kg: List[List[Dict]] list of list of {'head', 'type', 'tail'} **kwargs: vector arguments pass to huggingface `generate` method. Read more at https://huggingface.co/docs/transformers/main_classes/text_generation Returns ------- result: List[str] """ for kg in kgs: for no, k in enumerate(kg): if 'head' not in k and 'type' not in k and 'tail' not in k: raise ValueError('a dict must have `head`, `type` and `tail` properties.') elif not len(k['head']): raise ValueError(f'`head` length must > 0 for knowledge graph index {no}') elif not len(k['type']): raise ValueError(f'`head` length must > 0 for knowledge graph index {no}') elif not len(k['tail']): raise ValueError(f'`head` length must > 0 for knowledge graph index {no}') rebels = [rebel_format(dict_to_list(kg)) for kg in kgs] return self._model.generate(rebels, **kwargs)