diff --git a/examples/graph-analytics-serverless-ml-models.ipynb b/examples/graph-analytics-serverless-ml-models.ipynb new file mode 100644 index 000000000..4c4072a85 --- /dev/null +++ b/examples/graph-analytics-serverless-ml-models.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "aura" + ] + }, + "source": [ + "# Aura Graph Analytics using Models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Jupyter notebook is hosted [here](https://github.com/neo4j/graph-data-science-client/blob/main/examples/graph-analytics-serverless.ipynb) in the Neo4j Graph Data Science Client Github repository.\n", + "\n", + "The notebook shows how to use the `graphdatascience` Python library to create, manage, and use a GDS Session.\n", + "\n", + "We consider a graph of people and fruits, which we're using as a simple example to show how to connect your AuraDB instance to a GDS Session, run algorithms, and eventually write back your analytical results to the AuraDB database. \n", + "We will cover all management operations: creation, listing, and deletion.\n", + "\n", + "If you are using self managed DB, follow [this example](../graph-analytics-serverless-self-managed)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "This notebook requires having an AuraDB instance available and have the Aura Graph Analytics [feature](https://neo4j.com/docs/aura/graph-analytics/#aura-gds-serverless) enabled for your project.\n", + "\n", + "You also need to have the `graphdatascience` Python library installed, version `1.15` or later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "verify-version" + ] + }, + "outputs": [], + "source": [ + "%pip install \"graphdatascience>=1.20\" python-dotenv \"neo4j_viz[gds]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "\n", + "# This allows to load required secrets from `.env` file in local directory\n", + "# This can include Aura API Credentials and Database Credentials.\n", + "# If file does not exist this is a noop.\n", + "load_dotenv(\"staging_ci_enterprise.env\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Aura API credentials\n", + "\n", + "The entry point for managing GDS Sessions is the `GdsSessions` object, which requires creating [Aura API credentials](https://neo4j.com/docs/aura/api/authentication)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from graphdatascience.session import AuraAPICredentials, GdsSessions\n", + "\n", + "# you can also use AuraAPICredentials.from_env() to load credentials from environment variables\n", + "api_credentials = AuraAPICredentials(\n", + " client_id=os.environ[\"CLIENT_ID\"],\n", + " client_secret=os.environ[\"CLIENT_SECRET\"],\n", + " # If your account is a member of several project, you must also specify the project ID to use\n", + " project_id=os.environ.get(\"PROJECT_ID\", None),\n", + ")\n", + "\n", + "sessions = GdsSessions(api_credentials=api_credentials)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating a new session\n", + "\n", + "A new session is created by calling `sessions.get_or_create()` with the following parameters:\n", + "\n", + "* A session name, which lets you reconnect to an existing session by calling `get_or_create` again.\n", + "* The `DbmsConnectionInfo` containing the address, user name and password to an AuraDB instance\n", + "* The session memory. \n", + "* The cloud location.\n", + "* A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.\n", + "\n", + "See the API reference [documentation](https://neo4j.com/docs/graph-data-science-client/current/api/sessions/gds_sessions/#graphdatascience.session.gds_sessions.GdsSessions.get_or_create) or the manual for more details on the parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from graphdatascience.session import AlgorithmCategory, SessionMemory\n", + "\n", + "# Explicitly define memory\n", + "memory = SessionMemory.m_2GB\n", + "\n", + "# Estimate the memory needed for the GDS session\n", + "memory = sessions.estimate(\n", + " node_count=20,\n", + " relationship_count=50,\n", + " algorithm_categories=[AlgorithmCategory.NODE_EMBEDDING],\n", + ")\n", + "\n", + "print(f\"Estimated memory for the session: {memory}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import timedelta\n", + "\n", + "from graphdatascience.session import CloudLocation\n", + "\n", + "# Create a GDS session!\n", + "gds = sessions.get_or_create(\n", + " # we give it a representative name\n", + " session_name=\"training_session\",\n", + " memory=memory,\n", + " db_connection=None,\n", + " ttl=timedelta(minutes=30),\n", + " cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Verify the connectivity. Hints towards TLS or firewall issues if this fails directly after get_or_create\n", + "gds.verify_connectivity()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Listing sessions\n", + "\n", + "You can use `sessions.list()` to see the details for each created session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pandas import DataFrame\n", + "\n", + "gds_sessions = sessions.list()\n", + "\n", + "# for better visualization\n", + "DataFrame(gds_sessions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Projecting Graphs\n", + "\n", + "Now that we have imported a graph to our database, we can project it into our GDS Session.\n", + "We do that by using the `gds.graph.project()` endpoint.\n", + "\n", + "The remote projection query that we are using selects all `Person` nodes and their `LIKES` relationships, and all `Fruit` nodes and their `LIKES` relationships.\n", + "Additionally, we project node properties for illustrative purposes.\n", + "We can use these node properties as input to algorithms, although we do not do that in this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "G = gds.v2.graph.datasets.load_cora()\n", + "str(G)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let us visualize the projected graph\n", + "from neo4j_viz.gds import from_gds\n", + "\n", + "VG = from_gds(gds, gds.graph.get(G.name()), max_node_count=50)\n", + "for node in VG.nodes:\n", + " node.caption = node.properties.get(\"name\")\n", + "\n", + "VG.render(initial_zoom=1.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training the GraphSage model\n", + "\n", + "You can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Running GraphSage ...\")\n", + "model, result = gds.v2.graph_sage.train(\n", + " G, model_name=\"gs_example_model\", feature_properties=[\"subject\", \"features\"], store_model_to_disk=True\n", + ")\n", + "print(f\"Training result: {result}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model, result = gds.v2.graph_sage.train(G, model_name=\"gs_example_model_2\", feature_properties=[\"subject\", \"features\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"gs_example_model_2\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "store_result = gds.v2.model.store(model_name)\n", + "print(f\"Model stored with result: {store_result}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_details = gds.v2.model.get(model_name)\n", + "print(model_details)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.v2.model.delete(model_name) # remove the persisted model. its still available in the session" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.v2.model.store(model_name) # store it again to make sure it is available after deletion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sessions.delete(session_name=\"training_session\") # delete the session as we are done with the training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create embeddings\n", + "\n", + "Now we can use the model in new sessions to create embeddings for the same graph or also new graphs with the same schema." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import timedelta\n", + "\n", + "from graphdatascience.session import CloudLocation, SessionMemory\n", + "\n", + "memory = SessionMemory.m_2GB\n", + "gds = sessions.get_or_create(\n", + " session_name=\"inference_session\",\n", + " memory=memory,\n", + " ttl=timedelta(minutes=30),\n", + " cloud_location=CloudLocation(\"gcp\", \"europe-west1\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "G = gds.v2.graph.datasets.load_cora()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.v2.model.list() # check the model is available" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_details = gds.v2.model.get(model_name)\n", + "print(model_details)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gds.v2.model.load(model_name)\n", + "gds.v2.graph_sage.stream(G, model_name=model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sessions.delete(session_name=\"inference_session\") # delete the session as we are done with the inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# let's also make sure the deleted session is truly gone:\n", + "sessions.list()" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}