135 lines
4.9 KiB
Python
135 lines
4.9 KiB
Python
import os
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import sessionmaker, Session, relationship
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
Base = declarative_base()
|
|
|
|
class User(Base):
|
|
__tablename__ = "users"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
telegram_id = Column(String, unique=True, nullable=False)
|
|
username = Column(String, nullable=True)
|
|
first_name = Column(String, nullable=True)
|
|
last_name = Column(String, nullable=True)
|
|
language_code = Column(String, nullable=True)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
|
|
# Relationships
|
|
search_history = relationship("SearchHistory", back_populates="user")
|
|
downloads = relationship("Download", back_populates="user")
|
|
|
|
class SearchHistory(Base):
|
|
__tablename__ = "search_history"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
|
query = Column(String, nullable=False)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
|
|
# Relationship
|
|
user = relationship("User", back_populates="search_history")
|
|
|
|
class Download(Base):
|
|
__tablename__ = "downloads"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
|
video_id = Column(String, nullable=False)
|
|
title = Column(String, nullable=False)
|
|
file_path = Column(String, nullable=False)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
|
|
# Relationship
|
|
user = relationship("User", back_populates="downloads")
|
|
|
|
class Database:
|
|
def __init__(self):
|
|
# Get database URL from environment or use default
|
|
database_url = os.getenv("DATABASE_URL")
|
|
if not database_url:
|
|
raise ValueError("DATABASE_URL environment variable is required")
|
|
|
|
# Create engine
|
|
self.engine = create_engine(
|
|
database_url,
|
|
poolclass=StaticPool,
|
|
connect_args={"check_same_thread": False} if "sqlite" in database_url else {},
|
|
echo=False
|
|
)
|
|
|
|
# Create session factory
|
|
self.SessionLocal = sessionmaker(bind=self.engine)
|
|
|
|
# Create tables
|
|
self.init_db()
|
|
|
|
def init_db(self):
|
|
"""Initialize database tables"""
|
|
Base.metadata.create_all(bind=self.engine)
|
|
|
|
def get_session(self) -> Session:
|
|
"""Get database session"""
|
|
return self.SessionLocal()
|
|
|
|
def close(self):
|
|
"""Close database connection"""
|
|
self.engine.dispose()
|
|
|
|
async def get_user_by_telegram_id(self, telegram_id: str) -> Optional[User]:
|
|
"""Get user by Telegram ID"""
|
|
with self.get_session() as session:
|
|
return session.query(User).filter(User.telegram_id == telegram_id).first()
|
|
|
|
async def add_user(self, telegram_id: str, username: str = None,
|
|
first_name: str = None, last_name: str = None,
|
|
language_code: str = None) -> User:
|
|
"""Add new user"""
|
|
with self.get_session() as session:
|
|
user = User(
|
|
telegram_id=telegram_id,
|
|
username=username,
|
|
first_name=first_name,
|
|
last_name=last_name,
|
|
language_code=language_code
|
|
)
|
|
session.add(user)
|
|
session.commit()
|
|
session.refresh(user)
|
|
return user
|
|
|
|
async def add_search_history(self, user_id: int, query: str) -> SearchHistory:
|
|
"""Add search history record"""
|
|
with self.get_session() as session:
|
|
history = SearchHistory(user_id=user_id, query=query)
|
|
session.add(history)
|
|
session.commit()
|
|
session.refresh(history)
|
|
return history
|
|
|
|
async def add_download(self, user_id: int, video_id: str, title: str, file_path: str) -> Download:
|
|
"""Add download record"""
|
|
with self.get_session() as session:
|
|
download = Download(
|
|
user_id=user_id,
|
|
video_id=video_id,
|
|
title=title,
|
|
file_path=file_path
|
|
)
|
|
session.add(download)
|
|
session.commit()
|
|
session.refresh(download)
|
|
return download
|
|
|
|
async def get_search_history(self, user_id: int, limit: int = 10) -> List[SearchHistory]:
|
|
"""Get user search history"""
|
|
with self.get_session() as session:
|
|
return session.query(SearchHistory)\
|
|
.filter(SearchHistory.user_id == user_id)\
|
|
.order_by(SearchHistory.created_at.desc())\
|
|
.limit(limit)\
|
|
.all() |