WIP: Get tests passing for github oauth, almost fully implement it

This commit is contained in:
2024-09-14 00:36:37 -04:00
parent b2a0cafe6e
commit 64edba3cf8
19 changed files with 690 additions and 143 deletions

View File

@@ -8,13 +8,14 @@ import "@tsed/passport";
import { config } from "./config/index";
import * as rest from "./controllers/rest/index";
import * as pages from "./controllers/pages/index";
import { User } from "./entities/User";
@Configuration({
...config,
acceptMimes: ["application/json"],
httpPort: process.env.PORT || 8083,
httpsPort: false, // CHANGE
disableComponentsScan: true,
disableComponentsScan: false,
ajv: {
returnsCoercedValues: true
},
@@ -28,6 +29,10 @@ import * as pages from "./controllers/pages/index";
specVersion: "3.0.1"
}
],
componentsScan: [`./protocols/*.ts`, `./services/*.ts`],
passport: {
userInfoModel: User
},
middlewares: [
"cors",
"cookie-parser",

View File

@@ -0,0 +1,76 @@
import { describe, it, expect, vi, beforeEach, afterAll, beforeAll, afterEach } from "vitest";
import { PlatformTest, Req } from "@tsed/common";
import { GithubProtocol } from "../../protocols/GthubProtocol";
import { UserService } from "../../services/UserService";
import { Server } from "../../Server";
describe("GithubProtocol", () => {
let protocol: GithubProtocol;
let userService: UserService;
beforeAll(async () => {
await PlatformTest.create({ platform: Server });
});
afterAll(() => {
return PlatformTest.reset();
});
beforeEach(async () => {
userService = {
findOrCreate: vi.fn().mockResolvedValue({ id: "user123", username: "githubuser" })
} as unknown as UserService;
protocol = await PlatformTest.invoke<GithubProtocol>(GithubProtocol, [{ token: UserService, use: userService }]);
// Mock fetch for GitHub API call
global.fetch = vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue([
{ value: "user@example.com", verified: true },
{ value: "user2@example.com", verified: false }
])
});
});
afterEach(() => {
vi.restoreAllMocks();
});
it("should call $onVerify and return a user", async () => {
const mockReq = {
query: { state: "github-state" }
} as unknown as Req;
const mockAccessToken = "mock-access-token";
const mockProfile = { username: "githubuser" };
const result = await protocol.$onVerify(mockReq, mockAccessToken, "", mockProfile);
expect(userService.findOrCreate).toHaveBeenCalledWith({
service: "github",
serviceIdentifier: "github-state",
username: "githubuser",
emails: [{ value: "user@example.com", verified: true }],
accessToken: mockAccessToken
});
expect(result).toEqual({ id: "user123", username: "githubuser" });
});
it("should throw an error if no verified emails are found", async () => {
global.fetch = vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue([])
});
const mockReq = { query: { state: "github-state" } } as unknown as Req;
const mockAccessToken = "mock-access-token";
const mockProfile = { username: "githubuser" };
await expect(protocol.$onVerify(mockReq, mockAccessToken, "", mockProfile)).rejects.toThrow("No verified email found");
});
it("should fetch verified emails from GitHub", async () => {
const emails = await protocol.fetchVerifiedEmails("mock-access-token");
expect(emails).toEqual([{ value: "user@example.com", verified: true }]);
expect(global.fetch).toHaveBeenCalledWith("https://api.github.com/user/emails", expect.anything());
});
});

View File

@@ -0,0 +1,49 @@
import { Controller, Get, Req, Res, Next, QueryParams } from "@tsed/common";
import { Authenticate } from "@tsed/passport";
import { Configuration } from "@tsed/common";
import { Response, Request, NextFunction } from "express";
@Controller("/auth")
export class AuthController {
@Configuration()
private config: Configuration;
@Get("/github")
async githubLogin(
@Req() req: Request,
@Res() res: Response,
@Next() next: NextFunction,
@QueryParams("serviceIdentifier") serviceIdentifier: string
) {
if (!serviceIdentifier) {
res.status(400).send("serviceIdentifier is required");
return;
}
// Initiate authentication with the 'state' parameter
return Authenticate("github", {
scope: ["user:email"],
state: serviceIdentifier
})(req, res, next);
}
@Get("/github/callback")
@Authenticate("github", { failureRedirect: "/login" })
async githubCallback(@Req() req: Request, @Res() res: Response) {
// Authentication was successful
// You can redirect the user to a specific page or return user info
// Example: Redirect to the home page
res.redirect("/");
}
@Get("/logout")
logout(@Req() req: Request, @Res() res: Response) {
req.logout((err) => {
if (err) {
return res.status(500).send("Error logging out");
}
res.redirect("/");
});
}
}

