sqlmodel-selectin/database.py
Waylon S. Walker 48c6176c47 init
2025-07-03 09:48:05 -05:00

138 lines
4.3 KiB
Python

from sqlmodel import SQLModel, create_engine, Session, select
from sqlalchemy import text
from rich.progress import Progress
from faker import Faker
fake = Faker()
import random
from models import User, Post
# Setup SQLite DB
sqlite_url = "sqlite:///data.db"
engine = create_engine(sqlite_url, echo=True)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
def setup_db():
with engine.connect() as conn:
conn.execute(text("PRAGMA foreign_keys = ON;"))
conn.execute(text("PRAGMA journal_mode = WAL;"))
conn.execute(text("PRAGMA synchronous = NORMAL;"))
conn.execute(text("PRAGMA temp_store = MEMORY;"))
conn.execute(text("PRAGMA cache_size = 50000;"))
post_templates = [
"Just published a new blog about {topic}. Check it out!",
"Today's thoughts: {quote}",
"Here's a photo from my trip to {location} 🌍",
"Shoutout to all the amazing devs working with {topic}!",
"Feeling inspired by {quote}",
"If you're into {topic}, we should connect!",
"Throwback to {location} last summer ☀️",
]
topics = ["AI", "SQLModel", "Linux", "Neovim", "Docker", "Kubernetes", "FastAPI"]
quotes = [
"Keep it simple.",
"Fail fast, learn faster.",
"In code we trust.",
"Life is short, automate it.",
"Stay curious.",
]
locations = ["Tokyo", "Paris", "New York", "Berlin", "São Paulo", "Cairo"]
def add_posts_to_user(user_id, num_posts=5):
with Session(engine) as session:
for _ in range(num_posts):
post = Post(
content=random.choice(post_templates).format(
topic=random.choice(topics),
quote=random.choice(quotes),
location=random.choice(locations),
),
timestamp=fake.date_time_this_year(),
user_id=user_id,
)
session.add(post)
session.commit()
def populate_data(
num_users: int = 100, posts_per_user_range=(5, 20), batch_size=10_000
):
total_posts = 0
with Progress(transient=True) as progress:
user_task = progress.add_task("[bold blue]Creating users...", total=num_users)
users = []
for _ in range(num_users):
user = User(
username=fake.user_name(),
display_name=fake.name(),
)
users.append(user)
progress.advance(user_task)
# Bulk insert users
with Session(engine) as session:
for i in range(0, len(users), batch_size):
session.bulk_save_objects(users[i : i + batch_size])
session.commit()
# Fetch user ids after bulk insert
with Session(engine) as session:
user_ids = session.exec(select(User.id)).all()
# Generate posts
post_task = progress.add_task("[green]Creating posts...", total=num_users)
posts = []
for user_id in user_ids:
for _ in range(random.randint(*posts_per_user_range)):
content = random.choice(post_templates).format(
topic=random.choice(topics),
quote=random.choice(quotes),
location=random.choice(locations),
)
post = Post(
content=content,
timestamp=fake.date_time_this_year(),
user_id=user_id,
)
posts.append(post)
progress.advance(post_task)
# Bulk insert posts
with Session(engine) as session:
for i in range(0, len(posts), batch_size):
session.bulk_save_objects(posts[i : i + batch_size])
session.commit()
print(f"✅ Done! Inserted {len(users)} users and {len(posts)} posts.")
def query_example():
with Session(engine) as session:
statement = select(User)
users = session.exec(statement).all()
for user in users:
print(f"User: {user.display_name} ({user.username})")
for post in user.posts:
print(f" - {post.content} [{post.timestamp.isoformat()}]")
if __name__ == "__main__":
import sys
n_users = sys.argv[1] if len(sys.argv) > 1 else 100
n_posts = sys.argv[2] if len(sys.argv) > 2 else 20
create_db_and_tables()
setup_db()
populate_data(int(n_users), (int(5), int(n_posts)))