Source code for graphrag_sdk.schema.auto_detect

import os
import json
import time
import queue
import sqlite3
import concurrent.futures

from graphrag_sdk.prompts import *
from openai import OpenAI
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine, inspect, MetaData, Table, types

# Convert from SQLAlchemy type to Python type
[docs] def sqlalchemy_to_python_type(sqlalchemy_type): if isinstance(sqlalchemy_type, types.Integer): return int elif isinstance(sqlalchemy_type, types.Numeric): return float elif isinstance(sqlalchemy_type, types.Float): return float elif isinstance(sqlalchemy_type, types.String): return str elif isinstance(sqlalchemy_type, types.Text): return str elif isinstance(sqlalchemy_type, types.Boolean): return bool elif isinstance(sqlalchemy_type, types.Date): #return 'date' return str elif isinstance(sqlalchemy_type, types.DateTime): #return 'datetime' return str elif isinstance(sqlalchemy_type, types.Time): #return 'time' return str elif isinstance(sqlalchemy_type, types.LargeBinary): #return bytes return str elif isinstance(sqlalchemy_type, types.JSON): #return dict return str else: return 'Unknown Type'
# create entity node
[docs] def create_entity(s, client, model, name:str, columns) -> None: e = s.add_entity(name) has_pk = False for col in columns: t = sqlalchemy_to_python_type(col['type']) pk = bool(col['primary_key']) name = col['name'] if pk: has_pk = True e.add_attribute(name, t, unique=pk, mandatory=pk) # Require each entity to have at least one unique attribute if has_pk is False: user_msg = f"""Which of the following attributes is most suitble to identify a graph node of type: {name} ? List of attributes: {','.join([col['name'] for col in columns])} """ response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": FIND_UNIQUE_ATTRIBUTE_PROMPT}, {"role": "user", "content": user_msg} ] ) attr = response.choices[0].message.content attr = e.get_attribute(attr) attr.unique = True attr.mandatory = True
# Create entity node in addition to a single relation
[docs] def create_entity_with_relation(s, client, model, name:str, columns, foreign_keys) -> None: # Remove foreign_keys from columns for foreign_key in foreign_keys: for constrained_column in foreign_key['constrained_columns']: for idx, col in enumerate(columns): if col['name'] == constrained_column: del columns[idx] break # create entity create_entity(s, client, model, name, columns) # Determine relation destination node type dest = foreign_keys[0]['referred_table'] src = s.get_entity(name) dest = s.get_entity(dest) # Get relationship name user_msg = f"""Suggest a name for a relationship-type connecting source node of type {src.name} to destination node of type {dest.name}""" response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": RELATIONSHIP_NAMING_PROMPT}, {"role": "user", "content": user_msg} ] ) relatoin_name = response.choices[0].message.content.upper() # Create relation s.add_relation(relatoin_name, src, dest)
# Create relation between two entities
[docs] def create_relation(s, name:str, foreign_keys): src = foreign_keys[0]['referred_table'] dest = foreign_keys[1]['referred_table'] # Get source and destination entities # (src)-[]->(dest) src = s.get_entity(src) dest = s.get_entity(dest) # Create relation s.add_relation(name.upper(), src, dest)
# Convert SQL table into a node type or a relation
[docs] def convert_table(s, client, model, tbl, columns, foreign_keys): # Depending on the number of foreign keys in table # 0 -> Entity # 1 -> Entity & Relation # 2 -> Relation l = len(foreign_keys) if l == 0: create_entity(s, client, model, tbl, columns) elif l == 1: pass create_entity_with_relation(s, client, model, tbl, columns, foreign_keys) elif l == 2: pass create_relation(s, tbl, foreign_keys) else: print("can't process table with more then 2 forigen keys")
[docs] def convert_tables(s, client, model, inspector, table_names): # Start by converting tabels without foreign keys for table_name in table_names: columns = inspector.get_columns(table_name) foreign_keys = inspector.get_foreign_keys(table_name) if len(foreign_keys) == 0: convert_table(s, client, model, table_name, columns, foreign_keys) # Continue on processing the rest of the tables for table_name in table_names: columns = inspector.get_columns(table_name) foreign_keys = inspector.get_foreign_keys(table_name) if len(foreign_keys) > 0: convert_table(s, client, model, table_name, columns, foreign_keys)
# Create graph ontology from a set of SQL tables
[docs] def generate_graph_schema(s, client, model): # Connect to local SQLite DB database_file = 'ontology.db' connection_string = f'sqlite:///{database_file}' # Create an engine engine = create_engine(connection_string) # Verify the connection by connecting to the database connection = engine.connect() # Create a configured "Session" class Session = sessionmaker(bind=engine) # Create a Session session = Session() inspector = inspect(engine) table_names = inspector.get_table_names() # Process tabels into schema convert_tables(s, client, model, inspector, table_names) # Close the session when done session.close() # Close the connection connection.close()
# Improve relations nameing #for rel in s.relations: # r = s.relations[rel] # user_msg = f"""Suggest a name for a relationship-type connecting # source node of type {r.src.name} to destination node of type {r.dest.name}""" # response = client.chat.completions.create( # model=model, # messages=[ # {"role": "system", "content": RELATIONSHIP_NAMING_PROMPT}, # {"role": "user", "content": user_msg} # ] # ) # old = r.name # r.name = response.choices[0].message.content.upper() # print(f"Renamed from {old} to: {r.name}") # OpenAI assistant tool # queue sql statment
[docs] def run_sql(arg, queue): sql = arg["sql"] print(f"SQL: {sql}") queue.put(sql)
# OpenAI available functions available_functions = { "run_sql": run_sql } # Get OpenAI ontology detection assistant
[docs] def get_assistant(client, assistant_name, model): # Try reusing existing assistant if os.path.exists("./ONTOLOGY_ASSISTANT_ID"): with open("./ONTOLOGY_ASSISTANT_ID", 'r') as f: ASSISTANT_ID = f.read() try: return client.beta.assistants.retrieve(ASSISTANT_ID) except: pass # Create OpenAI assistant assistant = client.beta.assistants.create( name = assistant_name, instructions = ONTOLOGY_DETECTION_SYSTEM_PROMPT, model = model, tools=[{ "type": "function", "function": { "name": "run_sql", "description": "Run SQL statement", "parameters": { "type": "object", "properties": { "sql": {"type": "string", "description": "SQL statement"} }, "required": ["sql"] } } }] ) # Save assistant ID to local file with open("./ONTOLOGY_ASSISTANT_ID", 'w') as f: f.write(assistant.id) return assistant
# Start an interaction with OpenAI assistant
[docs] def initiate_interaction(client, assistant, user_message): run = client.beta.threads.create_and_run( assistant_id=assistant.id, thread={ "messages": [ {"role": "user", "content": user_message} ] } ) return run
# OpenAI run handler
[docs] def handle_run(client, run, queue): while True: run = client.beta.threads.runs.retrieve(thread_id=run.thread_id, run_id=run.id) if run.status == "completed": print("Done processing document") return if run.status == "expired": print(f"Processing expired") return if run.status == "cancelling" or run.status == "cancelled" or run.status == "failed": print(f"Processing failed") return # wait for run to complete if run.status == "queued": time.sleep(1) continue if run.status == "in_progress": last_step_id = None run_steps = [] run_steps = client.beta.threads.runs.steps.list(thread_id=run.thread_id, run_id=run.id) for step in run_steps.data: if step.status == "in_progress": step_details = step.step_details if step_details.type == "tool_calls": for tool_call in step_details.tool_calls: try: if tool_call.type == "function": func = tool_call.function func_name = func.name args = func.arguments args = json.loads(args) f = available_functions[func_name] f(args, queue) except: print(f"tool_call: {tool_call}") pass last_step_id = run_steps.last_id time.sleep(1) continue if run.status == "requires_action": action = run.required_action if action.type == "submit_tool_outputs": tool_outputs = action.submit_tool_outputs outputs = [] for tool in tool_outputs.tool_calls: tool_call_id = tool.id tool_type = tool.type func = tool.function func_name = func.name args = func.arguments args = json.loads(args) f = available_functions[func_name] f(args, queue) outputs.append({ "tool_call_id": tool_call_id, "output": "OK" }) # respond to action run = client.beta.threads.runs.submit_tool_outputs( thread_id=run.thread_id, run_id=run.id, tool_outputs=outputs ) continue
# process source # present content to LLM for graph ontology detection
[docs] def process_source(client, assistant, src, queue): text = src.load()[:32000] user_message = "Extract schema from following text:" + text run = initiate_interaction(client, assistant, user_message) handle_run(client, run, queue) client.beta.threads.delete(run.thread_id)
# Detect graph ontology from a set of sources
[docs] def schema_auto_detect(schema, sources, model="gpt-3.5-turbo-0125"): client = OpenAI() # Get assistant assistant = get_assistant(client, "ontology detection", model) # Create a new local SQLite DB db_file = "ontology.db" if os.path.exists(db_file): os.remove(db_file) # Process sources q = queue.Queue() with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: tasks = [] # extract entities and relationships from each page for src in sources: task = executor.submit(process_source, client, assistant, src, q) # Wait for all tasks to complete concurrent.futures.wait(tasks) # Collect duplicated CREATE statments n = q.qsize() tables = {} for i in range(n): stmt = q.get() print(f"sql: {stmt}") # group by table name # CREATE TABLE MoviePersonnel ( s = stmt.index("CREATE TABLE ") + len("CREATE TABLE ") e = stmt.index("(") - 1 tbl = stmt[s:e] print(f"tbl: {tbl}") if tbl not in tables: tables[tbl] = [] tables[tbl].append(stmt) # Merge duplicated create table statments for k in tables: while len(tables[k]) > 1: user_msg = f"""Merge to following two SQL CREATE TABLE statments: 1. {tables[k].pop()} 2. {tables[k].pop()} """ response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": SQL_FOLDING_PROMPT}, {"role": "user", "content": user_msg} ] ) merged_stmt = response.choices[0].message.content print(f"merged_stmt: {merged_stmt}") tables[k].append(merged_stmt) # Switch back from dict to array tables = [tables[x][0] for x in tables] # execute SQL statments con = sqlite3.connect(db_file) cur = con.cursor() for stmt in tables: print(f"Executing: {stmt}") cur.execute(stmt) con.commit() # Remove esoteric columns #for i, stmt in enumerate(tables): # user_msg = f"""Remove less descriptive columns from the following # CREATE TABLE SQL statment: {stmt}""" # response = client.chat.completions.create( # model=model, # messages=[ # {"role": "system", "content": SQL_FOCUS_PROMPT}, # {"role": "user", "content": user_msg} # ] # ) # refined_stmt = response.choices[0].message.content # if not refined_stmt.startswith("CREATE TABLE"): # refined_stmt = stmt # print(f"refined_stmt: {refined_stmt}") # tables[i] = refined_stmt cur.close() con.close() # generate schema generate_graph_schema(schema, client, model)