View File

@@ -18,22 +18,26 @@ describe("LinkController", () => {
it("should call POST /rest/links and GET /rest/links/:id", async () => {
const request = SuperTest(PlatformTest.callback());
const username = `silentsilas-${randomUUID()}`;
const response = await request
const postResponse = await request
.post("/rest/links")
.send({
service: "github",
serviceUsername: username
serviceIdentifier: username
})
.expect(201);
const response2 = await request.get(`/rest/users/${response.body.userId}`).expect(200);
const getOneResponse = await request.get(`/rest/users/${postResponse.body.userId}`).expect(200);
const getAllResponse = await request.get("/rest/users/").expect(200);
expect(response.body.id).toBeTruthy();
expect(response.body.service).toEqual("github");
expect(response.body.serviceUsername).toEqual(username);
expect(postResponse.body.id).toBeTruthy();
expect(postResponse.body.service).toEqual("github");
expect(postResponse.body.serviceIdentifier).toEqual(username);
expect(response2.body.id).toEqual(response.body.userId);
expect(response2.body.service).toEqual("github");
expect(response2.body.serviceUsername).toEqual(username);
expect(getOneResponse.body.id).toEqual(postResponse.body.userId);
expect(getOneResponse.body.service).toEqual("github");
expect(getOneResponse.body.serviceIdentifier).toEqual(username);
expect(getAllResponse.body).toBeInstanceOf(Array);
expect(getAllResponse.body.length).toBeGreaterThan(0);
});
});

View File

