Star-Mapper/app.py

687 lines
22 KiB
Python

"""
Neo4j Graph Visualizer - A beautiful, high-performance graph visualization app.
Connects to Neo4j, executes Cypher queries, precomputes layouts in Python,
and renders stunning visualizations in the browser.
"""
import os
import json
import hashlib
import colorsys
import logging
import time
import base64
from collections import defaultdict
from urllib.parse import urlparse
import requests as http_requests
from flask import Flask, render_template, jsonify, request
from layout_engine import compute_layout, get_available_algorithms
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Neo4j HTTP API endpoint (not Bolt)
NEO4J_HTTP_URL = os.environ.get("NEO4J_HTTP_URL", "")
NEO4J_USER = os.environ.get("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.environ.get("NEO4J_PASSWORD", "")
NEO4J_DATABASE = os.environ.get("NEO4J_DATABASE", "neo4j")
SAMPLE_QUERIES_FILE = os.environ.get(
"SAMPLE_QUERIES_FILE",
os.path.join(os.path.dirname(__file__), "config", "sample_queries.json"),
)
# Cache for the last query result (avoids re-querying Neo4j for re-layout)
_last_query_cache: dict = {}
# ---------------------------------------------------------------------------
# Neo4j HTTP Transactional API helpers
# ---------------------------------------------------------------------------
def _neo4j_auth_header():
"""Build Basic auth header for Neo4j HTTP API."""
cred = f"{NEO4J_USER}:{NEO4J_PASSWORD}"
b64 = base64.b64encode(cred.encode()).decode()
return {
"Authorization": f"Basic {b64}",
"Content-Type": "application/json",
"Accept": "application/json;charset=UTF-8",
}
def _neo4j_tx_url(database=None):
"""Build the transactional commit endpoint URL."""
db = database or NEO4J_DATABASE
base = NEO4J_HTTP_URL.rstrip("/")
return f"{base}/db/{db}/tx/commit"
def execute_cypher(cypher: str, params: dict | None = None):
"""
Execute a Cypher query via the Neo4j HTTP Transactional API.
Returns (nodes_dict, edges_list, records_list, keys).
"""
url = _neo4j_tx_url()
headers = _neo4j_auth_header()
payload = {
"statements": [
{
"statement": cypher,
"parameters": params or {},
"resultDataContents": ["row", "graph"],
}
]
}
resp = http_requests.post(url, json=payload, headers=headers, timeout=120)
resp.raise_for_status()
body = resp.json()
# Check for Neo4j-level errors
if body.get("errors"):
err_msgs = "; ".join(e.get("message", str(e)) for e in body["errors"])
raise RuntimeError(err_msgs)
nodes: dict = {}
edges: list = []
seen_edges: set = set()
records_out: list = []
keys: list = []
for result in body.get("results", []):
keys = result.get("columns", [])
for datum in result.get("data", []):
# --- Extract row data for table view ---
row_data = datum.get("row", [])
row = {}
for i, key in enumerate(keys):
row[key] = row_data[i] if i < len(row_data) else None
records_out.append(row)
# --- Extract graph data for visualization ---
graph_data = datum.get("graph", {})
for node_data in graph_data.get("nodes", []):
nid = str(node_data["id"])
if nid not in nodes:
labels = node_data.get("labels", [])
props = node_data.get("properties", {})
display = (
props.get("name")
or props.get("title")
or props.get("id")
or props.get("sku")
or (labels[0] if labels else nid)
)
nodes[nid] = {
"id": nid,
"labels": labels,
"properties": props,
"label": str(display)[:80],
}
for rel_data in graph_data.get("relationships", []):
eid = str(rel_data["id"])
if eid not in seen_edges:
seen_edges.add(eid)
edges.append(
{
"id": eid,
"source": str(rel_data["startNode"]),
"target": str(rel_data["endNode"]),
"type": rel_data.get("type", "RELATED"),
"properties": rel_data.get("properties", {}),
}
)
return nodes, edges, records_out, keys
def _execute_simple(cypher: str):
"""Execute a simple Cypher query and return rows."""
url = _neo4j_tx_url()
headers = _neo4j_auth_header()
payload = {"statements": [{"statement": cypher}]}
resp = http_requests.post(url, json=payload, headers=headers, timeout=30)
resp.raise_for_status()
body = resp.json()
if body.get("errors"):
err_msgs = "; ".join(e.get("message", str(e)) for e in body["errors"])
raise RuntimeError(err_msgs)
rows = []
for result in body.get("results", []):
for datum in result.get("data", []):
rows.append(datum.get("row", []))
return rows
def _default_sample_queries():
"""Fallback sample queries when no external file is available."""
return [
{
"name": "Product Neighborhood (200)",
"query": "MATCH (p) WHERE 'Product' IN labels(p) WITH p LIMIT 200 MATCH (p)-[r]-(n) RETURN p, r, n LIMIT 1000",
},
{
"name": "Products by Category",
"query": "MATCH (p)-[r]-(c) WHERE 'Product' IN labels(p) AND 'Category' IN labels(c) RETURN p, r, c LIMIT 800",
},
{
"name": "Products by Brand",
"query": "MATCH (p)-[r]-(b) WHERE 'Product' IN labels(p) AND 'Brand' IN labels(b) RETURN p, r, b LIMIT 800",
},
{
"name": "Supplier to Product Network",
"query": "MATCH (s)-[r]-(p) WHERE 'Supplier' IN labels(s) AND 'Product' IN labels(p) RETURN s, r, p LIMIT 800",
},
{
"name": "Product Attributes",
"query": "MATCH (p)-[r]-(a) WHERE 'Product' IN labels(p) AND any(lbl IN labels(a) WHERE lbl IN ['Attribute','Color','Material','Tag']) RETURN p, r, a LIMIT 1000",
},
{
"name": "Most Connected Products",
"query": "MATCH (p)-[r]-() WHERE 'Product' IN labels(p) WITH p, count(r) AS degree ORDER BY degree DESC LIMIT 25 MATCH (p)-[r2]-(n) RETURN p, r2, n LIMIT 1200",
},
{
"name": "Category Graph (Depth 2)",
"query": "MATCH (c) WHERE 'Category' IN labels(c) WITH c LIMIT 20 MATCH path=(c)-[*1..2]-(related) RETURN path LIMIT 500",
},
{
"name": "Review Connections",
"query": "MATCH (p)-[r]-(rv) WHERE 'Product' IN labels(p) AND 'Review' IN labels(rv) RETURN p, r, rv LIMIT 800",
},
{
"name": "Relationship Type Counts",
"query": "MATCH ()-[r]->() RETURN type(r) AS type, count(*) AS count ORDER BY count DESC LIMIT 25",
},
{
"name": "Node Label Counts",
"query": "MATCH (n) UNWIND labels(n) AS label RETURN label, count(*) AS count ORDER BY count DESC LIMIT 25",
},
{"name": "Schema Visualization", "query": "CALL db.schema.visualization()"},
]
def _load_sample_queries():
"""Load sample queries from JSON, falling back to sensible defaults."""
try:
with open(SAMPLE_QUERIES_FILE, "r", encoding="utf-8") as fh:
payload = json.load(fh)
except FileNotFoundError:
logger.warning("Sample query file not found: %s", SAMPLE_QUERIES_FILE)
return _default_sample_queries()
except Exception as exc:
logger.warning(
"Failed to load sample queries from %s: %s", SAMPLE_QUERIES_FILE, exc
)
return _default_sample_queries()
if not isinstance(payload, list):
logger.warning(
"Sample query file must contain a JSON array: %s", SAMPLE_QUERIES_FILE
)
return _default_sample_queries()
valid_queries = []
for idx, item in enumerate(payload):
if not isinstance(item, dict):
logger.warning("Skipping sample query #%d: expected object", idx)
continue
name = item.get("name")
query = item.get("query")
if not isinstance(name, str) or not name.strip():
logger.warning("Skipping sample query #%d: missing non-empty 'name'", idx)
continue
if not isinstance(query, str) or not query.strip():
logger.warning("Skipping sample query #%d: missing non-empty 'query'", idx)
continue
valid_queries.append({"name": name.strip(), "query": query.strip()})
if not valid_queries:
logger.warning("No valid sample queries found in %s", SAMPLE_QUERIES_FILE)
return _default_sample_queries()
return valid_queries
# ---------------------------------------------------------------------------
# Color generation
# ---------------------------------------------------------------------------
_PALETTE = [
"#00d4ff",
"#ff6b6b",
"#ffd93d",
"#6bcb77",
"#9b59b6",
"#e67e22",
"#1abc9c",
"#e74c3c",
"#3498db",
"#f39c12",
"#2ecc71",
"#e91e63",
"#00bcd4",
"#ff9800",
"#8bc34a",
"#673ab7",
"#009688",
"#ff5722",
"#607d8b",
"#cddc39",
]
def color_for_label(label: str) -> str:
"""Return a vivid, consistent color for a label string."""
idx = int(hashlib.md5(label.encode()).hexdigest()[:8], 16)
return _PALETTE[idx % len(_PALETTE)]
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.route("/")
def index():
return render_template("index.html")
@app.route("/api/query", methods=["POST"])
def api_query():
data = request.get_json(force=True)
cypher = data.get("query", "").strip()
layout_algo = data.get("layout", "auto")
spacing = float(data.get("spacing", 1.0))
iterations = int(data.get("iterations", 300))
if not cypher:
return jsonify({"error": "Empty query"}), 400
try:
t0 = time.time()
nodes_dict, edges, records, keys = execute_cypher(cypher)
t_query = time.time() - t0
# Cache raw results for re-layout
_last_query_cache.clear()
_last_query_cache["nodes_dict"] = {k: dict(v) for k, v in nodes_dict.items()}
_last_query_cache["edges"] = [dict(e) for e in edges]
_last_query_cache["records"] = records
_last_query_cache["keys"] = keys
# Assign colours
label_colors: dict[str, str] = {}
for nd in nodes_dict.values():
for lb in nd.get("labels", []):
if lb not in label_colors:
label_colors[lb] = color_for_label(lb)
# Compute layout server-side
t1 = time.time()
positions = compute_layout(
nodes_dict,
edges,
algorithm=layout_algo,
spacing=spacing,
iterations=iterations,
)
t_layout = time.time() - t1
# Degree for sizing
degree: dict[str, int] = defaultdict(int)
for e in edges:
degree[e["source"]] += 1
degree[e["target"]] += 1
max_deg = max(degree.values()) if degree else 1
nodes_list = []
for nid, nd in nodes_dict.items():
pos = positions.get(nid, {"x": 0, "y": 0})
primary = nd["labels"][0] if nd.get("labels") else "Unknown"
nd["x"] = pos["x"]
nd["y"] = pos["y"]
nd["color"] = label_colors.get(primary, "#888888")
d = degree.get(nid, 0)
nd["size"] = 3 + (d / max(max_deg, 1)) * 22
nodes_list.append(nd)
# Deduplicate edges (keep unique source-target-type combos)
seen = set()
unique_edges = []
for e in edges:
key = (e["source"], e["target"], e["type"])
if key not in seen:
seen.add(key)
unique_edges.append(e)
return jsonify(
{
"nodes": nodes_list,
"edges": unique_edges,
"label_colors": label_colors,
"records": records[:500], # cap tabular results
"keys": keys,
"stats": {
"node_count": len(nodes_list),
"edge_count": len(unique_edges),
"labels": list(label_colors.keys()),
"query_time_ms": round(t_query * 1000),
"layout_time_ms": round(t_layout * 1000),
},
}
)
except Exception as exc:
logger.exception("Query failed")
return jsonify({"error": str(exc)}), 400
@app.route("/api/relayout", methods=["POST"])
def api_relayout():
"""Re-run layout on the cached query result without hitting Neo4j."""
if not _last_query_cache:
return jsonify({"error": "No cached query result. Run a query first."}), 400
data = request.get_json(force=True)
layout_algo = data.get("layout", "auto")
spacing = float(data.get("spacing", 1.0))
iterations = int(data.get("iterations", 300))
try:
# Deep-copy cached data so layout doesn't mutate the cache
nodes_dict = {k: dict(v) for k, v in _last_query_cache["nodes_dict"].items()}
edges = [dict(e) for e in _last_query_cache["edges"]]
records = _last_query_cache["records"]
keys = _last_query_cache["keys"]
# Assign colours
label_colors: dict[str, str] = {}
for nd in nodes_dict.values():
for lb in nd.get("labels", []):
if lb not in label_colors:
label_colors[lb] = color_for_label(lb)
# Compute layout
t1 = time.time()
positions = compute_layout(
nodes_dict, edges,
algorithm=layout_algo, spacing=spacing, iterations=iterations,
)
t_layout = time.time() - t1
# Degree for sizing
degree: dict[str, int] = defaultdict(int)
for e in edges:
degree[e["source"]] += 1
degree[e["target"]] += 1
max_deg = max(degree.values()) if degree else 1
nodes_list = []
for nid, nd in nodes_dict.items():
pos = positions.get(nid, {"x": 0, "y": 0})
primary = nd["labels"][0] if nd.get("labels") else "Unknown"
nd["x"] = pos["x"]
nd["y"] = pos["y"]
nd["color"] = label_colors.get(primary, "#888888")
d = degree.get(nid, 0)
nd["size"] = 3 + (d / max(max_deg, 1)) * 22
nodes_list.append(nd)
seen = set()
unique_edges = []
for e in edges:
key = (e["source"], e["target"], e["type"])
if key not in seen:
seen.add(key)
unique_edges.append(e)
return jsonify({
"nodes": nodes_list,
"edges": unique_edges,
"label_colors": label_colors,
"records": records[:500],
"keys": keys,
"stats": {
"node_count": len(nodes_list),
"edge_count": len(unique_edges),
"labels": list(label_colors.keys()),
"query_time_ms": 0,
"layout_time_ms": round(t_layout * 1000),
},
})
except Exception as exc:
logger.exception("Re-layout failed")
return jsonify({"error": str(exc)}), 400
@app.route("/api/schema")
def api_schema():
try:
labels = [r[0] for r in _execute_simple("CALL db.labels()")]
rel_types = [r[0] for r in _execute_simple("CALL db.relationshipTypes()")]
prop_keys = [r[0] for r in _execute_simple("CALL db.propertyKeys()")]
return jsonify(
{
"labels": labels,
"relationship_types": rel_types,
"property_keys": prop_keys,
}
)
except Exception as exc:
return jsonify({"error": str(exc)}), 400
@app.route("/api/connection-test")
def api_connection_test():
try:
rows = _execute_simple("RETURN 1 AS ok")
if rows and rows[0][0] == 1:
return jsonify({"status": "connected", "uri": NEO4J_HTTP_URL})
raise RuntimeError("Unexpected response")
except Exception as exc:
return jsonify({"status": "error", "message": str(exc)}), 500
@app.route("/api/reconnect", methods=["POST"])
def api_reconnect():
global NEO4J_HTTP_URL, NEO4J_USER, NEO4J_PASSWORD, NEO4J_DATABASE
data = request.get_json(force=True)
new_url = data.get("uri", "").strip()
new_user = data.get("user", "").strip()
new_pass = data.get("password", "")
if not new_url:
return jsonify({"status": "error", "message": "URL is required"}), 400
NEO4J_HTTP_URL = new_url
NEO4J_USER = new_user
NEO4J_PASSWORD = new_pass
try:
rows = _execute_simple("RETURN 1 AS ok")
if rows and rows[0][0] == 1:
return jsonify({"status": "connected", "uri": NEO4J_HTTP_URL})
raise RuntimeError("Unexpected response")
except Exception as exc:
return jsonify({"status": "error", "message": str(exc)}), 500
@app.route("/api/layouts")
def api_layouts():
return jsonify(get_available_algorithms())
@app.route("/api/sample-queries")
def api_sample_queries():
return jsonify(_load_sample_queries())
@app.route("/api/demo", methods=["POST"])
def api_demo():
"""Generate a demo graph for testing the visualization without Neo4j."""
import random
data = request.get_json(force=True) if request.is_json else {}
size = min(int(data.get("size", 300)), 5000)
layout_algo = data.get("layout", "auto")
spacing = float(data.get("spacing", 1.0))
iterations = int(data.get("iterations", 300))
random.seed(42)
label_types = [
"Product",
"Category",
"Brand",
"Supplier",
"Attribute",
"Color",
"Material",
"Tag",
"Collection",
"Review",
]
rel_types = [
"BELONGS_TO",
"MADE_BY",
"SUPPLIED_BY",
"HAS_ATTRIBUTE",
"HAS_COLOR",
"MADE_OF",
"TAGGED_WITH",
"PART_OF",
"REVIEWED_IN",
"SIMILAR_TO",
]
adj_names = [
"Premium",
"Eco",
"Organic",
"Classic",
"Modern",
"Vintage",
"Smart",
"Ultra",
"Compact",
"Deluxe",
]
noun_names = [
"Widget",
"Gadget",
"Module",
"Unit",
"Element",
"Component",
"System",
"Kit",
"Bundle",
"Pack",
]
nodes_dict = {}
edges = []
# assign label distribution (more products, fewer reviews)
weights = [30, 15, 10, 8, 10, 5, 5, 7, 5, 5]
for i in range(size):
r = random.random() * sum(weights)
cumulative = 0
chosen_label = label_types[0]
for idx, w in enumerate(weights):
cumulative += w
if r <= cumulative:
chosen_label = label_types[idx]
break
name = f"{random.choice(adj_names)} {random.choice(noun_names)} {i}"
nid = f"demo_{i}"
nodes_dict[nid] = {
"id": nid,
"labels": [chosen_label],
"properties": {
"name": name,
"sku": f"SKU-{i:05d}",
"price": round(random.uniform(5, 500), 2),
},
"label": name,
}
# Create edges — mix of random & preferential attachment
node_ids = list(nodes_dict.keys())
edge_count = int(size * 1.5)
degree = defaultdict(int)
for _ in range(edge_count):
src = random.choice(node_ids)
# Preferential attachment: higher-degree nodes more likely as targets
if random.random() < 0.3 and degree:
top = sorted(degree.keys(), key=lambda nid: degree[nid], reverse=True)[:10]
tgt = random.choice(top)
else:
tgt = random.choice(node_ids)
if src != tgt:
edges.append(
{
"id": f"edge_{len(edges)}",
"source": src,
"target": tgt,
"type": random.choice(rel_types),
"properties": {},
}
)
degree[src] += 1
degree[tgt] += 1
# Color
label_colors = {lt: color_for_label(lt) for lt in label_types}
# Layout
t1 = time.time()
positions = compute_layout(
nodes_dict, edges, algorithm=layout_algo, spacing=spacing, iterations=iterations
)
t_layout = time.time() - t1
max_deg = max(degree.values()) if degree else 1
nodes_list = []
for nid, nd in nodes_dict.items():
pos = positions.get(nid, {"x": 0, "y": 0})
primary = nd["labels"][0]
nd["x"] = pos["x"]
nd["y"] = pos["y"]
nd["color"] = label_colors.get(primary, "#888")
d = degree.get(nid, 0)
nd["size"] = 3 + (d / max(max_deg, 1)) * 22
nodes_list.append(nd)
return jsonify(
{
"nodes": nodes_list,
"edges": edges,
"label_colors": label_colors,
"records": [],
"keys": [],
"stats": {
"node_count": len(nodes_list),
"edge_count": len(edges),
"labels": list(label_colors.keys()),
"query_time_ms": 0,
"layout_time_ms": round(t_layout * 1000),
},
}
)
# ---------------------------------------------------------------------------
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5555)