python
This commit is contained in:
135
backend/database.py
Normal file
135
backend/database.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session, relationship
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
telegram_id = Column(String, unique=True, nullable=False)
|
||||
username = Column(String, nullable=True)
|
||||
first_name = Column(String, nullable=True)
|
||||
last_name = Column(String, nullable=True)
|
||||
language_code = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
search_history = relationship("SearchHistory", back_populates="user")
|
||||
downloads = relationship("Download", back_populates="user")
|
||||
|
||||
class SearchHistory(Base):
|
||||
__tablename__ = "search_history"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
query = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="search_history")
|
||||
|
||||
class Download(Base):
|
||||
__tablename__ = "downloads"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
video_id = Column(String, nullable=False)
|
||||
title = Column(String, nullable=False)
|
||||
file_path = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="downloads")
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
# Get database URL from environment or use default
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
if not database_url:
|
||||
raise ValueError("DATABASE_URL environment variable is required")
|
||||
|
||||
# Create engine
|
||||
self.engine = create_engine(
|
||||
database_url,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False} if "sqlite" in database_url else {},
|
||||
echo=False
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
||||
|
||||
# Create tables
|
||||
self.init_db()
|
||||
|
||||
def init_db(self):
|
||||
"""Initialize database tables"""
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
|
||||
def get_session(self) -> Session:
|
||||
"""Get database session"""
|
||||
return self.SessionLocal()
|
||||
|
||||
def close(self):
|
||||
"""Close database connection"""
|
||||
self.engine.dispose()
|
||||
|
||||
async def get_user_by_telegram_id(self, telegram_id: str) -> Optional[User]:
|
||||
"""Get user by Telegram ID"""
|
||||
with self.get_session() as session:
|
||||
return session.query(User).filter(User.telegram_id == telegram_id).first()
|
||||
|
||||
async def add_user(self, telegram_id: str, username: str = None,
|
||||
first_name: str = None, last_name: str = None,
|
||||
language_code: str = None) -> User:
|
||||
"""Add new user"""
|
||||
with self.get_session() as session:
|
||||
user = User(
|
||||
telegram_id=telegram_id,
|
||||
username=username,
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
language_code=language_code
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
async def add_search_history(self, user_id: int, query: str) -> SearchHistory:
|
||||
"""Add search history record"""
|
||||
with self.get_session() as session:
|
||||
history = SearchHistory(user_id=user_id, query=query)
|
||||
session.add(history)
|
||||
session.commit()
|
||||
session.refresh(history)
|
||||
return history
|
||||
|
||||
async def add_download(self, user_id: int, video_id: str, title: str, file_path: str) -> Download:
|
||||
"""Add download record"""
|
||||
with self.get_session() as session:
|
||||
download = Download(
|
||||
user_id=user_id,
|
||||
video_id=video_id,
|
||||
title=title,
|
||||
file_path=file_path
|
||||
)
|
||||
session.add(download)
|
||||
session.commit()
|
||||
session.refresh(download)
|
||||
return download
|
||||
|
||||
async def get_search_history(self, user_id: int, limit: int = 10) -> List[SearchHistory]:
|
||||
"""Get user search history"""
|
||||
with self.get_session() as session:
|
||||
return session.query(SearchHistory)\
|
||||
.filter(SearchHistory.user_id == user_id)\
|
||||
.order_by(SearchHistory.created_at.desc())\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
Reference in New Issue
Block a user