@@ -1,67 +1,53 @@
import { BodyParams, PathParams } from "@tsed/platform-params";
import { BodyParams, PathParams, Req, Controller } from "@tsed/common";
import { Description, Get, Post, Returns, Summary } from "@tsed/schema";
import { Controller, Inject } from "@tsed/di";
import { Authenticate } from "@tsed/passport";
import { Link } from "../../entities/link/Link";
import { SqliteDatasource } from "../../datasources/SqliteDatasource";
import { DataSource } from "typeorm";
import { User } from "../../entities/User";
import { CreateLinkDto } from "../../entities/link/CreateLinkDTO";
import { executeWithRetry } from "../../datasources/SqliteDatasource";
import { UserService } from "../../services/UserService";
import { LinkService } from "../../services/LinkService"; // Create a new service for Link operations
import { User } from "../../entities/User";
@Controller("/links")
export class LinkController {
constructor(@Inject(SqliteDatasource) private sqliteDataSource: DataSource) {}
constructor(
private linkService: LinkService, // Inject LinkService
private userService: UserService // Inject UserService
) {}
@Post("/")
@Summary("Create a new link")
@Description("Creates a new link and associates it with a user")
@Returns(201, Link)
async create(@BodyParams() linkData: CreateLinkDto): Promise<Link> {
return executeWithRetry(async (queryRunner) => {
const userRepository = queryRunner.manager.getRepository(User);
const linkRepository = queryRunner.manager.getRepository(Link);
// Delegate user creation logic to UserService
const user = await this.userService.findOrCreate(linkData);
let user = await userRepository.findOne({
where: {
service: linkData.service,
serviceUsername: linkData.serviceUsername
}
});
if (!user) {
user = userRepository.create({
service: linkData.service,
serviceUsername: linkData.serviceUsername
});
user = await queryRunner.manager.save(User, user);
}
const link = linkRepository.create({
...linkData,
user
});
return queryRunner.manager.save(Link, link);
}, this.sqliteDataSource);
// Use LinkService to handle the link creation
return this.linkService.createLink(linkData, user);
}
@Get("/")
@Summary("Get all links")
@Authenticate("github")
@Summary("Get all links for the authenticated user")
@(Returns(200, Array).Of(Link))
async getList(): Promise<Link[]> {
return executeWithRetry(async (queryRunner) => {
const linkRepository = queryRunner.manager.getRepository(Link);
return linkRepository.find({ relations: ["user"] });
}, this.sqliteDataSource);
async getList(@Req() req: Req): Promise<Link[]> {
const user = req.user as User;
return this.linkService.getLinksForUser(user);
}
@Get("/:id")
@Summary("Get a link by ID")
@Summary("Get a link by ID without text content")
@Returns(200, Link)
async getOne(@PathParams("id") id: string): Promise<Link | null> {
return executeWithRetry(async (queryRunner) => {
const linkRepository = queryRunner.manager.getRepository(Link);
return linkRepository.findOne({ where: { id }, relations: ["user"] });
}, this.sqliteDataSource);
return this.linkService.getLinkById(id);
}
@Get("/:id/content")
@Authenticate("github")
@Summary("Get the content of a link if authorized")
@Returns(200, String)
async getLinkContent(@PathParams("id") id: string, @Req() req: Req): Promise<string> {
const user = req.user as User;
return this.linkService.getLinkContentById(id, user);
}
}

View File

@@ -1,21 +1,24 @@
import { expect, describe, it, beforeEach, afterEach, beforeAll } from "vitest";
import { expect, describe, it, afterEach, beforeAll, beforeEach } from "vitest";
import { PlatformTest } from "@tsed/common";
import { UserController } from "./UserController";
import { User } from "../../entities/User";
import { DataSource } from "typeorm";
import { v4 as uuidv4 } from "uuid";
import { Server } from "../../Server";
import { sqliteDatasource } from "src/datasources/SqliteDatasource";
import { SqliteDatasource, sqliteDatasource } from "src/datasources/SqliteDatasource";
describe("UserController", () => {
let controller: UserController;
beforeAll(PlatformTest.bootstrap(Server));
beforeAll(
PlatformTest.bootstrap(Server, {
imports: [sqliteDatasource]
})
);
beforeEach(async () => {
controller = await PlatformTest.invoke(UserController, [
{
token: DataSource,
token: SqliteDatasource,
use: sqliteDatasource
}
]);
@@ -30,7 +33,8 @@ describe("UserController", () => {
const user = new User();
user.id = userId;
user.service = "github";
user.serviceUsername = `silentsilas-${userId}`;
user.serviceIdentifier = `user-${userId}`;
user.links = [];
const repo = sqliteDatasource.getRepository(User);
await repo.save(user);
@@ -39,6 +43,6 @@ describe("UserController", () => {
expect(result).toEqual(user);
expect(result?.id).toEqual(userId);
expect(result?.service).toEqual("github");
expect(result?.serviceUsername).toEqual(`silentsilas-${userId}`);
expect(result?.serviceIdentifier).toEqual(`user-${userId}`);
});
});

View File

@@ -2,14 +2,13 @@ import { PathParams } from "@tsed/platform-params";
import { Description, Get, Post, Returns, Summary } from "@tsed/schema";
import { Controller, Inject } from "@tsed/di";
import { User } from "../../entities/User";
import { SqliteDatasource } from "../../datasources/SqliteDatasource";
import { DataSource } from "typeorm";
import { executeWithRetry } from "../../datasources/SqliteDatasource";
import { Forbidden } from "@tsed/exceptions";
import { UserService } from "../../services/UserService";
@Controller("/users")
export class UserController {
constructor(@Inject(SqliteDatasource) private sqliteDataSource: DataSource) {}
@Inject()
service: UserService;
// disable the create method and endpoint
@Post("/")
@@ -24,21 +23,13 @@ export class UserController {
@Summary("Get all users")
@(Returns(200, Array).Of(User))
async getList(): Promise<User[]> {
return executeWithRetry(async (queryRunner) => {
return queryRunner.manager.find(User);
}, this.sqliteDataSource);
return this.service.getAllUsers();
}
@Get("/:id")
@Summary("Get a user by ID")
@Returns(200, User)
async getOne(@PathParams("id") id: string): Promise<User | null> {
return executeWithRetry(async (queryRunner) => {
return queryRunner.manager.findOne(User, {
where: {
id
}
});
}, this.sqliteDataSource);
return this.service.getUserById(id);
}
}

View File

@@ -2,5 +2,6 @@
* @file Automatically generated by barrelsby.
*/
export * from "./AuthController";
export * from "./LinkController";
export * from "./UserController";

View File

@@ -3,7 +3,6 @@ import { DataSource } from "typeorm";
import { Logger } from "@tsed/logger";
import { User } from "../entities/User";
import { Link } from "../entities/link/Link";
import { QueryRunner } from "typeorm";
export const SqliteDatasource = Symbol.for("SqliteDatasource");
export type SqliteDatasource = DataSource;
@@ -37,48 +36,17 @@ registerProvider<DataSource>({
type: "typeorm:datasource",
deps: [Logger],
async useAsyncFactory(logger: Logger) {
await sqliteDatasource.initialize();
logger.info("Connected with typeorm to database: Sqlite");
if (!sqliteDatasource.isInitialized) {
await sqliteDatasource.initialize();
logger.info("Connected with TypeORM to database: Sqlite");
}
return sqliteDatasource;
},
hooks: {
$onDestroy(dataSource) {
return dataSource.isInitialized && dataSource.destroy();
async $onDestroy(dataSource: DataSource) {
if (dataSource.isInitialized) {
await dataSource.destroy();
}
}
}
});
export async function executeWithRetry<T>(
operation: (queryRunner: QueryRunner) => Promise<T>,
dataSource: DataSource,
maxRetries = 10,
delay = 1000
): Promise<T> {
let retries = 0;
while (true) {
const queryRunner = dataSource.createQueryRunner();
try {
await queryRunner.connect();
await queryRunner.startTransaction();
const result = await operation(queryRunner);
await queryRunner.commitTransaction();
return result;
} catch (error) {
await queryRunner.rollbackTransaction();
if (error.code === "SQLITE_BUSY" && retries < maxRetries) {
retries++;
await new Promise((resolve) => setTimeout(resolve, delay));
continue;
}
throw error;
} finally {
await queryRunner.release();
}
}
}

View File

@@ -17,12 +17,21 @@ export class User {
@Column({ length: 100 })
@MaxLength(100)
@Required()
serviceUsername: string;
serviceIdentifier: string;
@OneToMany(() => Link, (link) => link.user)
@CollectionOf(() => Link)
links: Link[];
@Column({ length: 100, nullable: true })
username: string;
@Column("simple-json", { nullable: true })
emails: string[];
@Column({ nullable: true })
accessToken: string;
@BeforeInsert()
generateId() {
if (!this.id) {

View File

@@ -1,13 +1,14 @@
import { Property, Required, MaxLength } from "@tsed/schema";
import { Property, Required, MaxLength, Enum } from "@tsed/schema";
export class CreateLinkDto {
@Property()
@Required()
@MaxLength(100)
@Enum("github")
service: string;
@Property()
@Required()
@MaxLength(100)
serviceUsername: string;
serviceIdentifier: string;
}

View File

@@ -1,8 +1,10 @@
import { MaxLength, Property, Required } from "@tsed/schema";
import { Enum, MaxLength, Property, Required } from "@tsed/schema";
import { Column, Entity, ManyToOne, PrimaryColumn, JoinColumn, BeforeInsert } from "typeorm";
import { User } from "../User";
import { v4 as uuidv4 } from "uuid";
export type Service = "github";
@Entity()
export class Link {
@PrimaryColumn("uuid")
@@ -12,12 +14,17 @@ export class Link {
@Column({ length: 100 })
@MaxLength(100)
@Required()
@Enum("github")
service: string;
@Column({ length: 100 })
@MaxLength(100)
@Required()
serviceUsername: string;
serviceIdentifier: string;
@Column({ nullable: true })
@Required()
text: string;
@ManyToOne(() => User, (user) => user.links, { onDelete: "SET NULL", onUpdate: "CASCADE" })
@JoinColumn({ name: "userId" })

View File

@@ -1,18 +1,25 @@
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from "vitest";
import { Server } from "./Server";
vi.mock("@tsed/common", () => ({
$log: {
error: vi.fn()
},
PlatformApplication: vi.fn()
}));
vi.mock("@tsed/common", async (importOriginal) => {
const actual: any = await importOriginal();
return {
...actual,
$log: {
error: vi.fn()
}
};
});
vi.mock("@tsed/platform-express", () => ({
PlatformExpress: {
bootstrap: vi.fn()
}
}));
vi.mock("@tsed/platform-express", async (importOriginal) => {
const actual: any = await importOriginal();
return {
...actual,
PlatformExpress: {
bootstrap: vi.fn()
}
};
});
describe("bootstrap function", () => {
let bootstrap: () => Promise<void>;

View File

@@ -0,0 +1,73 @@
import { describe, it, expect, vi, beforeEach, afterAll, beforeAll } from "vitest";
import { PlatformTest, Req } from "@tsed/common";
import { UserService } from "../services/UserService";
import { Server } from "../Server";
import { sqliteDatasource } from "../datasources/SqliteDatasource";
import { GithubProtocol } from "./GthubProtocol";
describe("GithubProtocol", () => {
let protocol: GithubProtocol;
let userService: UserService;
beforeAll(async () => {
await PlatformTest.create({ platform: Server, imports: [sqliteDatasource] }); // ensure PlatformTest.create() is called
});
afterAll(() => {
return PlatformTest.reset();
});
beforeEach(async () => {
userService = {
findOrCreate: vi.fn().mockResolvedValue({ id: "user123", username: "githubuser" })
} as unknown as UserService;
protocol = await PlatformTest.invoke<GithubProtocol>(GithubProtocol, [{ token: UserService, use: userService }]);
});
it("should call $onVerify and return a user", async () => {
const mockReq = {
query: { state: "github-state" }
} as unknown as Req;
const mockAccessToken = "mock-access-token";
const mockProfile = { username: "githubuser" };
const fetchSpy = vi.spyOn(protocol, "fetchVerifiedEmails").mockResolvedValue([{ value: "user@example.com", verified: true }]);
const result = await protocol.$onVerify(mockReq, mockAccessToken, "", mockProfile);
expect(fetchSpy).toHaveBeenCalledWith(mockAccessToken);
expect(userService.findOrCreate).toHaveBeenCalledWith({
service: "github",
serviceIdentifier: "github-state",
username: "githubuser",
emails: [{ value: "user@example.com", verified: true }],
accessToken: mockAccessToken
});
expect(result).toEqual({ id: "user123", username: "githubuser" });
});
it("should throw an error if no verified emails are found", async () => {
const mockReq = { query: { state: "github-state" } } as unknown as Req;
const mockAccessToken = "mock-access-token";
const mockProfile = { username: "githubuser" };
vi.spyOn(protocol, "fetchVerifiedEmails").mockResolvedValue([]);
await expect(protocol.$onVerify(mockReq, mockAccessToken, "", mockProfile)).rejects.toThrow("No verified email found");
});
it("should fetch verified emails from GitHub", async () => {
global.fetch = vi.fn().mockResolvedValue({
json: vi.fn().mockResolvedValue([
{ value: "email1@example.com", verified: true },
{ value: "email2@example.com", verified: false }
])
});
const emails = await protocol.fetchVerifiedEmails("mock-access-token");
expect(emails).toEqual([{ value: "email1@example.com", verified: true }]);
expect(global.fetch).toHaveBeenCalledWith("https://api.github.com/user/emails", expect.anything());
});
});

View File

@@ -0,0 +1,80 @@
// protocols/GithubProtocol.ts
import { Protocol, OnVerify, OnInstall } from "@tsed/passport";
import { Req } from "@tsed/common";
import { Inject } from "@tsed/di";
import { UserService } from "../services/UserService";
import { Strategy as GithubStrategy } from "passport-github";
import { SqliteDatasource } from "../datasources/SqliteDatasource";
@Protocol({
name: "github",
useStrategy: GithubStrategy,
settings: {
clientID: process.env.GITHUB_CLIENT_ID || "your-client-id",
clientSecret: process.env.GITHUB_CLIENT_SECRET || "your-client-secret",
callbackURL: "http://localhost:8080/auth/github/callback",
scope: ["user:email"],
state: true,
passReqToCallback: true
}
})
export class GithubProtocol implements OnVerify, OnInstall {
@Inject()
userService: UserService;
@Inject()
sqliteDatasource: SqliteDatasource;
async $onVerify(@Req() req: Req, accessToken: string, _refreshToken: string, profile: any) {
const emails = await this.fetchVerifiedEmails(accessToken);
if (!emails.length) {
throw new Error("No verified email found");
}
const state = req.query.state;
let identifier: string;
if (typeof state === "string") {
identifier = state;
} else if (Array.isArray(state)) {
// If state is an array, take the first string element
identifier = state.find((s) => typeof s === "string") as string;
if (!identifier) {
throw new Error("Invalid service identifier");
}
} else {
throw new Error("Service identifier is missing or invalid");
}
const user = await this.userService.findOrCreate({
service: "github",
serviceIdentifier: identifier.toString(),
username: profile.username,
emails: emails,
accessToken: accessToken
});
return user;
}
async fetchVerifiedEmails(accessToken: string): Promise<any[]> {
const response = await fetch("https://api.github.com/user/emails", {
headers: {
Authorization: `token ${accessToken}`,
"User-Agent": "YourAppName"
}
});
const emails = await response.json();
return emails.filter((email: any) => email.verified);
}
$onInstall(strategy: GithubStrategy) {
console.log("Github strategy installed");
if (process.env.NODE_ENV === "development") {
console.log(strategy);
}
// Optional: additional strategy configuration
}
}

View File

@@ -0,0 +1,63 @@
import { Inject, Injectable } from "@tsed/di";
import { Forbidden, NotFound } from "@tsed/exceptions";
import { SqliteDatasource } from "../datasources/SqliteDatasource";
import { User } from "../entities/User";
import { CreateLinkDto } from "../entities/link/CreateLinkDTO";
import { Link } from "../entities/link/Link";
import { DataSource, Repository } from "typeorm";
@Injectable()
export class LinkService {
private linkRepository: Repository<Link>;
constructor(@Inject(SqliteDatasource) private dataSource: DataSource) {
this.linkRepository = this.dataSource.getRepository(Link);
}
async createLink(linkData: CreateLinkDto, user: User): Promise<Link> {
const link = this.linkRepository.create({
...linkData,
user
});
return this.linkRepository.save(link);
}
async getLinksForUser(user: User): Promise<Link[]> {
return this.linkRepository.find({
where: { user: { id: user.id } },
relations: ["user"],
select: ["id", "service", "serviceIdentifier", "text"]
});
}
async getLinkById(id: string): Promise<Link | null> {
const link = await this.linkRepository.findOne({
where: { id },
relations: ["user"],
select: ["id", "service", "serviceIdentifier"]
});
if (!link) {
throw new NotFound("Link not found");
}
return link;
}
async getLinkContentById(id: string, user: User): Promise<string> {
const link = await this.linkRepository.findOne({
where: { id },
relations: ["user"]
});
if (!link) {
throw new NotFound("Link not found");
}
if (link.user.serviceIdentifier !== user.serviceIdentifier) {
throw new Forbidden("You are not authorized to view this link content");
}
return link.text;
}
}

View File

@@ -0,0 +1,55 @@
import { Inject, Injectable } from "@tsed/di";
import { User } from "../entities/User";
import { SqliteDatasource } from "../datasources/SqliteDatasource";
import { DataSource, Repository } from "typeorm";
type ProfileData = {
service: string;
serviceIdentifier: string;
username?: string;
emails?: string[];
accessToken?: string;
};
@Injectable({
deps: [SqliteDatasource]
})
export class UserService {
private userRepository: Repository<User>;
constructor(@Inject(SqliteDatasource) private dataSource: DataSource) {
this.userRepository = this.dataSource.getRepository(User);
}
public async findOrCreate(profileData: ProfileData): Promise<User> {
let user = await this.userRepository.findOne({
where: { serviceIdentifier: profileData.serviceIdentifier }
});
if (!user) {
user = this.userRepository.create({
service: profileData.service,
serviceIdentifier: profileData.serviceIdentifier,
username: profileData.username,
emails: profileData.emails,
accessToken: profileData.accessToken
});
} else {
user.emails = profileData.emails || user.emails;
user.accessToken = profileData.accessToken || user.accessToken;
}
return this.userRepository.save(user);
}
public async getUserById(id: string): Promise<User | null> {
return this.userRepository.findOne({
where: { id },
relations: ["links"]
});
}
public async getAllUsers(): Promise<User[]> {
return this.userRepository.find();
}
}